Skip to content

Commit 916f605

Browse files
committed
shape fixes, etc
1 parent 84b7a8b commit 916f605

4 files changed

Lines changed: 223 additions & 37 deletions

File tree

autotest/dfns/test_mapper.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,14 @@ def test_map_recarray_conversion():
179179
dfn = _v1_dfn(
180180
name="test-pkg",
181181
blocks={
182+
"dimensions": {
183+
"maxbound": _v1_field(
184+
name="maxbound",
185+
type="integer",
186+
block="dimensions",
187+
in_record=False,
188+
),
189+
},
182190
"period": {
183191
"stress_period_data": _v1_field(
184192
name="stress_period_data",
@@ -199,13 +207,105 @@ def test_map_recarray_conversion():
199207
block="period",
200208
in_record=True,
201209
),
202-
}
210+
},
203211
},
204212
)
205213
component = map_v2(dfn)
206214
period_fields = component.blocks["period"].fields
207215
spd = period_fields["stress_period_data"]
208216
assert isinstance(spd, List)
217+
assert spd.shape == ["maxbound"]
209218
assert isinstance(spd.item, Record)
210219
assert "cellid" in spd.item.fields
211220
assert "q" in spd.item.fields
221+
222+
223+
def test_map_recarray_missing_shape_inferred_from_maxbound():
224+
"""Period list with empty shape gets shape=["maxbound"] when maxbound dim exists."""
225+
dfn = _v1_dfn(
226+
name="utl-spc",
227+
blocks={
228+
"dimensions": {
229+
"maxbound": _v1_field(
230+
name="maxbound",
231+
type="integer",
232+
block="dimensions",
233+
in_record=False,
234+
),
235+
},
236+
"period": {
237+
"spd": _v1_field(
238+
name="spd",
239+
type="recarray bndno spcsetting",
240+
block="period",
241+
shape="", # empty in v1
242+
),
243+
"bndno": _v1_field(
244+
name="bndno",
245+
type="integer",
246+
block="period",
247+
in_record=True,
248+
),
249+
"spcsetting": _v1_field(
250+
name="spcsetting",
251+
type="keystring concentration",
252+
block="period",
253+
in_record=True,
254+
),
255+
"concentration": _v1_field(
256+
name="concentration",
257+
type="double precision",
258+
block="period",
259+
tagged=True,
260+
in_record=True,
261+
),
262+
},
263+
},
264+
)
265+
component = map_v2(dfn)
266+
period_fields = component.blocks["period"].fields
267+
spd = period_fields["spd"]
268+
assert isinstance(spd, List)
269+
assert spd.shape == ["maxbound"]
270+
271+
272+
def test_map_recarray_no_shape_no_maxbound():
273+
"""Period list with no shape and no maxbound dim keeps shape=[]."""
274+
dfn = _v1_dfn(
275+
name="gwf-sfr",
276+
advanced=True,
277+
blocks={
278+
"period": {
279+
"perioddata": _v1_field(
280+
name="perioddata",
281+
type="recarray ifno sfrsetting",
282+
block="period",
283+
shape="",
284+
),
285+
"ifno": _v1_field(
286+
name="ifno",
287+
type="integer",
288+
block="period",
289+
in_record=True,
290+
),
291+
"sfrsetting": _v1_field(
292+
name="sfrsetting",
293+
type="keystring status",
294+
block="period",
295+
in_record=True,
296+
),
297+
"status": _v1_field(
298+
name="status",
299+
type="string",
300+
block="period",
301+
tagged=True,
302+
in_record=True,
303+
),
304+
},
305+
},
306+
)
307+
component = map_v2(dfn)
308+
period_fields = component.blocks["period"].fields
309+
lst = period_fields["perioddata"]
310+
assert isinstance(lst, List)
311+
assert lst.shape == []

autotest/test_dfnmap.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,6 @@ def dfn_dir(module_tmpdir):
3636
return path
3737

3838

39-
@pytest.fixture(scope="module", params=FORMATS)
40-
def converted_v1(request, dfn_dir, module_tmpdir):
41-
fmt = request.param
42-
out = module_tmpdir / f"v1-{fmt}"
43-
migrate(dfn_dir, out, schema_version="1", fmt=fmt)
44-
return out, fmt
45-
46-
4739
@pytest.fixture(scope="module", params=FORMATS)
4840
def converted_v1_1(request, dfn_dir, module_tmpdir):
4941
fmt = request.param
@@ -60,17 +52,6 @@ def converted_v2(request, dfn_dir, module_tmpdir):
6052
return out, fmt
6153

6254

63-
@requires_pkg("boltons")
64-
def test_convert_v1(converted_v1):
65-
out, fmt = converted_v1
66-
files = list(out.glob(f"*.{fmt}"))
67-
assert files
68-
for p in files:
69-
data = _load(p, fmt)
70-
assert data["name"] == p.stem
71-
assert data["schema_version"] == "1"
72-
73-
7455
@requires_pkg("boltons")
7556
def test_convert_v1_1(converted_v1_1):
7657
out, fmt = converted_v1_1

