@@ -603,54 +603,9 @@ def __init__(self, symbol: pybamm.Symbol):
603603 static_argnums = self ._static_argnums ,
604604 )
605605
606- def _demote_constants (self ):
607- """Demote 64-bit constants (f64, i64) to 32-bit (f32, i32)"""
608- if not pybamm .demote_expressions_to_32bit :
609- return # pragma: no cover
610- self ._constants = EvaluatorJax ._demote_64_to_32 (self ._constants )
611-
612- @classmethod
613- def _demote_64_to_32 (cls , c ):
614- """Demote 64-bit operations (f64, i64) to 32-bit (f32, i32)"""
615-
616- if not pybamm .demote_expressions_to_32bit :
617- return c
618- if isinstance (c , float ):
619- c = jax .numpy .float32 (c )
620- if isinstance (c , int ):
621- c = jax .numpy .int32 (c )
622- if isinstance (c , np .int64 ):
623- c = c .astype (jax .numpy .int32 )
624- if isinstance (c , np .ndarray ):
625- if c .dtype == np .float64 :
626- c = c .astype (jax .numpy .float32 )
627- if c .dtype == np .int64 :
628- c = c .astype (jax .numpy .int32 )
629- if isinstance (c , jax .numpy .ndarray ):
630- if c .dtype == jax .numpy .float64 :
631- c = c .astype (jax .numpy .float32 )
632- if c .dtype == jax .numpy .int64 :
633- c = c .astype (jax .numpy .int32 )
634- if isinstance (
635- c , pybamm .expression_tree .operations .evaluate_python .JaxCooMatrix
636- ):
637- if c .data .dtype == np .float64 :
638- c .data = c .data .astype (jax .numpy .float32 )
639- if c .row .dtype == np .int64 :
640- c .row = c .row .astype (jax .numpy .int32 )
641- if c .col .dtype == np .int64 :
642- c .col = c .col .astype (jax .numpy .int32 )
643- if isinstance (c , dict ):
644- c = {key : EvaluatorJax ._demote_64_to_32 (value ) for key , value in c .items ()}
645- if isinstance (c , tuple ):
646- c = tuple (EvaluatorJax ._demote_64_to_32 (value ) for value in c )
647- if isinstance (c , list ):
648- c = [EvaluatorJax ._demote_64_to_32 (value ) for value in c ]
649- return c
650-
651606 @property
652607 def _constants (self ):
653- return tuple ( map ( EvaluatorJax . _demote_64_to_32 , self .__constants ))
608+ return self .__constants
654609
655610 @_constants .setter
656611 def _constants (self , value ):
0 commit comments