Skip to content

Commit 09c6cc0

Browse files
authored
Remove IREE helper functions (pybamm-team#5082)
* Remove most of IREE from the solvers * Cleanup * Remove most other jax and IREE code blocks * Update changelog * Remove some helper functions
1 parent 7155b62 commit 09c6cc0

File tree

5 files changed

+1
-123
lines changed

5 files changed

+1
-123
lines changed

src/pybamm/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from pybamm.version import __version__
22

3-
# Demote expressions to 32-bit floats/ints - option used for IDAKLU-MLIR compilation
4-
demote_expressions_to_32bit = False
5-
63
# Utility classes and methods
74
from .util import root_dir
85
from .util import Timer, TimerTime, FuzzyDict

src/pybamm/expression_tree/operations/evaluate_python.py

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -603,54 +603,9 @@ def __init__(self, symbol: pybamm.Symbol):
603603
static_argnums=self._static_argnums,
604604
)
605605

606-
def _demote_constants(self):
607-
"""Demote 64-bit constants (f64, i64) to 32-bit (f32, i32)"""
608-
if not pybamm.demote_expressions_to_32bit:
609-
return # pragma: no cover
610-
self._constants = EvaluatorJax._demote_64_to_32(self._constants)
611-
612-
@classmethod
613-
def _demote_64_to_32(cls, c):
614-
"""Demote 64-bit operations (f64, i64) to 32-bit (f32, i32)"""
615-
616-
if not pybamm.demote_expressions_to_32bit:
617-
return c
618-
if isinstance(c, float):
619-
c = jax.numpy.float32(c)
620-
if isinstance(c, int):
621-
c = jax.numpy.int32(c)
622-
if isinstance(c, np.int64):
623-
c = c.astype(jax.numpy.int32)
624-
if isinstance(c, np.ndarray):
625-
if c.dtype == np.float64:
626-
c = c.astype(jax.numpy.float32)
627-
if c.dtype == np.int64:
628-
c = c.astype(jax.numpy.int32)
629-
if isinstance(c, jax.numpy.ndarray):
630-
if c.dtype == jax.numpy.float64:
631-
c = c.astype(jax.numpy.float32)
632-
if c.dtype == jax.numpy.int64:
633-
c = c.astype(jax.numpy.int32)
634-
if isinstance(
635-
c, pybamm.expression_tree.operations.evaluate_python.JaxCooMatrix
636-
):
637-
if c.data.dtype == np.float64:
638-
c.data = c.data.astype(jax.numpy.float32)
639-
if c.row.dtype == np.int64:
640-
c.row = c.row.astype(jax.numpy.int32)
641-
if c.col.dtype == np.int64:
642-
c.col = c.col.astype(jax.numpy.int32)
643-
if isinstance(c, dict):
644-
c = {key: EvaluatorJax._demote_64_to_32(value) for key, value in c.items()}
645-
if isinstance(c, tuple):
646-
c = tuple(EvaluatorJax._demote_64_to_32(value) for value in c)
647-
if isinstance(c, list):
648-
c = [EvaluatorJax._demote_64_to_32(value) for value in c]
649-
return c
650-
651606
@property
652607
def _constants(self):
653-
return tuple(map(EvaluatorJax._demote_64_to_32, self.__constants))
608+
return self.__constants
654609

655610
@_constants.setter
656611
def _constants(self, value):

src/pybamm/solvers/idaklu_solver.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -526,13 +526,6 @@ def __setstate__(self, d):
526526
options=self._options,
527527
)
528528

529-
def _check_mlir_conversion(self, name, mlir: str):
530-
if mlir.count("f64") > 0: # pragma: no cover
531-
warnings.warn(f"f64 found in {name} (x{mlir.count('f64')})", stacklevel=2)
532-
533-
def _demote_64_to_32(self, x: pybamm.EvaluatorJax):
534-
return pybamm.EvaluatorJax._demote_64_to_32(x)
535-
536529
@property
537530
def supports_parallel_solve(self):
538531
return True

src/pybamm/solvers/processed_variable_computed.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,6 @@ def _unroll_nnz(self, realdata=None):
140140
nnz = sp.nnz()
141141
numel = sp.numel()
142142
row = sp.row()
143-
elif "nnz" in dir(self.base_variables_casadi[0]): # IREE fcn
144-
sp = self.base_variables_casadi[0]
145-
nnz = sp.nnz
146-
numel = sp.numel
147-
row = sp.row
148143
if nnz != numel:
149144
data = [None] * len(realdata)
150145
for datak in range(len(realdata)):

tests/unit/test_expression_tree/test_operations/test_evaluate_python.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -672,68 +672,6 @@ def test_evaluator_jax_inputs(self):
672672
result = evaluator(inputs={"a": 2})
673673
assert result == 4
674674

675-
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
676-
def test_evaluator_jax_demotion(self):
677-
for demote in [True, False]:
678-
pybamm.demote_expressions_to_32bit = demote # global flag
679-
target_dtype = "32" if demote else "64"
680-
if demote:
681-
# Test only works after conversion to jax.numpy
682-
for c in [
683-
1.0,
684-
1,
685-
]:
686-
assert (
687-
str(pybamm.EvaluatorJax._demote_64_to_32(c).dtype)[-2:]
688-
== target_dtype
689-
)
690-
for c in [
691-
np.float64(1.0),
692-
np.int64(1),
693-
np.array([1.0], dtype=np.float64),
694-
np.array([1], dtype=np.int64),
695-
jax.numpy.array([1.0], dtype=np.float64),
696-
jax.numpy.array([1], dtype=np.int64),
697-
]:
698-
assert (
699-
str(pybamm.EvaluatorJax._demote_64_to_32(c).dtype)[-2:]
700-
== target_dtype
701-
)
702-
for c in [
703-
{key: np.float64(1.0) for key in ["a", "b"]},
704-
]:
705-
expr_demoted = pybamm.EvaluatorJax._demote_64_to_32(c)
706-
assert all(
707-
str(c_v.dtype)[-2:] == target_dtype
708-
for c_k, c_v in expr_demoted.items()
709-
)
710-
for c in [
711-
(np.float64(1.0), np.float64(2.0)),
712-
[np.float64(1.0), np.float64(2.0)],
713-
]:
714-
expr_demoted = pybamm.EvaluatorJax._demote_64_to_32(c)
715-
assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in expr_demoted)
716-
for dtype in [
717-
np.float64,
718-
jax.numpy.float64,
719-
]:
720-
c = pybamm.JaxCooMatrix([0, 1], [0, 1], dtype([1.0, 2.0]), (2, 2))
721-
c_demoted = pybamm.EvaluatorJax._demote_64_to_32(c)
722-
assert all(
723-
str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.data
724-
)
725-
for dtype in [
726-
np.int64,
727-
jax.numpy.int64,
728-
]:
729-
c = pybamm.JaxCooMatrix(
730-
dtype([0, 1]), dtype([0, 1]), [1.0, 2.0], (2, 2)
731-
)
732-
c_demoted = pybamm.EvaluatorJax._demote_64_to_32(c)
733-
assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.row)
734-
assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.col)
735-
pybamm.demote_expressions_to_32bit = False
736-
737675
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
738676
def test_jax_coo_matrix(self):
739677
A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2))

0 commit comments

Comments
 (0)