File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -314,8 +314,8 @@ def _get_measurements(
314314 for col in [petabv1 .OBSERVABLE_PARAMETERS , petabv1 .NOISE_PARAMETERS ]:
315315 n_pars [col ] = 0
316316 if col in self ._petab_problem .measurement_df :
317- if np . issubdtype (
318- self ._petab_problem .measurement_df [col ].dtype , np . number
317+ if pd . api . types . is_numeric_dtype (
318+ self ._petab_problem .measurement_df [col ].dtype
319319 ):
320320 n_pars [col ] = 1 - int (
321321 self ._petab_problem .measurement_df [col ].isna ().all ()
@@ -416,7 +416,7 @@ def get_parameter_override(x):
416416 mat_numeric = jnp .ones ((len (m ), n_pars [col ]))
417417 par_mask = np .zeros_like (mat_numeric , dtype = bool )
418418 par_index = np .zeros_like (mat_numeric , dtype = int )
419- elif np . issubdtype (m [col ].dtype , np . number ):
419+ elif pd . api . types . is_numeric_dtype (m [col ].dtype ):
420420 mat_numeric = np .expand_dims (m [col ].values , axis = 1 )
421421 par_mask = np .zeros_like (mat_numeric , dtype = bool )
422422 par_index = np .zeros_like (mat_numeric , dtype = int )
Original file line number Diff line number Diff line change 11import logging
22from functools import partial
33
4+ import diffrax
45import equinox as eqx
56import jax
67import jax .numpy as jnp
2324 settings ,
2425)
2526
26- import diffrax
27-
2827jax .config .update ("jax_enable_x64" , True )
2928
3029
@@ -38,7 +37,12 @@ def test_jax_llh(benchmark_problem):
3837 benchmark_problem
3938 )
4039
41- to_skip = ["Smith_BMCSystBiol2013" , "Oliveira_NatCommun2021" , "SalazarCavazos_MBoC2020" ]
40+ to_skip = [
41+ "Liu_IFACPapersOnLine2025" ,
42+ "Oliveira_NatCommun2021" ,
43+ "SalazarCavazos_MBoC2020" ,
44+ "Smith_BMCSystBiol2013" ,
45+ ]
4246 if problem_id in to_skip :
4347 pytest .skip (
4448 f"Skipping { problem_id } due to non-supported events in JAX."
@@ -133,7 +137,7 @@ def test_jax_llh(benchmark_problem):
133137 )
134138 else :
135139 llh_jax , _ = beartype (run_simulations )(jax_problem )
136-
140+
137141 np .testing .assert_allclose (
138142 llh_jax ,
139143 llh_amici ,
You can’t perform that action at this time.
0 commit comments