Skip to content

Commit 1e5fabd

Browse files
authored
fix(fixtures): fix model-loading fixtures and utilities (MODFLOW-ORG#12)
* rename modflow_devtools.misc.get_models to get_model_paths * fix get_model_paths filtering by package * sort paths returned by get_model_paths * refactor get_packages function * expand tests
1 parent 3c63aaa commit 1e5fabd

5 files changed

Lines changed: 125 additions & 61 deletions

File tree

modflow_devtools/fixtures.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Dict, List, Optional
77

88
import pytest
9-
from modflow_devtools.misc import get_mf6_ftypes, get_models
9+
from modflow_devtools.misc import get_model_paths, get_packages
1010

1111
# temporary directory fixtures
1212

@@ -174,7 +174,7 @@ def pytest_generate_tests(metafunc):
174174
key = "test_model_mf6"
175175
if key in metafunc.fixturenames:
176176
models = (
177-
get_models(
177+
get_model_paths(
178178
Path(repos_path) / "modflow6-testmodels" / "mf6",
179179
prefix="test",
180180
excluded=["test205_gwtbuy-henrytidal"],
@@ -189,7 +189,7 @@ def pytest_generate_tests(metafunc):
189189
key = "test_model_mf5to6"
190190
if key in metafunc.fixturenames:
191191
models = (
192-
get_models(
192+
get_model_paths(
193193
Path(repos_path) / "modflow6-testmodels" / "mf5to6",
194194
prefix="test",
195195
namefile="*.nam",
@@ -205,7 +205,7 @@ def pytest_generate_tests(metafunc):
205205
key = "large_test_model"
206206
if key in metafunc.fixturenames:
207207
models = (
208-
get_models(
208+
get_model_paths(
209209
Path(repos_path) / "modflow6-largetestmodels",
210210
prefix="test",
211211
namefile="*.nam",
@@ -292,7 +292,7 @@ def get_examples():
292292
for name, namefiles in examples.items():
293293
ftypes = []
294294
for namefile in namefiles:
295-
ftype = get_mf6_ftypes(namefile, packages_selected)
295+
ftype = get_packages(namefile, packages_selected)
296296
if ftype not in ftypes:
297297
ftypes += ftype
298298
if len(ftypes) > 0:

modflow_devtools/misc.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,17 @@ def get_current_branch() -> str:
101101
raise ValueError(f"Could not determine current branch: {stderr}")
102102

103103

104-
def get_mf6_ftypes(namefile_path: PathLike, ftypekeys: List[str]) -> List[str]:
104+
def get_packages(namefile_path: PathLike) -> List[str]:
105105
"""
106-
Return a list of FTYPES that are in the name file and in ftypekeys.
106+
Return a list of packages used by the model defined in the given namefile.
107107
108108
Parameters
109109
----------
110-
namefile_path : str
111-
path to a MODFLOW 6 name file
112-
ftypekeys : list
113-
list of desired FTYPEs
110+
namefile_path : PathLike
111+
path to MODFLOW 6 name file
114112
Returns
115113
-------
116-
ftypes : list
117-
list of FTYPES that match ftypekeys in namefile
114+
list of package types
118115
"""
119116
with open(namefile_path, "r") as f:
120117
lines = f.readlines()
@@ -126,22 +123,29 @@ def get_mf6_ftypes(namefile_path: PathLike, ftypekeys: List[str]) -> List[str]:
126123
if len(ll) < 2:
127124
continue
128125

129-
if ll[0] in ["#", "!"]:
126+
l = ll[0].lower()
127+
if any(l.startswith(c) for c in ["#", "!", "data", "list"]) or l in [
128+
"begin",
129+
"end",
130+
"memory_print_option",
131+
]:
130132
continue
131133

132-
for key in ftypekeys:
133-
if key.lower() in ll[0].lower():
134-
ftypes.append(ll[0])
134+
# strip "6" from package name
135+
l = l.replace("6", "")
135136

136-
return ftypes
137+
ftypes.append(l.lower())
137138

139+
return list(set(ftypes))
138140

139-
def has_packages(namefile_path: PathLike, packages: List[str]) -> bool:
140-
ftypes = [item.upper() for item in get_mf6_ftypes(namefile_path, packages)]
141-
return len(ftypes) > 0
142141

142+
def has_package(namefile_path: PathLike, package: str) -> bool:
143+
"""Determines whether the model with the given namefile contains the selected package"""
144+
packages = get_packages(namefile_path)
145+
return package.lower in packages
143146

144-
def get_models(
147+
148+
def get_model_paths(
145149
path: PathLike,
146150
prefix: str = None,
147151
namefile: str = "mfsim.nam",
@@ -150,7 +154,12 @@ def get_models(
150154
packages=None,
151155
) -> List[Path]:
152156
"""
153-
Find models in the given filesystem location.
157+
Find models recursively in the given location.
158+
Models can be filtered or excluded by pattern,
159+
filtered by packages used or naming convention
160+
for namefiles, or by parent folder name prefix.
161+
The path to the model folder (i.e., the folder
162+
containing the model's namefile) is returned.
154163
"""
155164

156165
# if path doesn't exist, return empty list
@@ -161,7 +170,7 @@ def get_models(
161170
namfile_paths = [
162171
p
163172
for p in Path(path).rglob(
164-
f"{prefix}*/{namefile}" if prefix else namefile
173+
f"{prefix}*/**/{namefile}" if prefix else namefile
165174
)
166175
]
167176

@@ -172,37 +181,36 @@ def get_models(
172181
if (not excluded or not any(e in str(p) for e in excluded))
173182
]
174183

175-
# filter by package (optional)
184+
# filter by package
176185
if packages:
177-
namfile_paths = [
178-
p
179-
for p in namfile_paths
180-
if (has_packages(p, packages) if packages else True)
181-
]
182-
183-
# get model dir paths
186+
filtered = []
187+
for nfp in namfile_paths:
188+
nf_pkgs = get_packages(nfp)
189+
shared = set(nf_pkgs).intersection(
190+
set([p.lower() for p in packages])
191+
)
192+
if any(shared):
193+
filtered.append(nfp)
194+
namfile_paths = filtered
195+
196+
# get model folder paths
184197
model_paths = [p.parent for p in namfile_paths]
185198

186-
# filter by model name (optional)
199+
# filter by model name
187200
if selected:
188201
model_paths = [
189202
model
190203
for model in model_paths
191204
if any(s in model.name for s in selected)
192205
]
193206

194-
# exclude dev examples on master or release branches
195-
branch = get_current_branch()
196-
if "master" in branch.lower() or "release" in branch.lower():
197-
model_paths = [
198-
model for model in model_paths if "_dev" not in model.name.lower()
199-
]
200-
201-
return model_paths
207+
return sorted(model_paths)
202208

203209

204210
def is_connected(hostname):
205-
"""See https://stackoverflow.com/a/20913928/ to test hostname."""
211+
"""
212+
Tests whether the given URL is accessible.
213+
See https://stackoverflow.com/a/20913928/."""
206214
try:
207215
host = socket.gethostbyname(hostname)
208216
s = socket.create_connection((host, 80), 2)
@@ -214,6 +222,8 @@ def is_connected(hostname):
214222

215223

216224
def is_in_ci():
225+
"""Determines whether the current process is running GitHub Actions CI"""
226+
217227
# if running in GitHub Actions CI, "CI" variable always set to true
218228
# https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables
219229
return bool(environ.get("CI", None))
@@ -222,7 +232,7 @@ def is_in_ci():
222232
def is_github_rate_limited() -> Optional[bool]:
223233
"""
224234
Determines if a GitHub API rate limit is applied to the current IP.
225-
Note that running this function will consume an API request!
235+
Running this function will consume an API request!
226236
227237
Returns
228238
-------

modflow_devtools/test/test_executables.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,12 @@
77
from modflow_devtools.executables import Executables
88
from modflow_devtools.misc import add_sys_path, get_suffixes
99

10-
_bin_path = Path(environ.get("BIN_PATH")).expanduser()
10+
_bin_path = Path(environ.get("BIN_PATH")).expanduser().absolute()
1111
_ext, _ = get_suffixes(sys.platform)
1212

1313

14-
@pytest.fixture
15-
def bin_path(module_tmpdir) -> Path:
16-
return _bin_path.absolute()
17-
18-
1914
@pytest.mark.skipif(not _bin_path.is_dir(), reason="bin directory not found")
20-
def test_get_path(bin_path):
15+
def test_get_path():
2116
with add_sys_path(str(_bin_path)):
2217
ext, _ = get_suffixes(sys.platform)
2318
assert (
@@ -26,22 +21,22 @@ def test_get_path(bin_path):
2621
)
2722

2823

29-
def test_get_version(bin_path):
30-
with add_sys_path(str(bin_path)):
31-
ver_str = Executables.get_version("mf6", path=bin_path).partition(" ")
24+
def test_get_version():
25+
with add_sys_path(str(_bin_path)):
26+
ver_str = Executables.get_version("mf6", path=_bin_path).partition(" ")
3227
print(ver_str)
3328
version = int(ver_str[0].split(".")[0])
3429
assert version >= 6
3530

3631

3732
@pytest.fixture
38-
def exes(bin_path):
39-
return Executables(mf6=bin_path / f"mf6{_ext}")
33+
def exes():
34+
return Executables(mf6=_bin_path / f"mf6{_ext}")
4035

4136

42-
def test_executables_mapping(bin_path, exes):
37+
def test_executables_mapping(exes):
4338
print(exes.mf6)
44-
assert exes.mf6 == bin_path / f"mf6{_ext}"
39+
assert exes.mf6 == _bin_path / f"mf6{_ext}"
4540

4641

4742
def test_executables_usage(exes):

modflow_devtools/test/test_fixtures.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,11 @@ def test_test_model_mf6(test_model_mf6):
254254

255255
def test_test_model_mf5to6(test_model_mf5to6):
256256
assert isinstance(test_model_mf5to6, Path)
257-
assert len(list(test_model_mf5to6.glob("*.nam"))) >= 1
257+
assert any(list(test_model_mf5to6.glob("*.nam")))
258258

259259

260260
def test_large_test_model(large_test_model):
261261
assert isinstance(large_test_model, Path)
262-
assert (large_test_model / "mfsim.nam").is_file()
262+
assert (large_test_model / "mfsim.nam").is_file() or any(
263+
list(large_test_model.glob("*.nam"))
264+
)

modflow_devtools/test/test_misc.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,59 @@
1-
def test_set_dir():
2-
pass
1+
import os
2+
from os import environ
3+
from pathlib import Path
4+
5+
import pytest
6+
from modflow_devtools.misc import get_model_paths, get_packages, set_dir
7+
8+
9+
def test_set_dir(tmp_path):
10+
assert Path(os.getcwd()) != tmp_path
11+
with set_dir(tmp_path):
12+
assert Path(os.getcwd()) == tmp_path
13+
assert Path(os.getcwd()) != tmp_path
14+
15+
16+
_repos_path = Path(environ.get("REPOS_PATH")).expanduser().absolute()
17+
_examples_repo_path = _repos_path / "modflow6-examples"
18+
_examples_path = _examples_repo_path / "examples"
19+
_example_paths = (
20+
sorted(list(_examples_path.glob("ex-*")))
21+
if _examples_path.is_dir()
22+
else []
23+
)
24+
25+
26+
@pytest.mark.skipif(not any(_example_paths), reason="examples not found")
27+
def test_has_packages():
28+
example_path = _example_paths[0]
29+
packages = get_packages(example_path / "mfsim.nam")
30+
assert set(packages) == {"tdis", "gwf", "ims"}
31+
32+
33+
@pytest.mark.skipif(not any(_example_paths), reason="examples not found")
34+
def test_get_model_paths():
35+
paths = get_model_paths(_examples_path)
36+
assert len(paths) == 127
37+
38+
paths = get_model_paths(_examples_path, namefile="*.nam")
39+
assert len(paths) == 339
40+
41+
42+
def test_get_model_paths_exclude_patterns():
43+
paths = get_model_paths(_examples_path, excluded=["gwt"])
44+
assert len(paths) == 63
45+
46+
47+
def test_get_model_paths_select_prefix():
48+
paths = get_model_paths(_examples_path, prefix="ex2")
49+
assert not any(paths)
50+
51+
52+
def test_get_model_paths_select_patterns():
53+
paths = get_model_paths(_examples_path, selected=["gwf"])
54+
assert len(paths) == 70
55+
56+
57+
def test_get_model_paths_select_packages():
58+
paths = get_model_paths(_examples_path, namefile="*.nam", packages=["wel"])
59+
assert len(paths) == 64

0 commit comments

Comments
 (0)