Skip to content

Commit 633d17b

Browse files
authored
Merge branch 'main' into bes/petabv2_jax
2 parents db6881c + 9ede529 commit 633d17b

2 files changed

Lines changed: 11 additions & 7 deletions

File tree

python/sdist/amici/jax/petab.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ def _get_measurements(
314314
for col in [petabv1.OBSERVABLE_PARAMETERS, petabv1.NOISE_PARAMETERS]:
315315
n_pars[col] = 0
316316
if col in self._petab_problem.measurement_df:
317-
if np.issubdtype(
318-
self._petab_problem.measurement_df[col].dtype, np.number
317+
if pd.api.types.is_numeric_dtype(
318+
self._petab_problem.measurement_df[col].dtype
319319
):
320320
n_pars[col] = 1 - int(
321321
self._petab_problem.measurement_df[col].isna().all()
@@ -416,7 +416,7 @@ def get_parameter_override(x):
416416
mat_numeric = jnp.ones((len(m), n_pars[col]))
417417
par_mask = np.zeros_like(mat_numeric, dtype=bool)
418418
par_index = np.zeros_like(mat_numeric, dtype=int)
419-
elif np.issubdtype(m[col].dtype, np.number):
419+
elif pd.api.types.is_numeric_dtype(m[col].dtype):
420420
mat_numeric = np.expand_dims(m[col].values, axis=1)
421421
par_mask = np.zeros_like(mat_numeric, dtype=bool)
422422
par_index = np.zeros_like(mat_numeric, dtype=int)

tests/benchmark_models/test_petab_benchmark_jax.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from functools import partial
33

4+
import diffrax
45
import equinox as eqx
56
import jax
67
import jax.numpy as jnp
@@ -23,8 +24,6 @@
2324
settings,
2425
)
2526

26-
import diffrax
27-
2827
jax.config.update("jax_enable_x64", True)
2928

3029

@@ -38,7 +37,12 @@ def test_jax_llh(benchmark_problem):
3837
benchmark_problem
3938
)
4039

41-
to_skip = ["Smith_BMCSystBiol2013", "Oliveira_NatCommun2021", "SalazarCavazos_MBoC2020"]
40+
to_skip = [
41+
"Liu_IFACPapersOnLine2025",
42+
"Oliveira_NatCommun2021",
43+
"SalazarCavazos_MBoC2020",
44+
"Smith_BMCSystBiol2013",
45+
]
4246
if problem_id in to_skip:
4347
pytest.skip(
4448
f"Skipping {problem_id} due to non-supported events in JAX."
@@ -133,7 +137,7 @@ def test_jax_llh(benchmark_problem):
133137
)
134138
else:
135139
llh_jax, _ = beartype(run_simulations)(jax_problem)
136-
140+
137141
np.testing.assert_allclose(
138142
llh_jax,
139143
llh_amici,

0 commit comments

Comments
 (0)