Skip to content

Commit 323b46b

Browse files
authored
Fix importing PetabImporter without jax and sciml extras (#3141)
* Fix importing PetabImporter without jax extras Closes #3140. * local h5py import
1 parent 4e223a0 commit 323b46b

4 files changed

Lines changed: 21 additions & 12 deletions

File tree

CHANGELOG.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,15 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni
44

55
## v1.X Series
66

7-
### v1.0.0 (unreleased)
7+
### v1.0.1
8+
9+
**Fixes**
10+
11+
* Fixed an issue that resulted in failure to import the `PetabImporter` if
12+
the jax-dependencies weren't installed.
13+
14+
15+
### v1.0.0
816

917
**BREAKING CHANGES**
1018

@@ -39,7 +47,7 @@ The following functionality has been removed without replacement:
3947
fixed parameters as "fixed parameters", "constant parameters",
4048
or "constants". This has now been harmonized to "free" and "fixed" across the
4149
API. E.g., `Model.setParameters()` is now `Model.set_free_parameters()`.
42-
* `ReturnDataView.posteq_numsteps` and `ReturnDataView.posteq_numsteps` now
50+
* `ReturnDataView.posteq_numsteps` and `ReturnDataView.preeq_numsteps` now
4351
return a one-dimensional array of shape `(num_timepoints,)` instead of a
4452
two-dimensional array of shape `(1, num_timepoints)`.
4553
* `ReturnDataView.posteq_status` and `ReturnDataView.preeq_status` now

python/sdist/amici/importers/petab/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,4 @@
2121
passing a :class:`petab.v1.Problem` instance to the PEtab v2 import functions.
2222
"""
2323

24-
# FIXME: for some tests (petab-sciml, maybe petab-v1-pysb) we still rely on an
25-
# old PEtab version on which the petab v2 import does not work.
26-
# Once those tests are updated, we can remove this try-except block.
27-
try:
28-
from ._petab_importer import * # noqa: F403, F401
29-
except ImportError:
30-
pass
24+
from ._petab_importer import * # noqa: F403, F401

python/sdist/amici/importers/petab/_petab_importer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from amici._symbolic import DEModel, Event
2525
from amici.importers.utils import MeasurementChannel, amici_time_symbol
2626
from amici.logging import get_logger
27-
from amici.sim.jax.petab import JAXProblem
2827

2928
from .v1._sbml_import import _add_global_parameter
3029

@@ -594,7 +593,7 @@ def create_model(self) -> amici.sim.sundials.Model:
594593

595594
def create_simulator(
596595
self, force_import: bool = False
597-
) -> amici.sim.sundials.petab.PetabSimulator:
596+
) -> amici.sim.sundials.petab.PetabSimulator | amici.sim.jax.JAXProblem:
598597
"""
599598
Create a PEtab simulator for the imported model.
600599
@@ -607,6 +606,9 @@ def create_simulator(
607606
if self._jax:
608607
model_module = self.import_module(force_import=force_import)
609608
model = model_module.Model()
609+
610+
from amici.sim.jax.petab import JAXProblem
611+
610612
return JAXProblem(model, self.petab_problem)
611613

612614
model = self.import_module(force_import=force_import).get_model()

python/sdist/amici/sim/jax/petab.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import diffrax
1212
import equinox as eqx
13-
import h5py
1413
import jax.lax
1514
import jax.numpy as jnp
1615
import jaxtyping as jt
@@ -593,6 +592,9 @@ def _load_parameter_arrays_from_files(self) -> dict:
593592
"array_files", []
594593
)
595594

595+
import h5py
596+
597+
# TODO(performance): Avoid opening each file multiple times
596598
return {
597599
file_spec.split("_")[0]: h5py.File(file_spec, "r")["parameters"][
598600
file_spec.split("_")[0]
@@ -615,6 +617,9 @@ def _load_input_arrays_from_files(self) -> dict:
615617
"array_files", []
616618
)
617619

620+
import h5py
621+
622+
# TODO(performance): Avoid opening each file multiple times
618623
return {
619624
file_spec.split("_")[0]: h5py.File(file_spec, "r")["inputs"]
620625
for file_spec in array_files

0 commit comments

Comments
 (0)