Skip to content

Commit 9508113

Browse files
committed
pre-commit run --all-files
1 parent c86a19a commit 9508113

1 file changed

Lines changed: 70 additions & 22 deletions

File tree

tests/sciml/test_sciml.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def change_directory(destination):
4747

4848
def _reshape_flat_array(array_flat):
4949
array_flat["ix"] = array_flat["ix"].astype(str)
50-
ix_cols = [f"ix_{i}" for i in range(len(array_flat["ix"].values[0].split(";")))]
50+
ix_cols = [
51+
f"ix_{i}" for i in range(len(array_flat["ix"].values[0].split(";")))
52+
]
5153
if len(ix_cols) == 1:
5254
array_flat[ix_cols[0]] = array_flat["ix"].apply(int)
5355
else:
@@ -61,7 +63,9 @@ def _reshape_flat_array(array_flat):
6163
return array
6264

6365

64-
@pytest.mark.parametrize("test", sorted(d.stem for d in net_cases_dir.glob("[0-9]*")))
66+
@pytest.mark.parametrize(
67+
"test", sorted(d.stem for d in net_cases_dir.glob("[0-9]*"))
68+
)
6569
def test_ml_model_import(test):
6670
test_dir = net_cases_dir / test
6771
with open(test_dir / "solutions.yaml") as f:
@@ -104,7 +108,9 @@ def test_ml_model_import(test):
104108
if test == "053":
105109
input_files = [
106110
(i1, i2)
107-
for i1, i2 in zip(solutions["net_input_arg0"], solutions["net_input_arg1"])
111+
for i1, i2 in zip(
112+
solutions["net_input_arg0"], solutions["net_input_arg1"]
113+
)
108114
]
109115
else:
110116
input_files = solutions["net_input"]
@@ -117,13 +123,19 @@ def test_ml_model_import(test):
117123
if test == "053":
118124
input = tuple(
119125
[
120-
h5py.File(test_dir / in_file, "r")["inputs"]["input0"]["data"][:]
126+
h5py.File(test_dir / in_file, "r")["inputs"]["input0"][
127+
"data"
128+
][:]
121129
for in_file in input_file
122130
]
123131
)
124132
else:
125-
input = h5py.File(test_dir / input_file, "r")["inputs"]["input0"]["data"][:]
126-
output = h5py.File(test_dir / output_file, "r")["outputs"]["output0"]["data"][:]
133+
input = h5py.File(test_dir / input_file, "r")["inputs"]["input0"][
134+
"data"
135+
][:]
136+
output = h5py.File(test_dir / output_file, "r")["outputs"]["output0"][
137+
"data"
138+
][:]
127139

128140
if "net_ps" in solutions:
129141
par = h5py.File(test_dir / par_file, "r")
@@ -134,10 +146,14 @@ def test_ml_model_import(test):
134146
and hasattr(net.layers[layer], "weight")
135147
and net.layers[layer].weight is not None
136148
):
137-
w = par["parameters"][ml_model.nn_model_id][layer]["weight"][:]
149+
w = par["parameters"][ml_model.nn_model_id][layer][
150+
"weight"
151+
][:]
138152
if isinstance(net.layers[layer], eqx.nn.ConvTranspose):
139153
# see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose
140-
w = np.flip(w, axis=tuple(range(2, w.ndim))).swapaxes(0, 1)
154+
w = np.flip(w, axis=tuple(range(2, w.ndim))).swapaxes(
155+
0, 1
156+
)
141157
assert w.shape == net.layers[layer].weight.shape
142158
net = eqx.tree_at(
143159
lambda x: x.layers[layer].weight,
@@ -149,7 +165,9 @@ def test_ml_model_import(test):
149165
and hasattr(net.layers[layer], "bias")
150166
and net.layers[layer].bias is not None
151167
):
152-
b = par["parameters"][ml_model.nn_model_id][layer]["bias"][:]
168+
b = par["parameters"][ml_model.nn_model_id][layer]["bias"][
169+
:
170+
]
153171
if isinstance(
154172
net.layers[layer],
155173
eqx.nn.Conv | eqx.nn.ConvTranspose,
@@ -182,7 +200,9 @@ def test_ml_model_import(test):
182200
)
183201

184202

185-
@pytest.mark.parametrize("test", sorted([d.stem for d in ude_cases_dir.glob("[0-9]*")]))
203+
@pytest.mark.parametrize(
204+
"test", sorted([d.stem for d in ude_cases_dir.glob("[0-9]*")])
205+
)
186206
def test_sciml_problem_import(test):
187207
test_dir = ude_cases_dir / test
188208

@@ -193,7 +213,9 @@ def test_sciml_problem_import(test):
193213

194214
with change_directory(test_dir / "petab"):
195215
# HACK!! Again!! Around "array" in parameters table
196-
petab_problem = _v2_sciml_problem_helper(petab_yaml, test_dir / "petab")
216+
petab_problem = _v2_sciml_problem_helper(
217+
petab_yaml, test_dir / "petab"
218+
)
197219

