11import modflow_devtools .models as models
2+ import pandas as pd
23import pytest
34
45from modflowapi import run_simulation
1213 "ex-prt-mp7-p04" ,
1314]
1415
16+ temporary_skip_state_count = set (
17+ [
18+ # Skip state count tests for these model for now.
19+ # There seems to be a different problem.
20+ "ex-gwf-fhb" ,
21+ "ex-gwf-nwt-p02a" ,
22+ "ex-gwf-nwt-p02b" ,
23+ "ex-gwe-geotherm" ,
24+ "ex-gwe-radial-slow" ,
25+ ]
26+ )
27+
28+
29+ class StateCollector :
30+ """
31+ Callback that collects the model state sequence.
32+
33+ THe attribute `states` is a list of tuples `(model_name, state_name)`,
34+ where `state_name` is `callback_step.name`.
35+ """
36+
37+ def __init__ (self ):
38+ self .states = []
39+
40+ def __call__ (self , sim , callback_step ):
41+ self .states .append ((sim .model_names [0 ], callback_step .name ))
42+
43+
44+ def count_states (states ):
45+ """
46+ Count states per model.
47+
48+ The numbers of `start_state` and `end_state` should match.
49+ """
50+ states = pd .DataFrame (states , columns = ["model" , "state" ])
51+ counts = states .groupby (["model" , "state" ])[["state" ]].agg ("count" )
52+ counts .columns = ["counts" ]
53+ return counts
54+
55+
56+ def compare_counts (counts ):
57+ matches = [
58+ ("iteration_start" , "iteration_end" ),
59+ ("stress_period_start" , "stress_period_end" ),
60+ ("timestep_start" , "timestep_end" ),
61+ ]
62+ for model_name in counts .index .levels [0 ]:
63+ model = counts .loc [model_name ]
64+ for state1 , state2 in matches :
65+ count1 , count2 = model .loc [state1 ].iloc [0 ], model .loc [state2 ].iloc [0 ]
66+ assert count1 == count2 , (
67+ f"{ state1 } : { int (count1 )} " ,
68+ f"{ state2 } : { int (count2 )} " ,
69+ )
70+
1571
1672@pytest .mark .parametrize ("example_name" , examples .keys ())
1773def test_example (function_tmpdir , example_name ):
@@ -21,4 +77,8 @@ def test_example(function_tmpdir, example_name):
2177 model_relpath = model_name .rpartition (example_name + "/" )[- 1 ]
2278 model_workspace = function_tmpdir / model_relpath
2379 models .copy_to (model_workspace , model_name , verbose = True )
24- run_simulation (dll , model_workspace , lambda sim , step : None , verbose = True )
80+ callback = StateCollector ()
81+ run_simulation (dll , model_workspace , callback = callback , verbose = True )
82+ if model_name in temporary_skip_state_count :
83+ counts = count_states (callback .states )
84+ compare_counts (counts )
0 commit comments