Skip to content

Commit 8d8097a

Browse files
authored
feat: optional replacement semantics for stress period data (#2664)
Add an optional parameter replace=False to the .set_data() methods on the stress period data model. Setting this True resets the data before setting provided keys, i.e. clears all preexisting entries rather than leaving them in place. Since stress period configuration is not stored solely in the underlying data model but also in package block headers, a trap is necessary in MFBlock.write() to sync block headers to the keys existing in the period data, to avoid writing empty/spurious period blocks. Suggested in #2663 (comment)
1 parent cd8ba83 commit 8d8097a

5 files changed

Lines changed: 369 additions & 12 deletions

File tree

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
"""
2+
Test set_data() replace parameter (issue #2663). This parameter
3+
toggles whether .set_data() has update or replacement semantics.
4+
"""
5+
6+
from pathlib import Path
7+
8+
import numpy as np
9+
import pytest
10+
11+
import flopy
12+
13+
pytestmark = pytest.mark.mf6
14+
15+
16+
def count_stress_periods(file_path):
17+
"""Count the number of 'BEGIN period' statements in an input file."""
18+
with open(file_path, "r") as f:
19+
return sum(1 for line in f if line.strip().upper().startswith("BEGIN PERIOD"))
20+
21+
22+
@pytest.mark.parametrize("replace", [False, True], ids=["replace", "no_replace"])
23+
@pytest.mark.parametrize("use_pandas", [False, True], ids=["use_pandas", "no_pandas"])
24+
def test_set_data_replace_array_based_pkg(function_tmpdir, replace, use_pandas):
25+
name = "array_based"
26+
og_ws = Path(function_tmpdir) / "original"
27+
og_ws.mkdir(exist_ok=True)
28+
29+
nlay, nrow, ncol = 1, 10, 10
30+
nper_original = 48
31+
nper_new = 12
32+
33+
sim = flopy.mf6.MFSimulation(
34+
sim_name=name,
35+
sim_ws=str(og_ws),
36+
exe_name="mf6",
37+
use_pandas=use_pandas,
38+
)
39+
tdis = flopy.mf6.ModflowTdis(
40+
sim,
41+
nper=nper_original,
42+
perioddata=[(1.0, 1, 1.0) for _ in range(nper_original)],
43+
)
44+
ims = flopy.mf6.ModflowIms(sim)
45+
gwf = flopy.mf6.ModflowGwf(sim, modelname=name)
46+
dis = flopy.mf6.ModflowGwfdis(
47+
gwf,
48+
nlay=nlay,
49+
nrow=nrow,
50+
ncol=ncol,
51+
delr=100.0,
52+
delc=100.0,
53+
top=100.0,
54+
botm=0.0,
55+
)
56+
ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0)
57+
npf = flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0)
58+
oc = flopy.mf6.ModflowGwfoc(
59+
gwf,
60+
budget_filerecord=f"{name}.cbc",
61+
head_filerecord=f"{name}.hds",
62+
saverecord=[("HEAD", "LAST"), ("BUDGET", "LAST")],
63+
)
64+
rch_data = {kper: 0.001 + kper * 0.0001 for kper in range(nper_original)}
65+
rcha = flopy.mf6.ModflowGwfrcha(gwf, recharge=rch_data)
66+
67+
sim.write_simulation()
68+
69+
original_rch_file = og_ws / f"{name}.rcha"
70+
original_sp_count = count_stress_periods(original_rch_file)
71+
assert original_sp_count == nper_original
72+
73+
# Update RCH
74+
new_rch_data = {kper: 0.002 + kper * 0.0002 for kper in range(nper_new)}
75+
rcha.recharge.set_data(new_rch_data, replace=replace)
76+
77+
# Update TDIS
78+
tdis.nper = nper_new
79+
tdis.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)]
80+
81+
mod_ws = Path(function_tmpdir) / f"modified_replace_{replace}"
82+
mod_ws.mkdir(exist_ok=True)
83+
sim.set_sim_path(str(mod_ws))
84+
sim.write_simulation()
85+
86+
modified_rch_file = mod_ws / f"{name}.rcha"
87+
modified_sp_count = count_stress_periods(modified_rch_file)
88+
89+
if replace:
90+
# With replace=True, should only have 12 stress periods
91+
assert modified_sp_count == nper_new, (
92+
f"Expected {nper_new} stress periods "
93+
f"with replace=True, got {modified_sp_count}"
94+
)
95+
else:
96+
# With replace=False (backwards compatible), all 48 periods remain
97+
assert modified_sp_count == nper_original, (
98+
f"Expected {nper_original} stress periods "
99+
f"with replace=False, got {modified_sp_count}"
100+
)
101+
102+
with open(modified_rch_file, "r") as f:
103+
content = f.read()
104+
assert "0.00200000" in content or "2.00000000E-03" in content
105+
assert "0.00420000" in content or "4.20000000E-03" in content
106+
107+
108+
@pytest.mark.parametrize("replace", [False, True], ids=["replace", "no_replace"])
109+
@pytest.mark.parametrize("use_pandas", [False, True], ids=["use_pandas", "no_pandas"])
110+
def test_set_data_replace_list_based_pkg(function_tmpdir, replace, use_pandas):
111+
name = "list_based"
112+
sim_ws = Path(function_tmpdir) / "wel_original"
113+
sim_ws.mkdir(exist_ok=True)
114+
115+
nlay, nrow, ncol = 1, 10, 10
116+
nper_original = 24
117+
nper_new = 6
118+
119+
sim = flopy.mf6.MFSimulation(
120+
sim_name=name, sim_ws=str(sim_ws), exe_name="mf6", use_pandas=use_pandas
121+
)
122+
tdis = flopy.mf6.ModflowTdis(
123+
sim,
124+
nper=nper_original,
125+
perioddata=[(1.0, 1, 1.0) for _ in range(nper_original)],
126+
)
127+
ims = flopy.mf6.ModflowIms(sim)
128+
gwf = flopy.mf6.ModflowGwf(sim, modelname=name)
129+
dis = flopy.mf6.ModflowGwfdis(
130+
gwf,
131+
nlay=nlay,
132+
nrow=nrow,
133+
ncol=ncol,
134+
delr=100.0,
135+
delc=100.0,
136+
top=100.0,
137+
botm=0.0,
138+
)
139+
ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0)
140+
npf = flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0)
141+
oc = flopy.mf6.ModflowGwfoc(
142+
gwf,
143+
budget_filerecord=f"{name}.cbc",
144+
head_filerecord=f"{name}.hds",
145+
saverecord=[("HEAD", "LAST"), ("BUDGET", "LAST")],
146+
)
147+
wel_data = {
148+
kper: [[(0, 5, 5), -1000.0 - kper * 10.0]] for kper in range(nper_original)
149+
}
150+
wel = flopy.mf6.ModflowGwfwel(gwf, stress_period_data=wel_data)
151+
152+
sim.write_simulation()
153+
154+
original_wel_file = sim_ws / f"{name}.wel"
155+
original_sp_count = count_stress_periods(original_wel_file)
156+
assert original_sp_count == nper_original
157+
158+
# Update WEL
159+
new_wel_data = {
160+
kper: [[(0, 5, 5), -2000.0 - kper * 20.0]] for kper in range(nper_new)
161+
}
162+
wel.stress_period_data.set_data(new_wel_data, replace=replace)
163+
164+
# Update TDIS
165+
tdis.nper = nper_new
166+
tdis.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)]
167+
168+
mod_ws = Path(function_tmpdir) / f"wel_modified_replace_{replace}"
169+
mod_ws.mkdir(exist_ok=True)
170+
sim.set_sim_path(str(mod_ws))
171+
sim.write_simulation()
172+
173+
modified_wel_file = mod_ws / f"{name}.wel"
174+
modified_sp_count = count_stress_periods(modified_wel_file)
175+
176+
if replace:
177+
# With replace=True, should only have 6 stress periods
178+
assert modified_sp_count == nper_new, (
179+
f"Expected {nper_new} stress periods with "
180+
f"replace=True, got {modified_sp_count}"
181+
)
182+
else:
183+
# With replace=False, all 24 periods remain
184+
assert modified_sp_count == nper_original, (
185+
f"Expected {nper_original} stress periods with "
186+
f"replace=False, got {modified_sp_count}"
187+
)
188+
189+
190+
def test_set_data_update_array_based_pkg(function_tmpdir):
191+
name = "update_array_based"
192+
sim_ws = Path(function_tmpdir) / "compat"
193+
sim_ws.mkdir(exist_ok=True)
194+
195+
sim = flopy.mf6.MFSimulation(sim_name=name, sim_ws=str(sim_ws), exe_name="mf6")
196+
tdis = flopy.mf6.ModflowTdis(
197+
sim, nper=10, perioddata=[(1.0, 1, 1.0) for _ in range(10)]
198+
)
199+
ims = flopy.mf6.ModflowIms(sim)
200+
gwf = flopy.mf6.ModflowGwf(sim, modelname=name)
201+
dis = flopy.mf6.ModflowGwfdis(gwf, nlay=1, nrow=10, ncol=10)
202+
ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0)
203+
npf = flopy.mf6.ModflowGwfnpf(gwf, k=10.0)
204+
oc = flopy.mf6.ModflowGwfoc(gwf)
205+
206+
initial_data = dict.fromkeys(range(5), 0.001)
207+
rch = flopy.mf6.ModflowGwfrcha(gwf, recharge=initial_data)
208+
209+
additional_data = dict.fromkeys(range(5, 10), 0.002)
210+
rch.recharge.set_data(additional_data) # replace defaults to False
211+
212+
sim.write_simulation()
213+
214+
sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws))
215+
gwf2 = sim2.get_model(name)
216+
rch2 = gwf2.get_package("RCHA")
217+
218+
for kper in range(10):
219+
data = rch2.recharge.get_data(key=kper)
220+
assert np.allclose(data, 0.001 if kper < 5 else 0.002)
221+
222+
223+
def test_set_data_update_list_based_pkg(function_tmpdir):
224+
name = "update_list_based"
225+
sim_ws = Path(function_tmpdir) / "wel_update"
226+
sim_ws.mkdir(exist_ok=True)
227+
228+
sim = flopy.mf6.MFSimulation(sim_name=name, sim_ws=str(sim_ws), exe_name="mf6")
229+
tdis = flopy.mf6.ModflowTdis(
230+
sim, nper=10, perioddata=[(1.0, 1, 1.0) for _ in range(10)]
231+
)
232+
ims = flopy.mf6.ModflowIms(sim)
233+
gwf = flopy.mf6.ModflowGwf(sim, modelname=name)
234+
dis = flopy.mf6.ModflowGwfdis(gwf, nlay=1, nrow=10, ncol=10)
235+
ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0)
236+
npf = flopy.mf6.ModflowGwfnpf(gwf, k=10.0)
237+
oc = flopy.mf6.ModflowGwfoc(gwf)
238+
239+
initial_data = {kper: [[(0, 5, 5), -1000.0]] for kper in range(5)}
240+
wel = flopy.mf6.ModflowGwfwel(gwf, stress_period_data=initial_data)
241+
242+
additional_data = {kper: [[(0, 7, 7), -2000.0]] for kper in range(5, 10)}
243+
wel.stress_period_data.set_data(additional_data) # replace defaults to False
244+
245+
sim.write_simulation()
246+
247+
sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws))
248+
gwf2 = sim2.get_model(name)
249+
wel2 = gwf2.get_package("WEL")
250+
251+
for kper in range(10):
252+
data = wel2.stress_period_data.get_data(key=kper)
253+
assert data is not None, f"Period {kper} should have data"
254+
if kper < 5:
255+
# Original data should be at (0, 5, 5)
256+
assert len(data) == 1
257+
assert data[0]["cellid"] == (0, 5, 5)
258+
assert data[0]["q"] == -1000.0
259+
else:
260+
# Additional data should be at (0, 7, 7)
261+
assert len(data) == 1
262+
assert data[0]["cellid"] == (0, 7, 7)
263+
assert data[0]["q"] == -2000.0

