Skip to content

Commit 01837b6

Browse files
committed
wip
1 parent 79e78ad commit 01837b6

4 files changed

Lines changed: 120 additions & 117 deletions

File tree

autotest/test_dfn.py

Lines changed: 53 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from modflow_devtools.dfn import _load_common, load, load_all, load_tree
5+
from modflow_devtools.dfn import _load_common, load, load_all
66
from modflow_devtools.dfn.fetch import fetch_dfns
77
from modflow_devtools.dfn2toml import convert
88
from modflow_devtools.markers import requires_pkg
@@ -29,14 +29,11 @@ def pytest_generate_tests(metafunc):
2929

3030
if "toml_name" in metafunc.fixturenames:
3131
convert(DFN_DIR, TOML_DIR)
32-
dfn_paths = list(DFN_DIR.glob("*.dfn"))
3332
expected_toml_paths = [
34-
TOML_DIR / f"{dfn.stem.replace('-nam', '')}.toml"
35-
for dfn in dfn_paths
36-
if "common" not in dfn.stem
33+
dfn for dfn in DFN_DIR.glob("*.dfn") if "common" not in dfn.stem
3734
]
38-
toml_names = [toml.stem for toml in TOML_DIR.glob("*.toml")]
3935
assert all(toml_path.exists() for toml_path in expected_toml_paths)
36+
toml_names = [toml.stem for toml in TOML_DIR.glob("*.toml")]
4037
metafunc.parametrize("toml_name", toml_names, ids=toml_names)
4138

4239

@@ -47,7 +44,7 @@ def test_load_v1(dfn_name):
4744
(DFN_DIR / f"{dfn_name}.dfn").open() as dfn_file,
4845
):
4946
common, _ = _load_common(common_file)
50-
dfn = load(dfn_file, name=dfn_name, common=common)
47+
dfn = load(dfn_file, name=dfn_name, format="dfn", common=common)
5148
assert any(dfn.fields)
5249

5350

@@ -61,65 +58,57 @@ def test_load_v2(toml_name):
6158
@requires_pkg("boltons")
6259
@pytest.mark.parametrize("schema_version", list(SPEC_DIRS.keys()))
6360
def test_load_all(schema_version):
64-
path = SPEC_DIRS[schema_version]
65-
dfns = load_all(path)
61+
dfns = load_all(path=SPEC_DIRS[schema_version])
6662
assert all(any(dfn.fields) for dfn in dfns.values())
6763

6864

69-
@requires_pkg("boltons")
70-
def test_load_tree():
71-
import tempfile
72-
65+
@requires_pkg("boltons", "tomli")
66+
def test_convert(function_tmpdir):
7367
import tomli
7468

75-
with tempfile.TemporaryDirectory() as tmp_dir:
76-
tmp_path = Path(tmp_dir)
77-
convert(DFN_DIR, tmp_path)
78-
79-
# Test file conversion and naming
80-
assert (tmp_path / "sim.toml").exists()
81-
assert (tmp_path / "gwf.toml").exists()
82-
assert not (tmp_path / "sim-nam.toml").exists()
83-
84-
# Test parent relationships in files
85-
with (tmp_path / "sim.toml").open("rb") as f:
86-
sim_data = tomli.load(f)
87-
assert sim_data["name"] == "sim"
88-
assert "parent" not in sim_data
89-
90-
with (tmp_path / "gwf.toml").open("rb") as f:
91-
gwf_data = tomli.load(f)
92-
assert gwf_data["name"] == "gwf"
93-
assert gwf_data["parent"] == "sim"
94-
95-
dfns = load_all(tmp_path)
96-
root = load_tree(tmp_path)
97-
roots = []
98-
for dfn in dfns.values():
99-
if dfn.parent:
100-
assert dfn.parent in dfns
101-
else:
102-
roots.append(dfn.name)
103-
assert len(roots) == 1
104-
assert root.name == "sim"
105-
assert root == roots[0]
106-
107-
model_types = ["gwf", "gwt", "gwe"]
108-
models = root.children or {}
109-
for model_type in model_types:
110-
if model_type in models:
111-
assert models[model_type].name == model_type
112-
assert models[model_type].parent == "sim"
113-
114-
if "gwf" in models:
115-
pkgs = models["gwf"].children or {}
116-
gwf_packages = [
117-
k for k in pkgs if k.startswith("gwf-") and isinstance(pkgs[k], dict)
118-
]
119-
assert len(gwf_packages) > 0
120-
121-
if dis := pkgs.get("gwf-dis", None):
122-
assert dis.name == "gwf-dis"
123-
assert dis.parent == "gwf"
124-
assert "options" in (dis.blocks or {})
125-
assert "dimensions" in (dis.blocks or {})
69+
convert(DFN_DIR, function_tmpdir)
70+
71+
assert (function_tmpdir / "sim-nam.toml").exists()
72+
assert (function_tmpdir / "gwf-nam.toml").exists()
73+
74+
with (function_tmpdir / "sim-nam.toml").open("rb") as f:
75+
sim_data = tomli.load(f)
76+
assert sim_data["name"] == "sim-nam"
77+
assert sim_data["schema_version"] == "2"
78+
assert "parent" not in sim_data
79+
80+
with (function_tmpdir / "gwf-nam.toml").open("rb") as f:
81+
gwf_data = tomli.load(f)
82+
assert gwf_data["name"] == "gwf-nam"
83+
assert gwf_data["parent"] == "sim-nam"
84+
assert gwf_data["schema_version"] == "2"
85+
86+
dfns = load_all(function_tmpdir)
87+
roots = []
88+
for dfn in dfns.values():
89+
if dfn.parent:
90+
assert dfn.parent in dfns
91+
else:
92+
roots.append(dfn.name)
93+
assert len(roots) == 1
94+
root = dfns[roots[0]]
95+
assert root.name == "sim-nam"
96+
97+
models = root.children or {}
98+
for mdl in models:
99+
assert models[mdl].name == mdl
100+
assert models[mdl].parent == "sim-nam"
101+
102+
if gwf := models.get("gwf-nam", None):
103+
pkgs = gwf.children or {}
104+
pkgs = {
105+
k: v
106+
for k, v in pkgs.items()
107+
if k.startswith("gwf-") and isinstance(v, dict)
108+
}
109+
assert len(pkgs) > 0
110+
if dis := pkgs.get("gwf-dis", None):
111+
assert dis.name == "gwf-dis"
112+
assert dis.parent == "gwf"
113+
assert "options" in (dis.blocks or {})
114+
assert "dimensions" in (dis.blocks or {})

