Skip to content

Commit d237f21

Browse files
authored
Merge branch 'develop' into develop
2 parents 9837ff4 + 09c6cc0 commit d237f21

File tree

7 files changed

+136
-602
lines changed

7 files changed

+136
-602
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
## Breaking changes
1515

16+
- Removed the IREE code from the IDAKLU solver ([#5080](https://github.com/pybamm-team/PyBaMM/pull/5080))
1617
- Removed support for Python 3.9 ([#5052](https://github.com/pybamm-team/PyBaMM/pull/5052))
1718

1819
# [v25.6.0](https://github.com/pybamm-team/PyBaMM/tree/v25.6.0) - 2025-05-27

src/pybamm/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from pybamm.version import __version__
22

3-
# Demote expressions to 32-bit floats/ints - option used for IDAKLU-MLIR compilation
4-
demote_expressions_to_32bit = False
5-
63
# Utility classes and methods
74
from .util import root_dir
85
from .util import Timer, TimerTime, FuzzyDict
@@ -174,7 +171,7 @@
174171
from .solvers.jax_bdf_solver import jax_bdf_integrate
175172

176173
from .solvers.idaklu_jax import IDAKLUJax
177-
from .solvers.idaklu_solver import IDAKLUSolver, has_iree
174+
from .solvers.idaklu_solver import IDAKLUSolver
178175

179176
# Experiments
180177
from .experiment.experiment import Experiment

src/pybamm/expression_tree/operations/evaluate_python.py

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -603,54 +603,9 @@ def __init__(self, symbol: pybamm.Symbol):
603603
static_argnums=self._static_argnums,
604604
)
605605

606-
def _demote_constants(self):
607-
"""Demote 64-bit constants (f64, i64) to 32-bit (f32, i32)"""
608-
if not pybamm.demote_expressions_to_32bit:
609-
return # pragma: no cover
610-
self._constants = EvaluatorJax._demote_64_to_32(self._constants)
611-
612-
@classmethod
613-
def _demote_64_to_32(cls, c):
614-
"""Demote 64-bit operations (f64, i64) to 32-bit (f32, i32)"""
615-
616-
if not pybamm.demote_expressions_to_32bit:
617-
return c
618-
if isinstance(c, float):
619-
c = jax.numpy.float32(c)
620-
if isinstance(c, int):
621-
c = jax.numpy.int32(c)
622-
if isinstance(c, np.int64):
623-
c = c.astype(jax.numpy.int32)
624-
if isinstance(c, np.ndarray):
625-
if c.dtype == np.float64:
626-
c = c.astype(jax.numpy.float32)
627-
if c.dtype == np.int64:
628-
c = c.astype(jax.numpy.int32)
629-
if isinstance(c, jax.numpy.ndarray):
630-
if c.dtype == jax.numpy.float64:
631-
c = c.astype(jax.numpy.float32)
632-
if c.dtype == jax.numpy.int64:
633-
c = c.astype(jax.numpy.int32)
634-
if isinstance(
635-
c, pybamm.expression_tree.operations.evaluate_python.JaxCooMatrix
636-
):
637-
if c.data.dtype == np.float64:
638-
c.data = c.data.astype(jax.numpy.float32)
639-
if c.row.dtype == np.int64:
640-
c.row = c.row.astype(jax.numpy.int32)
641-
if c.col.dtype == np.int64:
642-
c.col = c.col.astype(jax.numpy.int32)
643-
if isinstance(c, dict):
644-
c = {key: EvaluatorJax._demote_64_to_32(value) for key, value in c.items()}
645-
if isinstance(c, tuple):
646-
c = tuple(EvaluatorJax._demote_64_to_32(value) for value in c)
647-
if isinstance(c, list):
648-
c = [EvaluatorJax._demote_64_to_32(value) for value in c]
649-
return c
650-
651606
@property
652607
def _constants(self):
653-
return tuple(map(EvaluatorJax._demote_64_to_32, self.__constants))
608+
return self.__constants
654609

655610
@_constants.setter
656611
def _constants(self, value):

0 commit comments

Comments
 (0)