Skip to content

Commit 1c1dd75

Browse files
pipligginspre-commit-ci[bot]martinjrobins
authored
Create a base processed variable class (pybamm-team#5076)
* Creates an abstract base class for variables and can convert a processed to computed variable. * sort imports * Propegate sensitivities in PVC * style: pre-commit fixes * working version with some tests * Use entries instead, add integration test * style: pre-commit fixes * move stub solution * skip testing line & add note * Edit changelog * make _update public --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Martin Robinson <[email protected]>
1 parent 1ce4ef2 commit 1c1dd75

File tree

9 files changed

+287
-25
lines changed

9 files changed

+287
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Features
44

5+
- Creates `BaseProcessedVariable` to enable object combination when adding solutions together ([#5076](https://github.com/pybamm-team/PyBaMM/pull/5076))
56
- Added a `Constant` symbol for named constants. This is a subclass of `Scalar` and is used to represent named constants such as the gas constant. This avoids constants being simplified out when constructing expressions. ([#5070](https://github.com/pybamm-team/PyBaMM/pull/5070))
67
- Generalise `pybamm.DiscreteTimeSum` to allow it to be embedded in other expressions ([#5044](https://github.com/pybamm-team/PyBaMM/pull/5044))
78
- Adds `all` key-value pair to `output_variables` sensitivity dictionaries, accessible through `solution[var].sensitivities['all']`. Aligns shape with conventional solution sensitivities object. ([#5067](https://github.com/pybamm-team/PyBaMM/pull/5067))
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
5+
import pybamm
6+
7+
8+
class BaseProcessedVariable(abc.ABC):
9+
"""
10+
Shared API for both 'on-demand' and 'computed' processed variables.
11+
"""
12+
13+
@abc.abstractmethod
14+
def __call__(self, **coords): ...
15+
16+
@abc.abstractmethod
17+
def as_computed(self) -> pybamm.ProcessedVariableComputed: ...
18+
19+
def update(
20+
self,
21+
other: BaseProcessedVariable,
22+
new_sol: pybamm.Solution,
23+
) -> pybamm.ProcessedVariableComputed:
24+
"""
25+
Return a *new* CPV that is `self` followed by `other`.
26+
Works no matter which concrete types we start with.
27+
"""
28+
return self.as_computed()._concat(other.as_computed(), new_sol)

src/pybamm/solvers/processed_variable.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
import pybamm
99

10+
from .base_processed_variable import BaseProcessedVariable
1011

11-
class ProcessedVariable:
12+
13+
class ProcessedVariable(BaseProcessedVariable):
1214
"""
1315
An object that can be evaluated at arbitrary (scalars or vectors) t and x, and
1416
returns the (interpolated) value of the base variable at that t and x.
@@ -494,6 +496,58 @@ def _is_discrete_time_method(self):
494496
def hermite_interpolation(self):
495497
return self.all_yps is not None
496498

499+
def as_computed(self):
500+
"""
501+
Allows a ProcessedVariable to be converted to a ComputedProcessedVariable for
502+
use together, e.g. when using a last state solution with a new simulation running
503+
with output variables in the solver.
504+
"""
505+
506+
def _stub_solution(self):
507+
"""
508+
Return a lightweight object that looks like the parts of
509+
`pybamm.Solution` required by ProcessedVariableComputed, but without
510+
keeping the full state vector in memory.
511+
"""
512+
513+
class StubSolution:
514+
def __init__(self, ts, ys, inputs, inputs_casadi, sensitivities, t_pts):
515+
self.all_ts = ts
516+
self.all_ys = ys
517+
self.all_inputs = inputs
518+
self.all_inputs_casadi = inputs_casadi
519+
self.sensitivities = sensitivities
520+
self.t = t_pts
521+
522+
return StubSolution(
523+
self.all_ts,
524+
self.all_ys,
525+
self.all_inputs,
526+
self.all_inputs_casadi,
527+
self.sensitivities,
528+
self.t_pts,
529+
)
530+
531+
entries = self.entries # shape: (..., n_t)
532+
533+
# Move time to axis 0, then flatten spatial dims per timestep
534+
reshaped = np.moveaxis(entries, -1, 0) # shape: (n_t, ...)
535+
base_data = [reshaped.reshape(reshaped.shape[0], -1)] # (n_t, n_vars)
536+
537+
cpv = pybamm.ProcessedVariableComputed(
538+
self.base_variables,
539+
self.base_variables_casadi,
540+
base_data,
541+
_stub_solution(self),
542+
)
543+
544+
# add sensitivities if they exist
545+
if self.sensitivities:
546+
# TODO: test once #5058 is fixed
547+
cpv._sensitivities = self.sensitivities # pragma: no cover
548+
549+
return cpv
550+
497551

498552
class ProcessedVariable0D(ProcessedVariable):
499553
def __init__(

src/pybamm/solvers/processed_variable_computed.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
import pybamm
1212

13+
from .base_processed_variable import BaseProcessedVariable
1314

14-
class ProcessedVariableComputed:
15+
16+
class ProcessedVariableComputed(BaseProcessedVariable):
1517
"""
1618
An object that can be evaluated at arbitrary (scalars or vectors) t and x, and
1719
returns the (interpolated) value of the base variable at that t and x.
@@ -133,6 +135,9 @@ def __init__(
133135

134136
raise NotImplementedError(f"Shape not recognized for {base_variables[0]}")
135137

138+
def as_computed(self) -> ProcessedVariableComputed:
139+
return self
140+
136141
def add_sensitivity(self, param, data):
137142
# unroll from sparse representation into n-d matrix
138143
# then flatten for consistency with full state-vector
@@ -694,7 +699,7 @@ def sensitivities(self):
694699
return {}
695700
return self._sensitivities
696701

697-
def _update(
702+
def _concat(
698703
self, other: pybamm.ProcessedVariableComputed, new_sol: pybamm.Solution
699704
) -> pybamm.ProcessedVariableComputed:
700705
"""
@@ -716,4 +721,9 @@ def _update(
716721

717722
new_var = self.__class__(bv, bvc, bvd, new_sol)
718723

724+
new_var._sensitivities = {
725+
k: np.concatenate((self.sensitivities[k], other.sensitivities[k]))
726+
for k in self.sensitivities.keys()
727+
}
728+
719729
return new_var

src/pybamm/solvers/solution.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -801,17 +801,10 @@ def __add__(self, other):
801801
new_sol._sub_solutions = self.sub_solutions + other.sub_solutions
802802

803803
# update variables which were derived at the solver stage
804-
if other._variables and all(
805-
isinstance(v, pybamm.ProcessedVariableComputed)
806-
for v in other._variables.values()
807-
):
808-
if not self._variables:
809-
new_sol._variables = other._variables.copy()
810-
else:
811-
new_sol._variables = {
812-
v: self._variables[v]._update(other._variables[v], new_sol)
813-
for v in self._variables.keys()
814-
}
804+
if any([self.variables_returned, other.variables_returned]):
805+
vars = {*self._variables.keys(), *other._variables.keys()}
806+
new_sol._variables = {v: self[v].update(other[v], new_sol) for v in vars}
807+
new_sol.variables_returned = True
815808

816809
return new_sol
817810

tests/integration/test_solvers/test_idaklu.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,36 @@ def test_with_experiments(self):
209209
sols[0].cycles[-1]["Current [A]"].data,
210210
sols[1].cycles[-1]["Current [A]"].data,
211211
)
212+
213+
def test_outvars_with_experiments_multi_simulation(self):
214+
model = pybamm.lithium_ion.SPM()
215+
216+
experiment = pybamm.Experiment(
217+
[
218+
"Discharge at C/2 for 10 minutes",
219+
"Rest for 10 minutes",
220+
]
221+
)
222+
223+
solver = pybamm.IDAKLUSolver(
224+
output_variables=[
225+
"Discharge capacity [A.h]", # 0D variables
226+
"Time [s]",
227+
"Current [A]",
228+
"Voltage [V]",
229+
"Pressure [Pa]", # 1D variable
230+
"Positive particle effective diffusivity [m2.s-1]", # 2D variable
231+
]
232+
)
233+
234+
sim = pybamm.Simulation(model, experiment=experiment, solver=solver)
235+
solution = sim.solve()
236+
237+
sim2 = pybamm.Simulation(model, experiment=experiment, solver=solver)
238+
239+
new_sol1 = sim2.solve(starting_solution=solution.copy())
240+
new_sol2 = sim2.solve(starting_solution=solution.copy().last_state)
241+
242+
np.testing.assert_array_equal(
243+
new_sol1["Voltage [V]"].entries[63:], new_sol2["Voltage [V]"].entries
244+
)

tests/unit/test_solvers/test_processed_variable.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,3 +1486,155 @@ def spl(t):
14861486

14871487
# Check that the unsorted and sorted arrays are the same
14881488
assert np.all(y_unsorted == y_sorted[idxs_unsort])
1489+
1490+
def test_as_computed_0D(self):
1491+
# 0D
1492+
t = pybamm.t
1493+
y = pybamm.StateVector(slice(0, 1))
1494+
var = t * y
1495+
model = pybamm.BaseModel()
1496+
t_sol = np.linspace(0, 1)
1497+
y_sol = np.array([np.linspace(0, 5)])
1498+
yp_sol = self._get_yps(y_sol, False)
1499+
var_casadi = to_casadi(var, y_sol)
1500+
processed_var = pybamm.process_variable(
1501+
"test",
1502+
[var],
1503+
[var_casadi],
1504+
self._sol_default(t_sol, y_sol, yp_sol, model),
1505+
)
1506+
computed_var = processed_var.as_computed()
1507+
1508+
np.testing.assert_array_equal(computed_var.entries, t_sol * y_sol[0])
1509+
1510+
def test_as_computed_1D(self):
1511+
var = pybamm.Variable("var", domain=["negative electrode", "separator"])
1512+
x = pybamm.SpatialVariable("x", domain=["negative electrode", "separator"])
1513+
1514+
disc = tests.get_discretisation_for_testing()
1515+
disc.set_variable_slices([var])
1516+
x_sol = disc.process_symbol(x).entries[:, 0]
1517+
var_sol = disc.process_symbol(var)
1518+
t_sol = np.linspace(0, 1)
1519+
y_sol = x_sol[:, np.newaxis] * (5 * t_sol)
1520+
yp_sol = self._get_yps(y_sol, False, values=5)
1521+
1522+
var_casadi = to_casadi(var_sol, y_sol)
1523+
processed_var = pybamm.process_variable(
1524+
"test",
1525+
[var_sol],
1526+
[var_casadi],
1527+
self._sol_default(t_sol, y_sol, yp_sol),
1528+
)
1529+
1530+
computed_var = processed_var.as_computed()
1531+
1532+
# 2 vectors
1533+
np.testing.assert_allclose(
1534+
computed_var(t_sol, x_sol), y_sol, rtol=1e-7, atol=1e-6
1535+
)
1536+
# 1 vector, 1 scalar
1537+
np.testing.assert_allclose(
1538+
computed_var(0.5, x_sol), 2.5 * x_sol, rtol=1e-7, atol=1e-6
1539+
)
1540+
np.testing.assert_allclose(
1541+
computed_var(t_sol, x_sol[-1]),
1542+
x_sol[-1] * np.linspace(0, 5),
1543+
rtol=1e-7,
1544+
atol=1e-6,
1545+
)
1546+
# 2 scalars
1547+
np.testing.assert_allclose(
1548+
computed_var(0.5, x_sol[-1]), 2.5 * x_sol[-1], rtol=1e-7, atol=1e-6
1549+
)
1550+
1551+
def test_as_computed_2D(self):
1552+
var = pybamm.Variable(
1553+
"var",
1554+
domain=["negative particle"],
1555+
auxiliary_domains={"secondary": ["negative electrode"]},
1556+
)
1557+
x = pybamm.SpatialVariable("x", domain=["negative electrode"])
1558+
r = pybamm.SpatialVariable(
1559+
"r",
1560+
domain=["negative particle"],
1561+
auxiliary_domains={"secondary": ["negative electrode"]},
1562+
)
1563+
1564+
disc = tests.get_p2d_discretisation_for_testing()
1565+
disc.set_variable_slices([var])
1566+
x_sol = disc.process_symbol(x).entries[:, 0]
1567+
r_sol = disc.process_symbol(r).entries[:, 0]
1568+
# Keep only the first iteration of entries
1569+
r_sol = r_sol[: len(r_sol) // len(x_sol)]
1570+
var_sol = disc.process_symbol(var)
1571+
t_sol = np.linspace(0, 1)
1572+
y_sol = np.ones(len(x_sol) * len(r_sol))[:, np.newaxis] * np.linspace(0, 5)
1573+
yp_sol = self._get_yps(y_sol, False)
1574+
1575+
var_casadi = to_casadi(var_sol, y_sol)
1576+
processed_var = pybamm.process_variable(
1577+
"test",
1578+
[var_sol],
1579+
[var_casadi],
1580+
self._sol_default(t_sol, y_sol, yp_sol),
1581+
)
1582+
1583+
computed_var = processed_var.as_computed()
1584+
# 3 vectors
1585+
np.testing.assert_array_equal(
1586+
computed_var(t_sol, x_sol, r_sol).shape, (10, 40, 50)
1587+
)
1588+
np.testing.assert_allclose(
1589+
computed_var(t_sol, x_sol, r_sol),
1590+
np.reshape(y_sol, [len(r_sol), len(x_sol), len(t_sol)]),
1591+
rtol=1e-7,
1592+
atol=1e-6,
1593+
)
1594+
# 2 vectors, 1 scalar
1595+
np.testing.assert_array_equal(computed_var(0.5, x_sol, r_sol).shape, (10, 40))
1596+
np.testing.assert_array_equal(computed_var(t_sol, 0.2, r_sol).shape, (10, 50))
1597+
np.testing.assert_array_equal(computed_var(t_sol, x_sol, 0.5).shape, (40, 50))
1598+
# 1 vectors, 2 scalar
1599+
np.testing.assert_array_equal(computed_var(0.5, 0.2, r_sol).shape, (10,))
1600+
np.testing.assert_array_equal(computed_var(0.5, x_sol, 0.5).shape, (40,))
1601+
np.testing.assert_array_equal(computed_var(t_sol, 0.2, 0.5).shape, (50,))
1602+
# 3 scalars
1603+
np.testing.assert_array_equal(computed_var(0.2, 0.2, 0.2).shape, ())
1604+
1605+
def test_as_computed_3D(self):
1606+
disc = tests.get_size_distribution_disc_for_testing(xpts=5, rpts=6, Rpts=7)
1607+
var = pybamm.Variable(
1608+
"var",
1609+
domain=["negative particle"],
1610+
auxiliary_domains={
1611+
"secondary": ["negative particle size"],
1612+
"tertiary": ["negative electrode"],
1613+
},
1614+
)
1615+
x_sol = disc.mesh["negative electrode"].nodes
1616+
disc.set_variable_slices([var])
1617+
1618+
Nx = len(x_sol)
1619+
R_sol = disc.mesh["negative particle size"].nodes
1620+
r_sol = disc.mesh["negative particle"].nodes
1621+
var_sol = disc.process_symbol(var)
1622+
t_sol = np.linspace(0, 1)
1623+
y_sol = np.ones(len(x_sol) * len(R_sol) * len(r_sol))[:, np.newaxis] * t_sol
1624+
yp_sol = self._get_yps(y_sol, False)
1625+
1626+
var_casadi = to_casadi(var_sol, y_sol)
1627+
geometry_options = {"options": {"particle size": "distribution"}}
1628+
model = tests.get_base_model_with_battery_geometry(**geometry_options)
1629+
processed_var = pybamm.process_variable(
1630+
"test",
1631+
[var_sol],
1632+
[var_casadi],
1633+
self._sol_default(t_sol, y_sol, yp_sol, model),
1634+
)
1635+
computed_var = processed_var.as_computed()
1636+
1637+
# 4 vectors
1638+
np.testing.assert_array_equal(
1639+
computed_var(t=t_sol, x=x_sol, R=R_sol, r=r_sol).shape, (6, 7, Nx, 50)
1640+
)

tests/unit/test_solvers/test_processed_variable_computed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_processed_variable_0D(self):
107107
)
108108

109109
comb_sol = sol + sol_2
110-
comb_var = processed_var._update(processed_var2, comb_sol)
110+
comb_var = processed_var.update(processed_var2, comb_sol)
111111
np.testing.assert_array_equal(comb_var.entries, np.append(y_sol, y_sol2))
112112

113113
# check empty sensitivity works
@@ -275,7 +275,7 @@ def test_processed_variable_1D_update(self):
275275
)
276276

277277
comb_sol = sol1 + sol2
278-
comb_var = processed_var1._update(var_2, comb_sol)
278+
comb_var = processed_var1.update(var_2, comb_sol)
279279

280280
# Ordering from idaklu with output_variables set is different to
281281
# the full solver

tests/unit/test_solvers/test_solution.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,6 @@ def test_add_solutions_with_computed_variables(self):
195195
# check solution still tagged as 'variables_returned'
196196
assert sol_sum.variables_returned is True
197197

198-
# add a solution with computed variable to an empty solution
199-
empty_sol = pybamm.Solution(
200-
sol1.all_ts, sol1["2u"].base_variables_data, model, {u: 0, v: 1}
201-
)
202-
203-
sol4 = empty_sol + sol2
204-
assert sol4["2u"] == sol2["2u"]
205-
assert sol4.variables_returned is True
206-
207198
def test_copy(self):
208199
# Set up first solution
209200
t1 = [np.linspace(0, 1), np.linspace(1, 2, 5)]

0 commit comments

Comments
 (0)