@@ -44,6 +44,8 @@ def change_directory(destination):
4444# pip install git+https://github.com/sebapersson/petab_sciml@add_standard#egg=petab_sciml\&subdirectory=src/python
4545
4646cases_dir = Path (__file__ ).parent / "testsuite" / "test_cases"
47+ net_cases_dir = cases_dir / "net_import"
48+ ude_cases_dir = cases_dir / "hybrid"
4749
4850
4951def _reshape_flat_array (array_flat ):
@@ -65,10 +67,10 @@ def _reshape_flat_array(array_flat):
6567
6668
6769@pytest .mark .parametrize (
68- "test" , sorted ([d .stem for d in cases_dir .glob ("net_ [0-9]*" )])
70+ "test" , sorted ([d .stem for d in net_cases_dir .glob ("[0-9]*" )])
6971)
7072def test_net (test ):
71- test_dir = cases_dir / test
73+ test_dir = net_cases_dir / test
7274 with open (test_dir / "solutions.yaml" ) as f :
7375 solutions = safe_load (f )
7476
@@ -83,20 +85,20 @@ def test_net(test):
8385 for ml_model in ml_models .models :
8486 module_dir = outdir / f"{ ml_model .mlmodel_id } .py"
8587 if test in (
86- "net_002 " ,
87- "net_009 " ,
88- "net_018 " ,
89- "net_019 " ,
90- "net_020 " ,
91- "net_021 " ,
92- "net_022 " ,
93- "net_042 " ,
94- "net_043 " ,
95- "net_044 " ,
96- "net_045 " ,
97- "net_046 " ,
98- "net_047 " ,
99- "net_048 " ,
88+ "002 " ,
89+ "009 " ,
90+ "018 " ,
91+ "019 " ,
92+ "020 " ,
93+ "021 " ,
94+ "022 " ,
95+ "042 " ,
96+ "043 " ,
97+ "044 " ,
98+ "045 " ,
99+ "046 " ,
100+ "047 " ,
101+ "048 " ,
100102 ):
101103 with pytest .raises (NotImplementedError ):
102104 generate_equinox (ml_model , module_dir )
@@ -111,34 +113,20 @@ def test_net(test):
111113 solutions .get ("net_ps" , solutions ["net_input" ]),
112114 solutions ["net_output" ],
113115 ):
114- input_flat = pd .read_csv (test_dir / input_file , sep = "\t " )
115- input = _reshape_flat_array (input_flat )
116-
117- output_flat = pd .read_csv (test_dir / output_file , sep = "\t " )
118- output = _reshape_flat_array (output_flat )
116+ input = h5py .File (test_dir / input_file , "r" )["input" ][:]
117+ output = h5py .File (test_dir / output_file , "r" )["output" ][:]
119118
120119 if "net_ps" in solutions :
121- par = pd . read_csv (test_dir / par_file , sep = " \t " )
120+ par = h5py . File (test_dir / par_file , "r " )
122121 for ml_model in ml_models .models :
123122 net = nets [ml_model .mlmodel_id ](jr .PRNGKey (0 ))
124123 for layer in net .layers .keys ():
125- layer_prefix = f"net_{ layer } "
126124 if (
127125 isinstance (net .layers [layer ], eqx .Module )
128126 and hasattr (net .layers [layer ], "weight" )
129127 and net .layers [layer ].weight is not None
130128 ):
131- prefix = layer_prefix + "_weight"
132- df = par [
133- par [petab .PARAMETER_ID ].str .startswith (prefix )
134- ]
135- df ["ix" ] = (
136- df [petab .PARAMETER_ID ]
137- .str .split ("_" )
138- .str [3 :]
139- .apply (lambda x : ";" .join (x ))
140- )
141- w = _reshape_flat_array (df )
129+ w = par [layer ]["weight" ][:]
142130 if isinstance (net .layers [layer ], eqx .nn .ConvTranspose ):
143131 # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose
144132 w = np .flip (
@@ -155,17 +143,7 @@ def test_net(test):
155143 and hasattr (net .layers [layer ], "bias" )
156144 and net .layers [layer ].bias is not None
157145 ):
158- prefix = layer_prefix + "_bias"
159- df = par [
160- par [petab .PARAMETER_ID ].str .startswith (prefix )
161- ]
162- df ["ix" ] = (
163- df [petab .PARAMETER_ID ]
164- .str .split ("_" )
165- .str [3 :]
166- .apply (lambda x : ";" .join (x ))
167- )
168- b = _reshape_flat_array (df )
146+ b = par [layer ]["bias" ][:]
169147 if isinstance (
170148 net .layers [layer ],
171149 eqx .nn .Conv | eqx .nn .ConvTranspose ,
@@ -199,75 +177,69 @@ def test_net(test):
199177
200178
201179@pytest .mark .parametrize (
202- "test" , sorted ([d .stem for d in cases_dir .glob ("[0-9]*" )])
180+ "test" , sorted ([d .stem for d in ude_cases_dir .glob ("[0-9]*" )])
203181)
204182def test_ude (test ):
205- test_dir = cases_dir / test
206- with open (test_dir / "petab" / "problem_ude .yaml" ) as f :
183+ test_dir = ude_cases_dir / test
184+ with open (test_dir / "petab" / "problem .yaml" ) as f :
207185 petab_yaml = safe_load (f )
208186 with open (test_dir / "solutions.yaml" ) as f :
209187 solutions = safe_load (f )
210188
211189 with change_directory (test_dir / "petab" ):
212190 from petab .v2 import Problem
213191
214- petab_yaml ["format_version" ] = "2.0.0"
215- for problem in petab_yaml ["problems" ]:
216- problem ["model_files" ] = {
217- problem ["model_files" ]["location" ].split ("." )[0 ]: problem [
218- "model_files"
219- ]
220- }
221- for mapping_file in problem ["mapping_files" ]:
222- df = pd .read_csv (
223- mapping_file ,
224- sep = "\t " ,
225- )
226- if df [petab .PETAB_ENTITY_ID ].str .startswith ("net" ).any ():
227- df .rename (
228- columns = {
229- petab .PETAB_ENTITY_ID : petab .MODEL_ENTITY_ID ,
230- petab .MODEL_ENTITY_ID : petab .PETAB_ENTITY_ID ,
231- }
232- ).to_csv (mapping_file , sep = "\t " , index = False )
233-
192+ petab_yaml ["format_version" ] = "2.0.0" # TODO: fixme
234193 petab_problem = Problem .from_yaml (petab_yaml )
235194 jax_model = import_petab_problem (
236195 petab_problem ,
237196 model_output_dir = Path (__file__ ).parent / "models" / test ,
238197 compile_ = True ,
239198 jax = True ,
240199 )
241- jax_problem = JAXProblem ( jax_model , petab_problem )
242- for net , net_config in petab_problem .extensions_config [
243- "petab_sciml"
244- ]. items ():
245- pars = h5py . File (
246- net_config [ "parameters" ]. replace ( ".h5" , ".hf5" ), "r"
247- )
248- for layer_name , layer in jax_problem . model . nns [ net ]. layers . items () :
249- for attribute in dir ( layer ):
250- if not isinstance (
251- getattr ( layer , attribute ), jax . numpy . ndarray
252- ):
253- continue
254- value = jnp . array ( pars [ layer_name ][ attribute ] )
200+ # non_numeric = pd.to_numeric(petab_problem.parameter_df[petab.NOMINAL_VALUE], errors='coerce').isna( )
201+ # par_files = petab_problem.parameter_df.loc[non_numeric, petab.NOMINAL_VALUE].unique()
202+ # par_values = {
203+ # par_file: h5py.File(par_file, "r")
204+ # for par_file in par_files
205+ # }
206+ # for par_id, row in petab_problem.parameter_df.iterrows():
207+ # if not non_numeric[par_id] :
208+ # continue
209+ # petab_problem.parameter_df.loc[par_id, petab.NOMINAL_VALUE] = \
210+ # (par_values[row[petab.NOMINAL_VALUE]],)
211+ # petab_problem.parameter_df.loc[np.logical_not(non_numeric), petab.NOMINAL_VALUE] = pd.to_numeric(
212+ # petab_problem.parameter_df.loc[np.logical_not(non_numeric), petab.NOMINAL_VALUE]
213+ # )
255214
256- if (
257- isinstance (layer , eqx .nn .ConvTranspose )
258- and attribute == "weight"
259- ):
260- # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose
261- value = jnp .flip (
262- value , axis = tuple (range (2 , value .ndim ))
263- ).swapaxes (0 , 1 )
264- jax_problem = eqx .tree_at (
265- lambda x : getattr (
266- x .model .nns [net ].layers [layer_name ], attribute
267- ),
268- jax_problem ,
269- value ,
270- )
215+ jax_problem = JAXProblem (jax_model , petab_problem )
216+ # for net, net_config in petab_problem.extensions_config.items(): # TODO: FIXME (https://github.com/sebapersson/petab_sciml_testsuite/issues/1)
217+ # pars = h5py.File(
218+ # net_config["net1_ps_file"]['location'], "r" # TODO: check format and actually use propoer petab nominal parameter infrastructure
219+ # )
220+ # for layer_name, layer in jax_problem.model.nns[net].layers.items():
221+ # for attribute in dir(layer):
222+ # if not isinstance(
223+ # getattr(layer, attribute), jax.numpy.ndarray
224+ # ):
225+ # continue
226+ # value = jnp.array(pars[layer_name][attribute])
227+ #
228+ # if (
229+ # isinstance(layer, eqx.nn.ConvTranspose)
230+ # and attribute == "weight"
231+ # ):
232+ # # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose
233+ # value = jnp.flip(
234+ # value, axis=tuple(range(2, value.ndim))
235+ # ).swapaxes(0, 1)
236+ # jax_problem = eqx.tree_at(
237+ # lambda x: getattr(
238+ # x.model.nns[net].layers[layer_name], attribute
239+ # ),
240+ # jax_problem,
241+ # value,
242+ # )
271243
272244 # llh
273245 if test in (
0 commit comments