Skip to content

Commit 5679349

Browse files
authored
Merge pull request #73 from hydrocomputing/stress_period_end
Fix number of stress period end states for multi-model simulations
2 parents 7aa43ca + 17b5b46 commit 5679349

3 files changed

Lines changed: 133 additions & 4 deletions

File tree

autotest/test_mf6_examples.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import modflow_devtools.models as models
2+
import pandas as pd
23
import pytest
34

45
from modflowapi import run_simulation
@@ -12,6 +13,61 @@
1213
"ex-prt-mp7-p04",
1314
]
1415

16+
temporary_skip_state_count = set(
17+
[
18+
# Skip state count tests for these model for now.
19+
# There seems to be a different problem.
20+
"ex-gwf-fhb",
21+
"ex-gwf-nwt-p02a",
22+
"ex-gwf-nwt-p02b",
23+
"ex-gwe-geotherm",
24+
"ex-gwe-radial-slow",
25+
]
26+
)
27+
28+
29+
class StateCollector:
30+
"""
31+
Callback that collects the model state sequence.
32+
33+
THe attribute `states` is a list of tuples `(model_name, state_name)`,
34+
where `state_name` is `callback_step.name`.
35+
"""
36+
37+
def __init__(self):
38+
self.states = []
39+
40+
def __call__(self, sim, callback_step):
41+
self.states.append((sim.model_names[0], callback_step.name))
42+
43+
44+
def count_states(states):
45+
"""
46+
Count states per model.
47+
48+
The numbers of `start_state` and `end_state` should match.
49+
"""
50+
states = pd.DataFrame(states, columns=["model", "state"])
51+
counts = states.groupby(["model", "state"])[["state"]].agg("count")
52+
counts.columns = ["counts"]
53+
return counts
54+
55+
56+
def compare_counts(counts):
57+
matches = [
58+
("iteration_start", "iteration_end"),
59+
("stress_period_start", "stress_period_end"),
60+
("timestep_start", "timestep_end"),
61+
]
62+
for model_name in counts.index.levels[0]:
63+
model = counts.loc[model_name]
64+
for state1, state2 in matches:
65+
count1, count2 = model.loc[state1].iloc[0], model.loc[state2].iloc[0]
66+
assert count1 == count2, (
67+
f"{state1}: {int(count1)}",
68+
f"{state2}: {int(count2)}",
69+
)
70+
1571

1672
@pytest.mark.parametrize("example_name", examples.keys())
1773
def test_example(function_tmpdir, example_name):
@@ -21,4 +77,8 @@ def test_example(function_tmpdir, example_name):
2177
model_relpath = model_name.rpartition(example_name + "/")[-1]
2278
model_workspace = function_tmpdir / model_relpath
2379
models.copy_to(model_workspace, model_name, verbose=True)
24-
run_simulation(dll, model_workspace, lambda sim, step: None, verbose=True)
80+
callback = StateCollector()
81+
run_simulation(dll, model_workspace, callback=callback, verbose=True)
82+
if model_name in temporary_skip_state_count:
83+
counts = count_states(callback.states)
84+
compare_counts(counts)

manualtest/collect_states.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Test helper for sequence of model states."""
2+
3+
from collections import defaultdict
4+
5+
import pandas as pd
6+
7+
import modflowapi
8+
9+
10+
class StateCollector:
11+
"""
12+
Callback that collects the model state sequence.
13+
14+
THe attribute `states` is a list of tuples `(model_name, state_name)`,
15+
where `state_name` is `callback_step.name`.
16+
"""
17+
18+
def __init__(self):
19+
self.states = []
20+
21+
def __call__(self, sim, callback_step):
22+
self.states.append((sim.model_names[0], callback_step.name))
23+
24+
25+
def run(dll, sim_path):
26+
"""Run a model and collect its sat sequence."""
27+
callback = StateCollector()
28+
modflowapi.run_simulation(dll=dll, sim_path=sim_path, callback=callback)
29+
return callback.states
30+
31+
32+
def visualize_states(states, limit=None):
33+
"""
34+
Visualize the state sequence by model.
35+
36+
This prints one tree-like sequence of state names per model.
37+
Different state groups have different indentations to visualize the
38+
flow.
39+
"""
40+
indents = {
41+
"initialize": 0,
42+
"finalize": 0,
43+
"stress_period_start": 4,
44+
"stress_period_end": 4,
45+
"timestep_start": 8,
46+
"timestep_end": 8,
47+
"iteration_start": 12,
48+
"iteration_end": 12,
49+
}
50+
models = defaultdict(list)
51+
for model_name, state in states:
52+
models[model_name].append(state)
53+
for model_name, model_states in models.items():
54+
print(model_name)
55+
for count, state in enumerate(model_states):
56+
if limit and limit == count:
57+
break
58+
print(" " * indents[state], state)
59+
60+
61+
def count_states(states):
62+
"""
63+
Count states per model.
64+
65+
The numbers of `start_state` and `end_state` should match.
66+
"""
67+
states = pd.DataFrame(states, columns=["model", "state"])
68+
counts = states.groupby(["model", "state"])[["state"]].agg("count")
69+
counts.columns = ["counts"]
70+
return counts

modflowapi/extensions/runner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,15 @@ def run_simulation(dll, sim_path, callback, verbose=False, _develop=False):
108108

109109
callback(sim_grp, Callbacks.timestep_end)
110110
mf6.finalize_solve(sol_id)
111+
if sim_grp.nstp == sim_grp.kstp + 1:
112+
callback(sim_grp, Callbacks.stress_period_end)
111113

112114
mf6.finalize_time_step()
113115
current_time = mf6.get_current_time()
114116

115117
if not has_converged:
116118
print(f"Simulation group: {sim_grp} DID NOT CONVERGE")
117119

118-
if sim_grp.nstp == sim_grp.kstp + 1:
119-
callback(sim_grp, Callbacks.stress_period_end)
120-
121120
try:
122121
callback(sim, Callbacks.finalize)
123122
mf6.finalize()

0 commit comments

Comments
 (0)