@@ -47,7 +47,9 @@ def change_directory(destination):
4747
4848def _reshape_flat_array (array_flat ):
4949 array_flat ["ix" ] = array_flat ["ix" ].astype (str )
50- ix_cols = [f"ix_{ i } " for i in range (len (array_flat ["ix" ].values [0 ].split (";" )))]
50+ ix_cols = [
51+ f"ix_{ i } " for i in range (len (array_flat ["ix" ].values [0 ].split (";" )))
52+ ]
5153 if len (ix_cols ) == 1 :
5254 array_flat [ix_cols [0 ]] = array_flat ["ix" ].apply (int )
5355 else :
@@ -61,7 +63,9 @@ def _reshape_flat_array(array_flat):
6163 return array
6264
6365
64- @pytest .mark .parametrize ("test" , sorted (d .stem for d in net_cases_dir .glob ("[0-9]*" )))
66+ @pytest .mark .parametrize (
67+ "test" , sorted (d .stem for d in net_cases_dir .glob ("[0-9]*" ))
68+ )
6569def test_ml_model_import (test ):
6670 test_dir = net_cases_dir / test
6771 with open (test_dir / "solutions.yaml" ) as f :
@@ -104,7 +108,9 @@ def test_ml_model_import(test):
104108 if test == "053" :
105109 input_files = [
106110 (i1 , i2 )
107- for i1 , i2 in zip (solutions ["net_input_arg0" ], solutions ["net_input_arg1" ])
111+ for i1 , i2 in zip (
112+ solutions ["net_input_arg0" ], solutions ["net_input_arg1" ]
113+ )
108114 ]
109115 else :
110116 input_files = solutions ["net_input" ]
@@ -117,13 +123,19 @@ def test_ml_model_import(test):
117123 if test == "053" :
118124 input = tuple (
119125 [
120- h5py .File (test_dir / in_file , "r" )["inputs" ]["input0" ]["data" ][:]
126+ h5py .File (test_dir / in_file , "r" )["inputs" ]["input0" ][
127+ "data"
128+ ][:]
121129 for in_file in input_file
122130 ]
123131 )
124132 else :
125- input = h5py .File (test_dir / input_file , "r" )["inputs" ]["input0" ]["data" ][:]
126- output = h5py .File (test_dir / output_file , "r" )["outputs" ]["output0" ]["data" ][:]
133+ input = h5py .File (test_dir / input_file , "r" )["inputs" ]["input0" ][
134+ "data"
135+ ][:]
136+ output = h5py .File (test_dir / output_file , "r" )["outputs" ]["output0" ][
137+ "data"
138+ ][:]
127139
128140 if "net_ps" in solutions :
129141 par = h5py .File (test_dir / par_file , "r" )
@@ -134,10 +146,14 @@ def test_ml_model_import(test):
134146 and hasattr (net .layers [layer ], "weight" )
135147 and net .layers [layer ].weight is not None
136148 ):
137- w = par ["parameters" ][ml_model .nn_model_id ][layer ]["weight" ][:]
149+ w = par ["parameters" ][ml_model .nn_model_id ][layer ][
150+ "weight"
151+ ][:]
138152 if isinstance (net .layers [layer ], eqx .nn .ConvTranspose ):
139153 # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose
140- w = np .flip (w , axis = tuple (range (2 , w .ndim ))).swapaxes (0 , 1 )
154+ w = np .flip (w , axis = tuple (range (2 , w .ndim ))).swapaxes (
155+ 0 , 1
156+ )
141157 assert w .shape == net .layers [layer ].weight .shape
142158 net = eqx .tree_at (
143159 lambda x : x .layers [layer ].weight ,
@@ -149,7 +165,9 @@ def test_ml_model_import(test):
149165 and hasattr (net .layers [layer ], "bias" )
150166 and net .layers [layer ].bias is not None
151167 ):
152- b = par ["parameters" ][ml_model .nn_model_id ][layer ]["bias" ][:]
168+ b = par ["parameters" ][ml_model .nn_model_id ][layer ]["bias" ][
169+ :
170+ ]
153171 if isinstance (
154172 net .layers [layer ],
155173 eqx .nn .Conv | eqx .nn .ConvTranspose ,
@@ -182,7 +200,9 @@ def test_ml_model_import(test):
182200 )
183201
184202
185- @pytest .mark .parametrize ("test" , sorted ([d .stem for d in ude_cases_dir .glob ("[0-9]*" )]))
203+ @pytest .mark .parametrize (
204+ "test" , sorted ([d .stem for d in ude_cases_dir .glob ("[0-9]*" )])
205+ )
186206def test_sciml_problem_import (test ):
187207 test_dir = ude_cases_dir / test
188208
@@ -193,7 +213,9 @@ def test_sciml_problem_import(test):
193213
194214 with change_directory (test_dir / "petab" ):
195215 # HACK!! Again!! Around "array" in parameters table
196- petab_problem = _v2_sciml_problem_helper (petab_yaml , test_dir / "petab" )
216+ petab_problem = _v2_sciml_problem_helper (
217+ petab_yaml , test_dir / "petab"
218+ )
197219
198220 if test in ("003" ,):
199221 with pytest .raises (NotImplementedError ):
@@ -290,9 +312,13 @@ def test_sciml_problem_import(test):
290312 )
291313 else :
292314 expected = h5py .File (test_dir / file , "r" )
293- for layer_name , layer in jax_problem .model .nns [component ].layers .items ():
315+ for layer_name , layer in jax_problem .model .nns [
316+ component
317+ ].layers .items ():
294318 for attribute in dir (layer ):
295- if not isinstance (getattr (layer , attribute ), jax .numpy .ndarray ):
319+ if not isinstance (
320+ getattr (layer , attribute ), jax .numpy .ndarray
321+ ):
296322 continue
297323 actual = getattr (
298324 sllh .model .nns [component ].layers [layer_name ], attribute
@@ -307,7 +333,9 @@ def test_sciml_problem_import(test):
307333 )
308334 if (
309335 np .squeeze (
310- expected ["parameters" ][component ][layer_name ][attribute ][:]
336+ expected ["parameters" ][component ][layer_name ][
337+ attribute
338+ ][:]
311339 ).size
312340 == 0
313341 ):
@@ -333,7 +361,9 @@ def _v2_sciml_problem_helper(yaml_config, base_path):
333361 df = pd .read_csv (f , sep = "\t " )
334362 df .nominalValue = df .nominalValue .apply (_try_float )
335363 if "priorParameters" in df .columns :
336- df .priorParameters = df .priorParameters .apply (_process_prior_params )
364+ df .priorParameters = df .priorParameters .apply (
365+ _process_prior_params
366+ )
337367 parameters = [
338368 v2 .Parameter .model_construct (** row .to_dict ())
339369 for _ , row in df .reset_index ().iterrows ()
@@ -351,19 +381,28 @@ def _v2_sciml_problem_helper(yaml_config, base_path):
351381 ]
352382
353383 measurement_tables = (
354- [v2 .MeasurementTable .from_tsv (f , base_path ) for f in config .measurement_files ]
384+ [
385+ v2 .MeasurementTable .from_tsv (f , base_path )
386+ for f in config .measurement_files
387+ ]
355388 if config .measurement_files
356389 else None
357390 )
358391
359392 experiment_tables = (
360- [v2 .ExperimentTable .from_tsv (f , base_path ) for f in config .experiment_files ]
393+ [
394+ v2 .ExperimentTable .from_tsv (f , base_path )
395+ for f in config .experiment_files
396+ ]
361397 if config .experiment_files
362398 else None
363399 )
364400
365401 condition_tables = (
366- [v2 .ConditionTable .from_tsv (f , base_path ) for f in config .condition_files ]
402+ [
403+ v2 .ConditionTable .from_tsv (f , base_path )
404+ for f in config .condition_files
405+ ]
367406 if config .condition_files
368407 else None
369408 )
@@ -382,7 +421,10 @@ def _v2_sciml_problem_helper(yaml_config, base_path):
382421 ]
383422
384423 observable_tables = (
385- [v2 .ObservableTable .from_tsv (f , base_path ) for f in config .observable_files ]
424+ [
425+ v2 .ObservableTable .from_tsv (f , base_path )
426+ for f in config .observable_files
427+ ]
386428 if config .observable_files
387429 else None
388430 )
@@ -414,7 +456,9 @@ def _process_prior_params(prior_params):
414456
415457def _normal_logpdf (x : jnp .ndarray , mean : float , std : float ) -> jnp .ndarray :
416458 var = std ** 2
417- return jnp .sum (- 0.5 * jnp .log (2.0 * jnp .pi * var ) - 0.5 * ((x - mean ) ** 2 ) / var )
459+ return jnp .sum (
460+ - 0.5 * jnp .log (2.0 * jnp .pi * var ) - 0.5 * ((x - mean ) ** 2 ) / var
461+ )
418462
419463
420464def _uniform_logpdf (x : jnp .ndarray , low : float , high : float ) -> jnp .ndarray :
@@ -449,7 +493,9 @@ def _tree_array_loguniformprior(tree, low: float, high: float) -> jnp.ndarray:
449493 return total
450494
451495
452- def _model_logprior (model , layer1_bias_std = 1.0 , layer1_weight_std = 1.0 ) -> jnp .ndarray :
496+ def _model_logprior (
497+ model , layer1_bias_std = 1.0 , layer1_weight_std = 1.0
498+ ) -> jnp .ndarray :
453499 mech = model .parameters
454500 layer1_bias = model .model .nns ["net1" ].layers ["layer1" ].bias
455501 layer1_weight = model .model .nns ["net1" ].layers ["layer1" ].weight
@@ -460,6 +506,8 @@ def _model_logprior(model, layer1_bias_std=1.0, layer1_weight_std=1.0) -> jnp.nd
460506 return (
461507 _tree_array_loguniformprior (mech , low = 0.0 , high = 15.0 )
462508 + _tree_array_lognormprior (layer1_bias , mean = 0.0 , std = layer1_bias_std )
463- + _tree_array_lognormprior (layer1_weight , mean = 0.0 , std = layer1_weight_std )
509+ + _tree_array_lognormprior (
510+ layer1_weight , mean = 0.0 , std = layer1_weight_std
511+ )
464512 + _tree_array_lognormprior (rest , mean = 0.0 , std = 1.0 )
465513 )
0 commit comments