@@ -119,6 +119,7 @@ def _is_none(x: Any) -> bool:
119119
120120
121121def _assert_term_compatible (
122+ t : FloatScalarLike ,
122123 y : PyTree [ArrayLike ],
123124 args : PyTree [Any ],
124125 terms : PyTree [AbstractTerm ],
@@ -138,7 +139,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
138139 for term , arg , term_contr_kwarg in zip (
139140 term .terms , get_args (_tmp ), term_contr_kwargs
140141 ):
141- _assert_term_compatible (yi , args , term , arg , term_contr_kwarg )
142+ _assert_term_compatible (t , yi , args , term , arg , term_contr_kwarg )
142143 else :
143144 raise ValueError (
144145 f"Term { term } is not a MultiTerm but is expected to be."
@@ -166,7 +167,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
166167 elif n_term_args == 2 :
167168 vf_type_expected , control_type_expected = term_args
168169 try :
169- vf_type = eqx .filter_eval_shape (term .vf , 0.0 , yi , args )
170+ vf_type = eqx .filter_eval_shape (term .vf , t , yi , args )
170171 except Exception as e :
171172 raise ValueError (f"Error while tracing { term } .vf: " + str (e ))
172173 vf_type_compatible = eqx .filter_eval_shape (
@@ -178,7 +179,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
178179 contr = ft .partial (term .contr , ** term_contr_kwargs )
179180 # Work around https://github.com/google/jax/issues/21825
180181 try :
181- control_type = eqx .filter_eval_shape (contr , 0.0 , 0.0 )
182+ control_type = eqx .filter_eval_shape (contr , t , t )
182183 except Exception as e :
183184 raise ValueError (f"Error while tracing { term } .contr: " + str (e ))
184185 control_type_compatible = eqx .filter_eval_shape (
@@ -1077,6 +1078,7 @@ def _promote(yi):
10771078 if isinstance (solver , (EulerHeun , ItoMilstein , StratonovichMilstein )):
10781079 try :
10791080 _assert_term_compatible (
1081+ t0 ,
10801082 y0 ,
10811083 args ,
10821084 terms ,
@@ -1098,6 +1100,7 @@ def _promote(yi):
10981100
10991101 # Error checking for term compatibility
11001102 _assert_term_compatible (
1103+ t0 ,
11011104 y0 ,
11021105 args ,
11031106 terms ,
0 commit comments