@@ -114,6 +114,10 @@ def _is_none(x: Any) -> bool:
114114 return x is None
115115
116116
117+ class TermAndSolverIncompatible (ValueError ):
118+ pass
119+
120+
117121def _assert_term_compatible (
118122 t : FloatScalarLike ,
119123 y : PyTree [ArrayLike ],
@@ -137,7 +141,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
137141 ):
138142 _assert_term_compatible (t , yi , args , term , arg , term_contr_kwarg )
139143 else :
140- raise ValueError (
144+ raise TermAndSolverIncompatible (
141145 f"Term { term } is not a MultiTerm but is expected to be."
142146 )
143147 else :
@@ -147,7 +151,9 @@ def _check(term_cls, term, term_contr_kwargs, yi):
147151 if origin_cls is None :
148152 origin_cls = term_cls
149153 if not isinstance (term , origin_cls ):
150- raise ValueError (f"Term { term } is not an instance of { origin_cls } ." )
154+ raise TermAndSolverIncompatible (
155+ f"Term { term } is not an instance of { origin_cls } ."
156+ )
151157
152158 # Now check the generic parametrization of `term_cls`; can be one of:
153159 # -----------------------------------------
@@ -167,7 +173,9 @@ def _check(term_cls, term, term_contr_kwargs, yi):
167173 better_isinstance , vf_type , vf_type_expected
168174 )
169175 if not vf_type_compatible :
170- raise ValueError (f"Vector field term { term } is incompatible." )
176+ raise TermAndSolverIncompatible (
177+ f"Vector field term { term } is incompatible."
178+ )
171179
172180 contr = ft .partial (term .contr , ** term_contr_kwargs )
173181 # Work around https://github.com/google/jax/issues/21825
@@ -176,7 +184,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
176184 better_isinstance , control_type , control_type_expected
177185 )
178186 if not control_type_compatible :
179- raise ValueError (
187+ raise TermAndSolverIncompatible (
180188 "Control term is incompatible: the returned control (e.g. "
181189 f"Brownian motion for an SDE) was { control_type } , but this "
182190 f"solver expected { control_type_expected } ."
@@ -188,11 +196,10 @@ def _check(term_cls, term, term_contr_kwargs, yi):
188196 try :
189197 with jax .numpy_dtype_promotion ("standard" ):
190198 jtu .tree_map (_check , term_structure , terms , contr_kwargs , y )
191- except ValueError as e :
192- # ValueError may also arise from mismatched tree structures
199+ except TermAndSolverIncompatible as e :
193200 pretty_term = wl .pformat (terms )
194201 pretty_expected = wl .pformat (term_structure )
195- raise ValueError (
202+ raise TermAndSolverIncompatible (
196203 f"Terms are not compatible with solver! Got:\n { pretty_term } \n but expected:"
197204 f"\n { pretty_expected } \n Note that terms are checked recursively: if you "
198205 "scroll up you may find a root-cause error that is more specific."
0 commit comments