44from pathlib import Path
55from types import MappingProxyType
66
7+ import pandas as pd
78from jax import Array
89
910from lcm .ages import AgeGrid
11+ from lcm .exceptions import InvalidParamsError
1012from lcm .grids import DiscreteGrid
1113from lcm .model_processing import (
1214 build_regimes_and_template ,
1315 validate_model_inputs ,
1416)
17+ from lcm .pandas_utils import (
18+ convert_series_in_params ,
19+ has_series ,
20+ initial_conditions_from_dataframe ,
21+ )
22+ from lcm .params import MappingLeaf , SequenceLeaf
1523from lcm .params .processing import (
1624 process_params ,
1725)
@@ -195,6 +203,7 @@ def solve(
195203 internal_params = _maybe_convert_series (
196204 internal_params , model = self , derived_categoricals = derived_categoricals
197205 )
206+ _validate_param_types (internal_params )
198207 validate_regime_transitions_all_periods (
199208 internal_regimes = self .internal_regimes ,
200209 internal_params = internal_params ,
@@ -283,6 +292,7 @@ def simulate(
283292 internal_params = _maybe_convert_series (
284293 internal_params , model = self , derived_categoricals = derived_categoricals
285294 )
295+ _validate_param_types (internal_params )
286296 if check_initial_conditions :
287297 validate_initial_conditions (
288298 initial_conditions = initial_conditions ,
@@ -336,8 +346,6 @@ def _maybe_convert_series(
336346 | None ,
337347) -> InternalParams :
338348 """Convert pd.Series leaves in params to JAX arrays if any are present."""
339- from lcm .pandas_utils import convert_series_in_params , has_series # noqa: PLC0415
340-
341349 if derived_categoricals is not None or has_series (internal_params ):
342350 return convert_series_in_params (
343351 internal_params = internal_params ,
@@ -347,17 +355,51 @@ def _maybe_convert_series(
347355 return internal_params
348356
349357
358+ def _validate_param_types (internal_params : InternalParams ) -> None :
359+ """Raise if any param leaf is not a Python scalar or JAX array.
360+
361+ After processing, every leaf value (including inside MappingLeaf /
362+ SequenceLeaf containers) must be a Python scalar (float, int, bool) or a
363+ JAX array. Notably, numpy arrays and pandas Series are not accepted.
364+ """
365+ for regime_name , regime_params in internal_params .items ():
366+ for key , value in regime_params .items ():
367+ _check_leaf (value , f"{ regime_name } __{ key } " )
368+
369+
370+ def _check_leaf (value : object , path : str ) -> None :
371+ """Check a single leaf value, recursing into MappingLeaf/SequenceLeaf."""
372+ if isinstance (value , MappingLeaf ):
373+ for k , v in value .data .items ():
374+ _check_leaf (v , f"{ path } .{ k } " )
375+ return
376+ if isinstance (value , SequenceLeaf ):
377+ for i , v in enumerate (value .data ):
378+ _check_leaf (v , f"{ path } [{ i } ]" )
379+ return
380+ if isinstance (value , (float , int , bool )):
381+ return
382+ if hasattr (value , "dtype" ) and hasattr (value , "shape" ):
383+ if isinstance (value , Array ):
384+ return
385+ type_name = type (value ).__module__ + "." + type (value ).__name__
386+ msg = (
387+ f"Parameter '{ path } ' is a { type_name } (shape { value .shape } ). "
388+ f"Use jax.numpy.array() or pass a pd.Series with a named index."
389+ )
390+ raise InvalidParamsError (msg )
391+ type_name = type (value ).__module__ + "." + type (value ).__name__
392+ msg = f"Parameter '{ path } ' has unexpected type { type_name } ."
393+ raise InvalidParamsError (msg )
394+
395+
350396def _maybe_convert_dataframe (
351397 initial_conditions : Mapping [str , Array ],
352398 * ,
353399 model : Model ,
354400) -> Mapping [str , Array ]:
355401 """Convert a DataFrame to initial_conditions dict if needed."""
356- import pandas as pd # noqa: PLC0415
357-
358402 if isinstance (initial_conditions , pd .DataFrame ):
359- from lcm .pandas_utils import initial_conditions_from_dataframe # noqa: PLC0415
360-
361403 return initial_conditions_from_dataframe (df = initial_conditions , model = model )
362404 return initial_conditions
363405
0 commit comments