198220
if test in ("003",):
199221
with pytest.raises(NotImplementedError):
@@ -290,9 +312,13 @@ def test_sciml_problem_import(test):
290312
)
291313
else:
292314
expected = h5py.File(test_dir / file, "r")
293-
for layer_name, layer in jax_problem.model.nns[component].layers.items():
315+
for layer_name, layer in jax_problem.model.nns[
316+
component
317+
].layers.items():
294318
for attribute in dir(layer):
295-
if not isinstance(getattr(layer, attribute), jax.numpy.ndarray):
319+
if not isinstance(
320+
getattr(layer, attribute), jax.numpy.ndarray
321+
):
296322
continue
297323
actual = getattr(
298324
sllh.model.nns[component].layers[layer_name], attribute
@@ -307,7 +333,9 @@ def test_sciml_problem_import(test):
307333
)
308334
if (
309335
np.squeeze(
310-
expected["parameters"][component][layer_name][attribute][:]
336+
expected["parameters"][component][layer_name][
337+
attribute
338+
][:]
311339
).size
312340
== 0
313341
):
@@ -333,7 +361,9 @@ def _v2_sciml_problem_helper(yaml_config, base_path):
333361
df = pd.read_csv(f, sep="\t")
334362
df.nominalValue = df.nominalValue.apply(_try_float)
335363
if "priorParameters" in df.columns:
336-
df.priorParameters = df.priorParameters.apply(_process_prior_params)
364+
df.priorParameters = df.priorParameters.apply(
365+
_process_prior_params
366+
)
337367
parameters = [
338368
v2.Parameter.model_construct(**row.to_dict())
339369
for _, row in df.reset_index().iterrows()
@@ -351,19 +381,28 @@ def _v2_sciml_problem_helper(yaml_config, base_path):
351381
]
352382

353383
measurement_tables = (
354-
[v2.MeasurementTable.from_tsv(f, base_path) for f in config.measurement_files]
384+
[
385+
v2.MeasurementTable.from_tsv(f, base_path)
386+
for f in config.measurement_files
387+
]
355388
if config.measurement_files
356389
else None
357390
)
358391

359392
experiment_tables = (
360-
[v2.ExperimentTable.from_tsv(f, base_path) for f in config.experiment_files]
393+
[
394+
v2.ExperimentTable.from_tsv(f, base_path)
395+
for f in config.experiment_files
396+
]
361397
if config.experiment_files
362398
else None
363399
)
364400

365401
condition_tables = (
366-
[v2.ConditionTable.from_tsv(f, base_path) for f in config.condition_files]
402+
[
403+
v2.ConditionTable.from_tsv(f, base_path)
404+
for f in config.condition_files
405+
]
367406
if config.condition_files
368407
else None
369408
)
@@ -382,7 +421,10 @@ def _v2_sciml_problem_helper(yaml_config, base_path):
382421
]
383422

384423
observable_tables = (
385-
[v2.ObservableTable.from_tsv(f, base_path) for f in config.observable_files]
424+
[
425+
v2.ObservableTable.from_tsv(f, base_path)
426+
for f in config.observable_files
427+
]
386428
if config.observable_files
387429
else None
388430
)
@@ -414,7 +456,9 @@ def _process_prior_params(prior_params):
414456

415457
def _normal_logpdf(x: jnp.ndarray, mean: float, std: float) -> jnp.ndarray:
416458
var = std**2
417-
return jnp.sum(-0.5 * jnp.log(2.0 * jnp.pi * var) - 0.5 * ((x - mean) ** 2) / var)
459+
return jnp.sum(
460+
-0.5 * jnp.log(2.0 * jnp.pi * var) - 0.5 * ((x - mean) ** 2) / var
461+
)
418462

419463

420464
def _uniform_logpdf(x: jnp.ndarray, low: float, high: float) -> jnp.ndarray:
@@ -449,7 +493,9 @@ def _tree_array_loguniformprior(tree, low: float, high: float) -> jnp.ndarray:
449493
return total
450494

451495

452-
def _model_logprior(model, layer1_bias_std=1.0, layer1_weight_std=1.0) -> jnp.ndarray:
496+
def _model_logprior(
497+
model, layer1_bias_std=1.0, layer1_weight_std=1.0
498+
) -> jnp.ndarray:
453499
mech = model.parameters
454500
layer1_bias = model.model.nns["net1"].layers["layer1"].bias
455501
layer1_weight = model.model.nns["net1"].layers["layer1"].weight
@@ -460,6 +506,8 @@ def _model_logprior(model, layer1_bias_std=1.0, layer1_weight_std=1.0) -> jnp.nd
460506
return (
461507
_tree_array_loguniformprior(mech, low=0.0, high=15.0)
462508
+ _tree_array_lognormprior(layer1_bias, mean=0.0, std=layer1_bias_std)
463-
+ _tree_array_lognormprior(layer1_weight, mean=0.0, std=layer1_weight_std)
509+
+ _tree_array_lognormprior(
510+
layer1_weight, mean=0.0, std=layer1_weight_std
511+
)
464512
+ _tree_array_lognormprior(rest, mean=0.0, std=1.0)
465513
)

0 commit comments

Comments
 (0)