Skip to content

Commit 53be60d

Browse files
committed
Move amici.jax to amici.exporters.jax and amici.sim.jax
Related to #3041.
1 parent 0d7e49e commit 53be60d

27 files changed

Lines changed: 239 additions & 241 deletions

doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
"source": [
6666
"## Simulation\n",
6767
"\n",
68-
"We can now run efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)."
68+
"We can now run efficient simulation using [amici.sim.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)."
6969
]
7070
},
7171
{
@@ -75,7 +75,7 @@
7575
"metadata": {},
7676
"outputs": [],
7777
"source": [
78-
"from amici.jax import run_simulations\n",
78+
"from amici.sim.jax import run_simulations\n",
7979
"\n",
8080
"# Run simulations and compute the log-likelihood\n",
8181
"llh, results = run_simulations(jax_problem)"
@@ -386,7 +386,7 @@
386386
"import diffrax\n",
387387
"import jax.numpy as jnp\n",
388388
"import optimistix\n",
389-
"from amici.jax import ReturnValue\n",
389+
"from amici.sim.jax import ReturnValue\n",
390390
"\n",
391391
"# Define the simulation condition\n",
392392
"experiment_condition = \"_petab_experiment_condition___default__\"\n",

doc/python_modules.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ AMICI Python API
1515
amici.importers.petab
1616
amici.importers.petab.v1
1717
amici.importers.utils
18-
amici.jax
18+
amici.sim.jax
1919
amici.sim.sundials
2020
amici.sim.sundials.plotting
2121
amici.sim.sundials.gradient_check

python/sdist/amici/_symbolic/de_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2499,7 +2499,8 @@ def _process_hybridization(self, hybridization: dict) -> None:
24992499
orig_obs = tuple([s.get_sym() for s in self._observables])
25002500
for net_id, net in hybridization.items():
25012501
if net["static"]:
2502-
continue # do not integrate into ODEs, handle in amici.jax.petab
2502+
# do not integrate into ODEs, handle in amici.sim.jax.petab
2503+
continue
25032504
inputs = [
25042505
comp
25052506
for comp in self._components
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
Code generation for JAX models for simulation with diffrax solvers.
3+
4+
This module provides an interface to generate and use AMICI models with JAX.
5+
Please note that this module is experimental, the API may substantially change
6+
in the future. Use at your own risk and do not expect backward compatibility.
7+
"""
8+
9+
from warnings import warn
10+
11+
from .nn import Flatten, cat, generate_equinox, tanhshrink
12+
13+
warn(
14+
"The JAX module is experimental and the API may change in the future.",
15+
ImportWarning,
16+
stacklevel=2,
17+
)
18+
19+
__all__ = [
20+
"Flatten",
21+
"generate_equinox",
22+
"tanhshrink",
23+
"cat",
24+
]

python/sdist/amici/jax/jax.template.py renamed to python/sdist/amici/exporters/jax/jax.template.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# ruff: noqa: F401, F821, F841
22
from pathlib import Path
33

4-
import equinox as eqx
4+
import equinox as eqx # noqa: F401
55
import jax.numpy as jnp
6-
import jax.random as jr
7-
import jaxtyping as jt
8-
from interpax import interp1d
9-
from jax.numpy import inf as oo
10-
from jax.numpy import nan as nan
11-
12-
from amici import _module_from_path
13-
from amici.jax.model import JAXModel, safe_div, safe_log
6+
import jax.random as jr # noqa: F401
7+
import jaxtyping as jt # noqa: F401
8+
from interpax import interp1d # noqa: F401
9+
from jax.numpy import inf as oo # noqa: F401
10+
from jax.numpy import nan as nan # noqa: F401
11+
12+
from amici import _module_from_path # noqa: F401
13+
from amici.sim.jax.model import JAXModel, safe_div, safe_log # noqa: F401
1414

1515
TPL_NET_IMPORTS
1616

File renamed without changes.
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import jax.numpy as jnp
55

66
from amici import amiciModulePath
7-
8-
from ..exporters.template import apply_template
7+
from amici.exporters.template import apply_template
98

109

1110
class Flatten(eqx.Module):
@@ -185,7 +184,7 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F
185184
layer_map = {
186185
"Dropout1d": "eqx.nn.Dropout",
187186
"Dropout2d": "eqx.nn.Dropout",
188-
"Flatten": "amici.jax.Flatten",
187+
"Flatten": "amici.export.jax.Flatten",
189188
}
190189

191190
# mapping of keyword argument names in sciml yaml format to equinox/custom amici implementations
@@ -321,9 +320,9 @@ def _process_activation_call(node: "Node") -> str: # noqa: F821
321320
"hardtanh": "jax.nn.hard_tanh",
322321
"hardsigmoid": "jax.nn.hard_sigmoid",
323322
"hardswish": "jax.nn.hard_swish",
324-
"tanhshrink": "amici.jax.tanhshrink",
323+
"tanhshrink": "amici.export.jax.tanhshrink",
325324
"softsign": "jax.nn.soft_sign",
326-
"cat": "amici.jax.cat",
325+
"cat": "amici.export.jax.cat",
327326
}
328327

329328
# Validate hardtanh parameters

python/sdist/amici/jax/nn.template.py renamed to python/sdist/amici/exporters/jax/nn.template.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
# ruff: noqa: F401, F821, F841
22
import equinox as eqx
3-
import jax
4-
import jax.nn
5-
import jax.numpy as jnp
63
import jax.random as jr
74

8-
import amici.jax.nn
9-
105

116
class TPL_MODEL_ID(eqx.Module):
127
layers: dict

python/sdist/amici/jax/ode_export.py renamed to python/sdist/amici/exporters/jax/ode_export.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,17 @@
1818

1919
import sympy as sp
2020

21-
from amici import (
22-
amiciModulePath,
23-
)
2421
from amici._symbolic.de_model import DEModel
2522
from amici._symbolic.sympy_utils import (
2623
_monkeypatch_sympy,
2724
)
25+
from amici.exporters.jax.nn import generate_equinox
2826
from amici.exporters.sundials.de_export import is_valid_identifier
2927
from amici.exporters.template import apply_template
30-
from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str
31-
from amici.jax.model import JAXModel
32-
from amici.jax.nn import generate_equinox
3328
from amici.logging import get_logger, log_execution_time, set_log_level
29+
from amici.sim.jax.model import JAXModel
30+
31+
from .jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str
3432

3533
#: python log manager
3634
logger = get_logger(__name__, logging.ERROR)
@@ -303,7 +301,7 @@ def _generate_jax_code(self) -> None:
303301
}
304302

305303
apply_template(
306-
Path(amiciModulePath) / "jax" / "jax.template.py",
304+
Path(__file__).parent / "jax.template.py",
307305
self.model_path / "__init__.py",
308306
tpl_data,
309307
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Code generation for models for simulation with SUNDIALS solvers."""

0 commit comments

Comments
 (0)