Skip to content

Commit 4c6947e

Browse files
committed
update testsuite, some fixes
1 parent eaf3e09 commit 4c6947e

5 files changed

Lines changed: 105 additions & 116 deletions

File tree

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[submodule "tests/sciml/testsuite"]
22
path = tests/sciml/testsuite
3-
url = https://github.com/sebapersson/petab_sciml
3+
url = https://github.com/sebapersson/petab_sciml_testsuite

python/sdist/amici/de_model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2381,10 +2381,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
23812381
"""
23822382
added_expressions = False
23832383
for net_id, net in hybridization.items():
2384-
if not (
2385-
net["hybridization"]["output"] == "ode"
2386-
or net["hybridization"]["input"] == "ode"
2387-
):
2384+
if net["static"]:
23882385
continue # do not integrate into ODEs, handle in amici.jax.petab
23892386
inputs = [
23902387
comp

python/sdist/amici/jax/petab.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import pandas as pd
1818
import petab.v1 as petab
19+
import h5py
1920

2021
from amici import _module_from_path
2122
from amici.petab.parameter_mapping import (
@@ -128,8 +129,6 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem):
128129
self._np_indices,
129130
) = self._get_measurements(scs)
130131

131-
self.parameters = self._get_nominal_parameter_values()
132-
133132
def save(self, directory: Path):
134133
"""
135134
Save the problem to a directory.
@@ -518,6 +517,13 @@ def _get_nominal_parameter_values(
518517
if (net := pname.split(".")[0]) in model.nns:
519518
to_set = []
520519
nn = model_pars[net]
520+
scalar = True
521+
try:
522+
value = float(row[petab.NOMINAL_VALUE])
523+
except ValueError:
524+
value = h5py.File(row[petab.NOMINAL_VALUE], "r")
525+
scalar = False
526+
521527
if len(pname.split(".")) > 1:
522528
layer = nn[pname.split(".")[1]]
523529
if len(pname.split(".")) > 2:
@@ -541,9 +547,12 @@ def _get_nominal_parameter_values(
541547
)
542548

543549
for layer, attribute in to_set:
544-
nn[layer][attribute] = row[
545-
petab.NOMINAL_VALUE
546-
] * jnp.ones_like(nn[layer][attribute])
550+
if scalar:
551+
nn[layer][attribute] = value * jnp.ones_like(
552+
nn[layer][attribute]
553+
)
554+
else:
555+
nn[layer][attribute] = value[layer][attribute]
547556

548557
# set values in model
549558
for net_id in model_pars:
@@ -559,9 +568,11 @@ def _get_nominal_parameter_values(
559568
return jnp.array(
560569
[
561570
petab.scale(
562-
self._petab_problem.parameter_df.loc[
563-
pval, petab.NOMINAL_VALUE
564-
],
571+
float(
572+
self._petab_problem.parameter_df.loc[
573+
pval, petab.NOMINAL_VALUE
574+
]
575+
),
565576
self._petab_problem.parameter_df.loc[
566577
pval, petab.PARAMETER_SCALE
567578
],
@@ -604,7 +615,12 @@ def parameter_ids(self) -> list[str]:
604615
PEtab parameter ids
605616
"""
606617
return self._petab_problem.parameter_df[
607-
self._petab_problem.parameter_df[petab.ESTIMATE] == 1
618+
self._petab_problem.parameter_df[petab.ESTIMATE]
619+
== 1
620+
& pd.to_numeric(
621+
self._petab_problem.parameter_df[petab.NOMINAL_VALUE],
622+
errors="coerce",
623+
).notna()
608624
].index.tolist()
609625

610626
@property
@@ -886,7 +902,9 @@ def _prepare_conditions(
886902
Tuple of parameter arrays, reinitialisation masks and reinitialisation values, observable parameters and
887903
noise parameters.
888904
"""
889-
p_array = jnp.stack([self.load_parameters(sc) for sc in conditions])
905+
p_array = jnp.stack(
906+
[self.load_model_parameters(sc) for sc in conditions]
907+
)
890908
unscaled_parameters = jnp.stack(
891909
[
892910
jax_unscale(

python/sdist/amici/petab/petab_import.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,16 @@ def import_petab_problem(
146146

147147
logger.info(f"Compiling model {model_name} to {model_output_dir}.")
148148

149-
if "petab_sciml" in petab_problem.extensions_config:
149+
if "neural_nets" in petab_problem.extensions_config: # TODO: fixme
150150
from petab_sciml import PetabScimlStandard
151151

152-
config = petab_problem.extensions_config["petab_sciml"]
152+
config = petab_problem.extensions_config
153+
# TODO: only accept YAML format for now
154+
# TODO: edit petab library to load hybridization table and map input and output vars here
153155
hybridization = {
154156
net_id: {
155157
"model": PetabScimlStandard.load_data(
156-
Path() / net_config["file"]
158+
Path() / net_config["location"]
157159
).models,
158160
"input_vars": [
159161
petab_id
@@ -183,7 +185,7 @@ def import_petab_problem(
183185
],
184186
**net_config,
185187
}
186-
for net_id, net_config in config.items()
188+
for net_id, net_config in config["neural_nets"].items()
187189
}
188190
if not jax or petab_problem.model.type_id == MODEL_TYPE_PYSB:
189191
raise NotImplementedError(

tests/sciml/test_sciml.py

Lines changed: 69 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def change_directory(destination):
4444
# pip install git+https://github.com/sebapersson/petab_sciml@add_standard#egg=petab_sciml\&subdirectory=src/python
4545

4646
cases_dir = Path(__file__).parent / "testsuite" / "test_cases"
47+
net_cases_dir = cases_dir / "net_import"
48+
ude_cases_dir = cases_dir / "hybrid"
4749

4850

4951
def _reshape_flat_array(array_flat):
@@ -65,10 +67,10 @@ def _reshape_flat_array(array_flat):
6567

6668

6769
@pytest.mark.parametrize(
68-
"test", sorted([d.stem for d in cases_dir.glob("net_[0-9]*")])
70+
"test", sorted([d.stem for d in net_cases_dir.glob("[0-9]*")])
6971
)
7072
def test_net(test):
71-
test_dir = cases_dir / test
73+
test_dir = net_cases_dir / test
7274
with open(test_dir / "solutions.yaml") as f:
7375
solutions = safe_load(f)
7476

@@ -83,20 +85,20 @@ def test_net(test):
8385
for ml_model in ml_models.models:
8486
module_dir = outdir / f"{ml_model.mlmodel_id}.py"
8587
if test in (
86-
"net_002",
87-
"net_009",
88-
"net_018",
89-
"net_019",
90-
"net_020",
91-
"net_021",
92-
"net_022",
93-
"net_042",
94-
"net_043",
95-
"net_044",
96-
"net_045",
97-
"net_046",
98-
"net_047",
99-
"net_048",
88+
"002",
89+
"009",
90+
"018",
91+
"019",
92+
"020",
93+
"021",
94+
"022",
95+
"042",
96+
"043",
97+
"044",
98+
"045",
99+
"046",
100+
"047",
101+
"048",
100102
):
101103
with pytest.raises(NotImplementedError):
102104
generate_equinox(ml_model, module_dir)
@@ -111,34 +113,20 @@ def test_net(test):
111113
solutions.get("net_ps", solutions["net_input"]),
112114
solutions["net_output"],
113115
):
114-
input_flat = pd.read_csv(test_dir / input_file, sep="\t")
115-
input = _reshape_flat_array(input_flat)
116-
117-
output_flat = pd.read_csv(test_dir / output_file, sep="\t")
118-
output = _reshape_flat_array(output_flat)
116+
input = h5py.File(test_dir / input_file, "r")["input"][:]
117+
output = h5py.File(test_dir / output_file, "r")["output"][:]
119118

120119
if "net_ps" in solutions:
121-
par = pd.read_csv(test_dir / par_file, sep="\t")
120+
par = h5py.File(test_dir / par_file, "r")
122121
for ml_model in ml_models.models:
123122
net = nets[ml_model.mlmodel_id](jr.PRNGKey(0))
124123
for layer in net.layers.keys():
125-
layer_prefix = f"net_{layer}"
126124
if (
127125
isinstance(net.layers[layer], eqx.Module)
128126
and hasattr(net.layers[layer], "weight")
129127
and net.layers[layer].weight is not None
130128
):
131-
prefix = layer_prefix + "_weight"
132-
df = par[
133-
par[petab.PARAMETER_ID].str.startswith(prefix)
134-
]
135-
df["ix"] = (
136-
df[petab.PARAMETER_ID]
137-
.str.split("_")
138-
.str[3:]
139-
.apply(lambda x: ";".join(x))
140-
)
141-
w = _reshape_flat_array(df)
129+
w = par[layer]["weight"][:]
142130
if isinstance(net.layers[layer], eqx.nn.ConvTranspose):
143131
# see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose
144132
w = np.flip(
@@ -155,17 +143,7 @@ def test_net(test):
155143
and hasattr(net.layers[layer], "bias")
156144
and net.layers[layer].bias is not None
157145
):
158-
prefix = layer_prefix + "_bias"
159-
df = par[
160-
par[petab.PARAMETER_ID].str.startswith(prefix)
161-
]
162-
df["ix"] = (
163-
df[petab.PARAMETER_ID]
164-
.str.split("_")
165-
.str[3:]
166-
.apply(lambda x: ";".join(x))
167-
)
168-
b = _reshape_flat_array(df)
146+
b = par[layer]["bias"][:]
169147
if isinstance(
170148
net.layers[layer],
171149
eqx.nn.Conv | eqx.nn.ConvTranspose,
@@ -199,75 +177,69 @@ def test_net(test):
199177

200178

201179
@pytest.mark.parametrize(
202-
"test", sorted([d.stem for d in cases_dir.glob("[0-9]*")])
180+
"test", sorted([d.stem for d in ude_cases_dir.glob("[0-9]*")])
203181
)
204182
def test_ude(test):
205-
test_dir = cases_dir / test
206-
with open(test_dir / "petab" / "problem_ude.yaml") as f:
183+
test_dir = ude_cases_dir / test
184+
with open(test_dir / "petab" / "problem.yaml") as f:
207185
petab_yaml = safe_load(f)
208186
with open(test_dir / "solutions.yaml") as f:
209187
solutions = safe_load(f)
210188

211189
with change_directory(test_dir / "petab"):
212190
from petab.v2 import Problem
213191

214-
petab_yaml["format_version"] = "2.0.0"
215-
for problem in petab_yaml["problems"]:
216-
problem["model_files"] = {
217-
problem["model_files"]["location"].split(".")[0]: problem[
218-
"model_files"
219-
]
220-
}
221-
for mapping_file in problem["mapping_files"]:
222-
df = pd.read_csv(
223-
mapping_file,
224-
sep="\t",
225-
)
226-
if df[petab.PETAB_ENTITY_ID].str.startswith("net").any():
227-
df.rename(
228-
columns={
229-
petab.PETAB_ENTITY_ID: petab.MODEL_ENTITY_ID,
230-
petab.MODEL_ENTITY_ID: petab.PETAB_ENTITY_ID,
231-
}
232-
).to_csv(mapping_file, sep="\t", index=False)
233-
192+
petab_yaml["format_version"] = "2.0.0" # TODO: fixme
234193
petab_problem = Problem.from_yaml(petab_yaml)
235194
jax_model = import_petab_problem(
236195
petab_problem,
237196
model_output_dir=Path(__file__).parent / "models" / test,
238197
compile_=True,
239198
jax=True,
240199
)
241-
jax_problem = JAXProblem(jax_model, petab_problem)
242-
for net, net_config in petab_problem.extensions_config[
243-
"petab_sciml"
244-
].items():
245-
pars = h5py.File(
246-
net_config["parameters"].replace(".h5", ".hf5"), "r"
247-
)
248-
for layer_name, layer in jax_problem.model.nns[net].layers.items():
249-
for attribute in dir(layer):
250-
if not isinstance(
251-
getattr(layer, attribute), jax.numpy.ndarray
252-
):
253-
continue
254-
value = jnp.array(pars[layer_name][attribute])
200+
# non_numeric = pd.to_numeric(petab_problem.parameter_df[petab.NOMINAL_VALUE], errors='coerce').isna()
201+
# par_files = petab_problem.parameter_df.loc[non_numeric, petab.NOMINAL_VALUE].unique()
202+
# par_values = {
203+
# par_file: h5py.File(par_file, "r")
204+
# for par_file in par_files
205+
# }
206+
# for par_id, row in petab_problem.parameter_df.iterrows():
207+
# if not non_numeric[par_id]:
208+
# continue
209+
# petab_problem.parameter_df.loc[par_id, petab.NOMINAL_VALUE] = \
210+
# (par_values[row[petab.NOMINAL_VALUE]],)
211+
# petab_problem.parameter_df.loc[np.logical_not(non_numeric), petab.NOMINAL_VALUE] = pd.to_numeric(
212+
# petab_problem.parameter_df.loc[np.logical_not(non_numeric), petab.NOMINAL_VALUE]
213+
# )
255214

256-
if (
257-
isinstance(layer, eqx.nn.ConvTranspose)
258-
and attribute == "weight"
259-
):
260-
# see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose
261-
value = jnp.flip(
262-
value, axis=tuple(range(2, value.ndim))
263-
).swapaxes(0, 1)
264-
jax_problem = eqx.tree_at(
265-
lambda x: getattr(
266-
x.model.nns[net].layers[layer_name], attribute
267-
),
268-
jax_problem,
269-
value,
270-
)
215+
jax_problem = JAXProblem(jax_model, petab_problem)
216+
# for net, net_config in petab_problem.extensions_config.items(): # TODO: FIXME (https://github.com/sebapersson/petab_sciml_testsuite/issues/1)
217+
# pars = h5py.File(
218+
# net_config["net1_ps_file"]['location'], "r" # TODO: check format and actually use propoer petab nominal parameter infrastructure
219+
# )
220+
# for layer_name, layer in jax_problem.model.nns[net].layers.items():
221+
# for attribute in dir(layer):
222+
# if not isinstance(
223+
# getattr(layer, attribute), jax.numpy.ndarray
224+
# ):
225+
# continue
226+
# value = jnp.array(pars[layer_name][attribute])
227+
#
228+
# if (
229+
# isinstance(layer, eqx.nn.ConvTranspose)
230+
# and attribute == "weight"
231+
# ):
232+
# # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose
233+
# value = jnp.flip(
234+
# value, axis=tuple(range(2, value.ndim))
235+
# ).swapaxes(0, 1)
236+
# jax_problem = eqx.tree_at(
237+
# lambda x: getattr(
238+
# x.model.nns[net].layers[layer_name], attribute
239+
# ),
240+
# jax_problem,
241+
# value,
242+
# )
271243

272244
# llh
273245
if test in (

0 commit comments

Comments
 (0)