modflow_devtools/dfnmap.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
Format = Literal["yaml", "toml", "json"]
1818

19-
_EXT: dict[str, str] = {"yaml": ".yaml", "toml": ".toml", "json": ".json"}
20-
2119
# YAML 1.1 (PyYAML default) serializes booleans as yes/no; override to true/false (YAML 1.2).
2220
pyaml.add_representer(
2321
bool,
@@ -92,14 +90,14 @@ def migrate(
9290
outdir : str or PathLike
9391
Output directory.
9492
schema_version : str, optional
95-
Target schema version: "1", "1.1", or "2". Default "2".
93+
Target schema version: "1.1" or "2". Default "2".
9694
fmt : str, optional
9795
Output format: "yaml", "toml", or "json". Default "yaml".
9896
"""
9997
inpath = Path(inpath).expanduser().absolute()
10098
outdir = Path(outdir).expanduser().absolute()
10199
outdir.mkdir(exist_ok=True, parents=True)
102-
ext = _EXT[fmt]
100+
ext = f".{fmt}"
103101

104102
if inpath.is_file():
105103
if inpath.name == "common.dfn":
@@ -113,32 +111,24 @@ def migrate(
113111
with inpath.open() as f:
114112
dfn = v1.Dfn.load(f, name=inpath.stem, common=common)
115113

116-
if schema_version == "1":
117-
pass
118-
elif schema_version == "1.1":
114+
if schema_version == "1.1":
119115
dfn = map_v1_1(dfn)
120116
elif schema_version == "2":
121117
dfn = map_v2(dfn)
122118
else:
123-
raise ValueError(
124-
f"Got schema version {schema_version}, supported versions are: 1, 1.1, 2"
125-
)
119+
raise ValueError(f"Got schema version {schema_version}, supported versions are: 1.1, 2")
126120

127121
_write(_serialize_safe(dfn), outdir / f"{inpath.stem}{ext}", fmt)
128122
else:
129123
dfns = v1.load_all(inpath)
130124

131-
if schema_version == "1":
132-
pass
133-
elif schema_version == "1.1":
125+
if schema_version == "1.1":
134126
dfns = v1.to_flat(v1.to_tree(dfns))
135127
dfns = {name: map_v1_1(dfn) for name, dfn in dfns.items()}
136128
elif schema_version == "2":
137129
dfns = {name: map_v2(dfn) for name, dfn in dfns.items()}
138130
else:
139-
raise ValueError(
140-
f"Got schema version {schema_version}, supported versions are: 1, 1.1, 2"
141-
)
131+
raise ValueError(f"Got schema version {schema_version}, supported versions are: 1.1, 2")
142132

143133
for dfn_name, dfn in dfns.items():
144134
_write(_serialize_safe(dfn), outdir / f"{dfn_name}{ext}", fmt)
@@ -164,7 +154,7 @@ def migrate(
164154
"--schema-version",
165155
"-s",
166156
default="2",
167-
choices=["1", "1.1", "2"],
157+
choices=["1.1", "2"],
168158
help="Target schema version (default: 2).",
169159
)
170160
parser.add_argument(

modflow_devtools/dfns/mapper.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,59 @@ def _scope_for(
3131
return "component"
3232

3333

34+
def _raw_dim_names(blocks: dict[str, v2.Block]) -> set[str]:
35+
"""Names of all Integer fields in the dimensions block."""
36+
dim_block = blocks.get("dimensions")
37+
if not dim_block:
38+
return set()
39+
return {fname for fname, f in dim_block.fields.items() if isinstance(f, v2.Integer)}
40+
41+
42+
def _parse_list_shape(s: str) -> list[str]:
43+
"""
44+
Parse a v1 recarray shape string into a ``List.shape`` value.
45+
46+
Only a bare identifier is accepted — complex expressions such as
47+
``sum(nlakeconn)`` cannot be represented in ``List.shape`` and are dropped.
48+
"""
49+
if not s:
50+
return []
51+
s_clean = s.strip()
52+
if s_clean.startswith("(") and s_clean.endswith(")"):
53+
s_clean = s_clean[1:-1].strip()
54+
if _IDENT_RE.fullmatch(s_clean):
55+
return [s_clean]
56+
return []
57+
58+
59+
def _normalize_n_prefix_shapes(
60+
blocks: dict[str, v2.Block],
61+
raw_dim_names: set[str],
62+
) -> dict[str, v2.Block]:
63+
"""
64+
Fix List shapes that use ``nFoo`` where the actual dimension is ``maxFoo``.
65+
66+
Some v1 DFNs (e.g. ``gwf-mvr [packages]`` with ``shape (npackages)``) use
67+
an ``n``-prefixed name while the dimensions block defines the same quantity
68+
under a ``max``-prefixed name. Normalise before building explicit dims.
69+
"""
70+
result = {}
71+
for bname, block in blocks.items():
72+
new_fields = {}
73+
changed = False
74+
for fname, field in block.fields.items():
75+
if isinstance(field, v2.List) and field.shape:
76+
elem = field.shape[0]
77+
if elem not in raw_dim_names and elem.startswith("n") and len(elem) > 1:
78+
candidate = "max" + elem[1:]
79+
if candidate in raw_dim_names:
80+
field = field.model_copy(update={"shape": [candidate]})
81+
changed = True
82+
new_fields[fname] = field
83+
result[bname] = block.model_copy(update={"fields": new_fields}) if changed else block
84+
return result
85+
86+
3487
def _build_explicit_dims(
3588
parent: "str | list[str] | None",
3689
blocks: dict[str, v2.Block],
@@ -65,6 +118,33 @@ def _build_explicit_dims(
65118
return dims
66119

67120

121+
def _sanitize_list_shapes(
122+
blocks: dict[str, v2.Block],
123+
known_dims: set[str],
124+
) -> dict[str, v2.Block]:
125+
"""
126+
Clear the shape of any List whose shape element doesn't resolve to a known
127+
dim.
128+
129+
Advanced packages (LAK, SFR, GNC, transport packages, etc.) often carry
130+
``shape (maxbound)`` in their v1 DFNs as a convention even though
131+
``maxbound`` is not declared as a dimension. The structurally correct v2
132+
representation for such lists is ``shape=[]``.
133+
"""
134+
result = {}
135+
for bname, block in blocks.items():
136+
new_fields = {}
137+
changed = False
138+
for fname, field in block.fields.items():
139+
if isinstance(field, v2.List) and field.shape:
140+
if any(elem not in known_dims for elem in field.shape):
141+
field = field.model_copy(update={"shape": []})
142+
changed = True
143+
new_fields[fname] = field
144+
result[bname] = block.model_copy(update={"fields": new_fields}) if changed else block
145+
return result
146+
147+
68148
def _resolve_dimensions(
69149
blocks: dict[str, v2.Block],
70150
) -> tuple[dict[str, v2.Block], dict[str, v2.Dim]]:
@@ -176,6 +256,34 @@ def _resolve_record(record: v2.Record) -> v2.Record:
176256
}
177257

178258

259+
def _fill_period_list_shapes(
260+
blocks: dict[str, v2.Block],
261+
explicit_dims: dict[str, v2.Dim],
262+
) -> dict[str, v2.Block]:
263+
"""
264+
For period blocks whose List field has no shape expression, infer the shape
265+
from the component's explicit dims. Currently handles ``maxbound`` only:
266+
if the component defines a ``maxbound`` dimension but the period list omits
267+
it, add ``shape=["maxbound"]``.
268+
"""
269+
if "maxbound" not in explicit_dims:
270+
return blocks
271+
result = {}
272+
for bname, block in blocks.items():
273+
if "period" not in bname:
274+
result[bname] = block
275+
continue
276+
new_fields = {}
277+
changed = False
278+
for fname, field in block.fields.items():
279+
if isinstance(field, v2.List) and not field.shape:
280+
field = field.model_copy(update={"shape": ["maxbound"]})
281+
changed = True
282+
new_fields[fname] = field
283+
result[bname] = block.model_copy(update={"fields": new_fields}) if changed else block
284+
return result
285+
286+
179287
def map(dfn: v1.Dfn) -> v2.Component:
180288
"""Map a component definition from the v1 schema to v2."""
181289

@@ -401,6 +509,7 @@ def _record_fields() -> dict:
401509

402510
if _type.startswith("recarray"):
403511
item = _row_field()
512+
list_shape = _parse_list_shape(shape_str) if shape_str else []
404513
return v2.List(
405514
name=_name,
406515
longname=longname,
@@ -410,6 +519,7 @@ def _record_fields() -> dict:
410519
developmode=developmode,
411520
netcdf=netcdf,
412521
item=item,
522+
shape=list_shape,
413523
)
414524

415525
if _type.startswith("keystring"):
@@ -572,7 +682,12 @@ def _record_fields() -> dict:
572682

573683
blocks, array_dims = _resolve_dimensions(blocks)
574684
blocks = _resolve_relations(blocks)
685+
raw_dim_names = _raw_dim_names(blocks)
686+
blocks = _normalize_n_prefix_shapes(blocks, raw_dim_names)
575687
explicit_dims = _build_explicit_dims(dfn["parent"], blocks)
688+
known_dims = set(explicit_dims) | set(array_dims)
689+
blocks = _sanitize_list_shapes(blocks, known_dims)
690+
blocks = _fill_period_list_shapes(blocks, explicit_dims)
576691
dims = {**explicit_dims, **array_dims} or None
577692

578693
d: dict[str, Any] = {

0 commit comments

Comments
 (0)