Skip to content

Commit 085b573

Browse files
authored
Merge branch 'develop' into develop
2 parents 575c006 + 509def8 commit 085b573

File tree

7 files changed

+6
-38
lines changed

7 files changed

+6
-38
lines changed

docs/source/api/util.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,4 @@ Utility functions
1616

1717
.. autofunction:: pybamm.has_jax
1818

19-
.. autofunction:: pybamm.is_jax_compatible
20-
2119
.. autofunction:: pybamm.set_logging_level

pyproject.toml

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ plot = [
7979
"matplotlib>=3.6.0",
8080
]
8181
cite = [
82-
"setuptools", # Fix for a pybtex issue
83-
"pybtex>=0.24.0",
82+
"pybtex>=0.25.0",
8483
]
8584
# Battery Parameter eXchange format
8685
bpx = [
@@ -112,13 +111,9 @@ dev = [
112111
"importlib-metadata; python_version < '3.10'",
113112
# For property based testing
114113
"hypothesis",
115-
116114
]
117-
# For the Jax solver.
118-
# Note: These must be kept in sync with the versions defined in pybamm/util.py
119115
jax = [
120-
"jax==0.4.27",
121-
"jaxlib==0.4.27",
116+
"jax>=0.4.36,<0.6.0",
122117
]
123118
# Contains all optional dependencies, except for jax, and dev dependencies
124119
all = [

src/pybamm/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
get_parameters_filepath,
1313
has_jax,
1414
import_optional_dependency,
15-
is_jax_compatible,
1615
)
1716
from .logger import logger, set_logging_level, get_new_logger
1817
from .settings import settings

src/pybamm/solvers/idaklu_jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ def _jaxify(
600600
self._register_callbacks() # Register python methods as callbacks in IDAKLU-JAX
601601

602602
for _name, _value in idaklu.registrations().items():
603+
# todo: This has been removed from jax v0.6.0
603604
xla_client.register_custom_call_target(
604605
f"{_name}_{self._unique_name()}", _value, platform="cpu"
605606
)

src/pybamm/solvers/jax_bdf_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ def merge(l1, l2):
910910
return out, merge
911911

912912
def abstractify(x):
913-
return core.raise_to_shaped(core.get_aval(x))
913+
return core.get_aval(x)
914914

915915
def ravel_first_arg(f, unravel):
916916
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped

src/pybamm/util.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@
1212

1313
import pybamm
1414

15-
# Versions of jax and jaxlib compatible with PyBaMM. Note: these are also defined in
16-
# the extras dependencies in pyproject.toml, and therefore must be kept in sync.
17-
JAX_VERSION = "0.4.27"
18-
JAXLIB_VERSION = "0.4.27"
19-
2015

2116
def root_dir():
2217
"""return the root directory of the PyBaMM install directory"""
@@ -354,27 +349,11 @@ def has_jax():
354349
True if jax and jaxlib are installed with the correct versions, False if otherwise
355350
356351
"""
357-
return (
358-
(importlib.util.find_spec("jax") is not None)
359-
and (importlib.util.find_spec("jaxlib") is not None)
360-
and is_jax_compatible()
352+
return (importlib.util.find_spec("jax") is not None) and (
353+
importlib.util.find_spec("jaxlib") is not None
361354
)
362355

363356

364-
def is_jax_compatible():
365-
"""
366-
Check if the available versions of jax and jaxlib are compatible with PyBaMM
367-
368-
Returns
369-
-------
370-
bool
371-
True if jax and jaxlib are compatible with PyBaMM, False if otherwise
372-
"""
373-
return importlib.metadata.distribution("jax").version.startswith(
374-
JAX_VERSION
375-
) and importlib.metadata.distribution("jaxlib").version.startswith(JAXLIB_VERSION)
376-
377-
378357
def is_constant_and_can_evaluate(symbol):
379358
"""
380359
Returns True if symbol is constant and evaluation does not raise any errors.

tests/unit/test_util.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,6 @@ def test_get_parameters_filepath(self, tmp_path):
9090
os.path.join(pybamm.root_dir(), "src", "pybamm", temppath)
9191
)
9292

93-
@pytest.mark.skipif(not pybamm.has_jax(), reason="JAX is not installed")
94-
def test_is_jax_compatible(self):
95-
assert pybamm.is_jax_compatible()
96-
9793
def test_import_optional_dependency(self):
9894
optional_distribution_deps = get_optional_distribution_deps("pybamm")
9995
present_optional_import_deps = get_present_optional_import_deps(

0 commit comments

Comments
 (0)