flopy/mf6/data/mfdataarray.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,7 +1890,7 @@ def _build_period_data(
18901890
output[sp] = data
18911891
return output
18921892

1893-
def set_record(self, data_record):
1893+
def set_record(self, data_record, replace=False):
18941894
"""Sets data and metadata at layer `layer` and time `key` to
18951895
`data_record`. For unlayered data do not pass in `layer`.
18961896
@@ -1902,10 +1902,15 @@ def set_record(self, data_record):
19021902
and metadata (factor, iprn, filename, binary, data) for a given
19031903
stress period. How to define the dictionary of data and
19041904
metadata is described in the MFData class's set_record method.
1905+
replace : bool
1906+
Perform the operation with replacement semantics: all existing
1907+
stress period keys not present in the new dictionary will be
1908+
removed. If False, existing keys not in the new dictionary
1909+
will be preserved. Defaults False for backwards compatibility.
19051910
"""
1906-
self._set_data_record(data_record, is_record=True)
1911+
self._set_data_record(data_record, is_record=True, replace=replace)
19071912

1908-
def set_data(self, data, multiplier=None, layer=None, key=None):
1913+
def set_data(self, data, multiplier=None, layer=None, key=None, replace=False):
19091914
"""Sets the contents of the data at layer `layer` and time `key` to
19101915
`data` with multiplier `multiplier`. For unlayered data do not pass
19111916
in `layer`.
@@ -1926,15 +1931,30 @@ def set_data(self, data, multiplier=None, layer=None, key=None):
19261931
key : int
19271932
Zero based stress period to assign data too. Does not apply
19281933
if `data` is a dictionary.
1934+
replace : bool
1935+
If True and `data` is a dictionary, perform the operation
1936+
with replacement semantics: all existing stress period keys
1937+
not present in the new dictionary will be removed. If False,
1938+
existing keys not in the new dictionary will be preserved.
1939+
Defaults False for backwards compatibility.
19291940
"""
1930-
self._set_data_record(data, multiplier, layer, key)
1941+
self._set_data_record(data, multiplier, layer, key, replace=replace)
19311942

19321943
def _set_data_record(
1933-
self, data, multiplier=None, layer=None, key=None, is_record=False
1944+
self, data, multiplier=None, layer=None, key=None, is_record=False, replace=False
19341945
):
19351946
if isinstance(data, dict):
19361947
# each item in the dictionary is a list for one stress period
19371948
# the dictionary key is the stress period the list is for
1949+
1950+
# If replacing, remove keys not in the new data
1951+
if replace and self._data_storage:
1952+
keys_to_remove = set(self._data_storage.keys()) - set(data.keys())
1953+
for k in keys_to_remove:
1954+
self.remove_transient_key(k)
1955+
if k in self.empty_keys:
1956+
del self.empty_keys[k]
1957+
19381958
del_keys = []
19391959
for key, list_item in data.items():
19401960
if list_item is None:

flopy/mf6/data/mfdatalist.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,7 +1780,7 @@ def get_data(self, key=None, apply_mult=False, **kwargs):
17801780
else:
17811781
return None
17821782

1783-
def set_record(self, data_record, autofill=False, check_data=True):
1783+
def set_record(self, data_record, autofill=False, check_data=True, replace=False):
17841784
"""Sets the contents of the data based on the contents of
17851785
'data_record`.
17861786
@@ -1795,15 +1795,21 @@ def set_record(self, data_record, autofill=False, check_data=True):
17951795
Automatically correct data
17961796
check_data : bool
17971797
Whether to verify the data
1798+
replace : bool
1799+
Perform the operation with replacement semantics: all existing
1800+
stress period keys not present in the new dictionary will be
1801+
removed. If False, existing keys not in the new dictionary
1802+
will be preserved. Defaults False for backwards compatibility.
17981803
"""
17991804
self._set_data_record(
18001805
data_record,
18011806
autofill=autofill,
18021807
check_data=check_data,
18031808
is_record=True,
1809+
replace=replace,
18041810
)
18051811

1806-
def set_data(self, data, key=None, autofill=False):
1812+
def set_data(self, data, key=None, autofill=False, replace=False):
18071813
"""Sets the contents of the data at time `key` to `data`.
18081814
18091815
Parameters
@@ -1819,17 +1825,32 @@ def set_data(self, data, key=None, autofill=False):
18191825
if `data` is a dictionary.
18201826
autofill : bool
18211827
Automatically correct data.
1828+
replace : bool
1829+
If True and `data` is a dictionary, perform the operation
1830+
with replacement semantics: all existing stress period keys
1831+
not present in the new dictionary will be removed. If False,
1832+
existing keys not in the new dictionary will be preserved.
1833+
Defaults False for backwards compatibility.
18221834
"""
1823-
self._set_data_record(data, key, autofill)
1835+
self._set_data_record(data, key, autofill, replace=replace)
18241836

18251837
def _set_data_record(
1826-
self, data, key=None, autofill=False, check_data=False, is_record=False
1838+
self, data, key=None, autofill=False, check_data=False, is_record=False, replace=False
18271839
):
18281840
self._cache_model_grid = True
18291841
if isinstance(data, dict):
18301842
if "filename" not in data and "data" not in data:
18311843
# each item in the dictionary is a list for one stress period
18321844
# the dictionary key is the stress period the list is for
1845+
1846+
# If replacing, remove keys not in the new data
1847+
if replace and self._data_storage:
1848+
keys_to_remove = set(self._data_storage.keys()) - set(data.keys())
1849+
for k in keys_to_remove:
1850+
self.remove_transient_key(k)
1851+
if k in self.empty_keys:
1852+
del self.empty_keys[k]
1853+
18331854
del_keys = []
18341855
for key, list_item in data.items():
18351856
if list_item is None:

0 commit comments

Comments
 (0)