modflow_devtools/dfn/__init__.py

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from modflow_devtools.dfn.schema.ref import Ref
3030
from modflow_devtools.dfn.schema.v1 import FieldV1
3131
from modflow_devtools.dfn.schema.v2 import FieldV2
32-
from modflow_devtools.misc import try_literal_eval
32+
from modflow_devtools.misc import drop_none_or_empty, try_literal_eval
3333

3434
__all__ = [
3535
"SCALAR_TYPES",
@@ -41,6 +41,7 @@
4141
"FieldV1",
4242
"Fields",
4343
"Ref",
44+
"flatten",
4445
"infer_tree",
4546
"load",
4647
"load_all",
@@ -62,8 +63,8 @@ class Dfn:
6263
MODFLOW 6 input component definition.
6364
"""
6465

65-
name: str
6666
schema_version: Version
67+
name: str
6768
parent: str | None = None
6869
advanced: bool = False
6970
multi: bool = False
@@ -313,13 +314,15 @@ def remove_attrs(path, key, value):
313314
def map(self, dfn: Dfn) -> Dfn:
314315
if dfn.schema_version == (v2 := Version("2")):
315316
return dfn
317+
316318
return Dfn(
317319
name=dfn.name,
318320
advanced=dfn.advanced,
319321
multi=dfn.multi,
320322
ref=dfn.ref,
321323
blocks=MapV1To2.map_blocks(dfn),
322324
schema_version=v2,
325+
parent=dfn.parent,
323326
)
324327

325328

@@ -396,63 +399,73 @@ def infer_tree(dfns: Dfns) -> Dfn:
396399
Assumes all DFNs are of the same schema version.
397400
"""
398401

399-
def drop_none_or_empty(path, key, value):
400-
if value is None or value == "" or value == [] or value == {}:
401-
return False
402-
return True
403-
404-
def add_parent(dfn):
405-
dfn = dfn.copy()
406-
dfn_name = dfn["name"]
407-
if dfn_name == "sim-nam":
408-
# simulation name file
409-
dfn.name = "sim"
402+
def set_parent(dfn):
403+
dfn = asdict(dfn)
404+
if (dfn_name := dfn["name"]) == "sim-nam":
405+
pass
410406
elif dfn_name.endswith("-nam"):
411-
# model name files
412-
model_type = dfn_name[:-4] # Remove "-nam"
413-
dfn.name = model_type
414-
dfn.parent = "sim"
407+
dfn["parent"] = "sim-nam"
415408
elif (
416409
dfn_name.startswith("exg-")
417410
or dfn_name.startswith("sln-")
418411
or dfn_name.startswith("utl-")
419412
):
420-
# exchanges, solutions, standalone utilities
421-
dfn.parent = "sim"
413+
dfn["parent"] = "sim-nam"
422414
elif "-" in dfn_name:
423-
# packages
424-
model_type = dfn_name.split("-")[0]
425-
dfn.parent = model_type
426-
427-
return remap(dfn, visit=drop_none_or_empty)
428-
429-
dfns = {name: add_parent(dfn) for name, dfn in dfns.items()}
430-
first = next(iter(dfns.values()), None)
431-
schema_version = first.schema_version if first else Version("1")
432-
433-
if schema_version == Version("1"):
434-
raise NotImplementedError("Structure inference from v1 schema not implemented")
435-
elif schema_version == Version("2"):
436-
if (
437-
len(
438-
roots := [
439-
(name, dfn) for name, dfn in dfns.items() if dfn.parent is None
440-
]
415+
mdl = dfn_name.split("-")[0]
416+
dfn["parent"] = f"{mdl}-nam"
417+
418+
return Dfn(**remap(dfn, visit=drop_none_or_empty))
419+
420+
dfns = {name: set_parent(dfn) for name, dfn in dfns.items()}
421+
first_dfn = next(iter(dfns.values()), None)
422+
match schema_version := str(
423+
first_dfn.schema_version if first_dfn else Version("1")
424+
):
425+
case "1":
426+
raise NotImplementedError("Tree inference from v1 schema not implemented")
427+
case "2":
428+
if (
429+
nroots := len(
430+
roots := [
431+
(name, dfn) for name, dfn in dfns.items() if dfn.parent is None
432+
]
433+
)
434+
) != 1:
435+
raise ValueError(f"Expected one root component, found {nroots}")
436+
437+
def get_children(node_name: str) -> Dfns:
438+
return {
439+
name: dfn for name, dfn in dfns.items() if dfn.parent == node_name
440+
}
441+
442+
return Dfn(
443+
name=(root_name := roots[0][0]),
444+
blocks=dfns[root_name].blocks,
445+
children=get_children(root_name),
446+
schema_version=Version(schema_version),
447+
)
448+
case _:
449+
raise ValueError(
450+
f"Unsupported schema version: {schema_version}. Expected 1 or 2."
441451
)
442-
!= 1
443-
):
444-
raise ValueError(f"Expected one root component, found {len(roots)}")
445452

446-
def get_children(node_name: str) -> Dfns:
447-
return {name: dfn for name, dfn in dfns.items() if dfn.parent == node_name}
448453

449-
return Dfn(
450-
name=(root_name := roots[0][0]),
451-
blocks=dfns[root_name].blocks,
452-
children=get_children(root_name),
453-
schema_version=schema_version,
454-
)
455-
raise ValueError(f"Unsupported schema version: {schema_version}. Expected 1 or 2.")
454+
def flatten(dfn: Dfn) -> Dfns:
455+
"""
456+
Flatten a MODFLOW 6 input component hierarchy to a flat spec:
457+
unlinked DFNs, i.e. without `children` populated, only `parent`.
458+
459+
Returns a dictionary of all components in the tree.
460+
"""
461+
462+
def _flatten(dfn: Dfn) -> Dfns:
463+
dfns = {dfn.name: replace(dfn, children=None)}
464+
for child in (dfn.children or {}).values():
465+
dfns.update(_flatten(child))
466+
return dfns
467+
468+
return _flatten(dfn)
456469

457470

458471
def load_tree(dfndir: str | PathLike) -> Dfn:

modflow_devtools/dfn2toml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import tomli_w as tomli
99
from boltons.iterutils import remap
1010

11-
from modflow_devtools.dfn import load_all, map
11+
from modflow_devtools.dfn import flatten, load_tree, map
1212
from modflow_devtools.misc import drop_none_or_empty
1313

1414
# mypy: ignore-errors
@@ -18,7 +18,7 @@ def convert(indir: PathLike, outdir: PathLike, schema_version: str = "2") -> Non
1818
indir = Path(indir).expanduser().absolute()
1919
outdir = Path(outdir).expanduser().absolute()
2020
outdir.mkdir(exist_ok=True, parents=True)
21-
dfns_in = load_all(indir)
21+
dfns_in = flatten(load_tree(indir))
2222
dfns = {
2323
name: map(dfn, schema_version=schema_version) for name, dfn in dfns_in.items()
2424
}

modflow_devtools/misc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import traceback
55
from _warnings import warn
66
from ast import literal_eval
7+
from collections.abc import Iterable
78
from contextlib import contextmanager
89
from enum import Enum
910
from functools import wraps
@@ -585,9 +586,9 @@ def try_literal_eval(value: str) -> Any:
585586

586587
def drop_none_or_empty(path, key, value):
587588
"""
588-
Drop dictionary items with None or empty string values.
589+
Drop dictionary items with None or empty values.
589590
For use with `boltons.iterutils.remap`.
590591
"""
591-
if value is None or value == "":
592+
if value is None or (isinstance(value, Iterable) and not any(value)):
592593
return False
593594
return True

0 commit comments

Comments
 (0)