Skip to content

Commit f6cc9bb

Browse files
authored
fix(models): exclude output files from registries (#274)
Exclude some known output file extensions during model registry creation. Also fix the locking mechanism around the cache.
1 parent 5dce2af commit f6cc9bb

2 files changed

Lines changed: 95 additions & 37 deletions

File tree

autotest/test_models.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ class TestSync:
272272

273273
def test_sync_single_source_single_ref(self):
274274
"""Test syncing a single source/ref."""
275-
_DEFAULT_CACHE.clear()
275+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
276276

277277
source = ModelSourceRepo(
278278
repo=TEST_REPO,
@@ -287,7 +287,7 @@ def test_sync_single_source_single_ref(self):
287287

288288
def test_sync_creates_cache(self):
289289
"""Test that sync creates cached registry."""
290-
_DEFAULT_CACHE.clear()
290+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
291291
assert not _DEFAULT_CACHE.has(TEST_SOURCE_NAME, TEST_REF)
292292

293293
source = ModelSourceRepo(
@@ -301,7 +301,7 @@ def test_sync_creates_cache(self):
301301

302302
def test_sync_skip_cached(self):
303303
"""Test that sync skips already-cached registries."""
304-
_DEFAULT_CACHE.clear()
304+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
305305

306306
source = ModelSourceRepo(
307307
repo=TEST_REPO,
@@ -320,7 +320,7 @@ def test_sync_skip_cached(self):
320320

321321
def test_sync_force(self):
322322
"""Test that force flag re-syncs cached registries."""
323-
_DEFAULT_CACHE.clear()
323+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
324324

325325
source = ModelSourceRepo(
326326
repo=TEST_REPO,
@@ -339,7 +339,7 @@ def test_sync_force(self):
339339

340340
def test_sync_via_source_method(self):
341341
"""Test syncing via ModelSourceRepo.sync() method."""
342-
_DEFAULT_CACHE.clear()
342+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
343343

344344
# Create source with test repo override
345345
source = ModelSourceRepo(
@@ -356,7 +356,7 @@ def test_sync_via_source_method(self):
356356

357357
def test_source_is_synced_method(self):
358358
"""Test ModelSourceRepo.is_synced() method."""
359-
_DEFAULT_CACHE.clear()
359+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
360360

361361
source = ModelSourceRepo(
362362
repo=TEST_REPO,
@@ -375,7 +375,7 @@ def test_source_is_synced_method(self):
375375

376376
def test_source_list_synced_refs_method(self):
377377
"""Test ModelSourceRepo.list_synced_refs() method."""
378-
_DEFAULT_CACHE.clear()
378+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
379379

380380
source = ModelSourceRepo(
381381
repo=TEST_REPO,
@@ -400,7 +400,7 @@ class TestRegistry:
400400
@pytest.fixture(scope="class")
401401
def synced_registry(self):
402402
"""Fixture that syncs and loads a registry once for all tests."""
403-
_DEFAULT_CACHE.clear()
403+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
404404
source = ModelSourceRepo(
405405
repo=TEST_REPO,
406406
name=TEST_SOURCE_NAME,
@@ -474,7 +474,7 @@ def test_cli_list_empty(self, capsys):
474474

475475
def test_cli_list_with_cache(self, capsys):
476476
"""Test 'list' command with cached registries."""
477-
_DEFAULT_CACHE.clear()
477+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
478478
source = ModelSourceRepo(
479479
repo=TEST_REPO,
480480
name=TEST_SOURCE_NAME,
@@ -506,7 +506,7 @@ class TestIntegration:
506506
def test_full_workflow(self):
507507
"""Test complete workflow: discover -> cache -> load."""
508508
# Clear cache
509-
_DEFAULT_CACHE.clear()
509+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
510510

511511
# Create test source
512512
source = ModelSourceRepo(
@@ -530,7 +530,7 @@ def test_full_workflow(self):
530530

531531
def test_sync_and_list_models(self):
532532
"""Test syncing and listing available models."""
533-
_DEFAULT_CACHE.clear()
533+
_DEFAULT_CACHE.clear(source=TEST_SOURCE_NAME, ref=TEST_REF)
534534

535535
# Sync
536536
source = ModelSourceRepo(

modflow_devtools/models/__init__.py

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,51 @@
4141
_DEFAULT_REGISTRY_FILE_NAME = "registry.toml"
4242
"""The default registry file name"""
4343

44+
_EXCLUDED_PATTERNS = [".DS_Store", "compare"]
45+
"""Filename patterns to exclude from registry (substring match)"""
46+
47+
_OUTPUT_FILE_EXTENSIONS = [
48+
".lst", # list file
49+
".hds", # head file
50+
".hed", # head file
51+
".cbb", # budget file
52+
".cbc", # budget file
53+
".bud", # budget file
54+
".ddn", # drawdown file
55+
".ucn", # concentration file
56+
".obs", # observation file
57+
".glo", # global listing file
58+
]
59+
"""Output file extensions to exclude from model input registry"""
60+
61+
62+
def _should_exclude_file(path: Path) -> bool:
63+
"""
64+
Check if a file should be excluded from the registry.
65+
66+
Excludes files matching patterns in _EXCLUDED_PATTERNS (substring match)
67+
or with extensions in _OUTPUT_FILE_EXTENSIONS.
68+
69+
Parameters
70+
----------
71+
path : Path
72+
File path to check
73+
74+
Returns
75+
-------
76+
bool
77+
True if file should be excluded, False otherwise
78+
"""
79+
# Check filename patterns (substring match)
80+
if any(pattern in path.name for pattern in _EXCLUDED_PATTERNS):
81+
return True
82+
83+
# Check output file extensions (exact suffix match)
84+
if path.suffix.lower() in _OUTPUT_FILE_EXTENSIONS:
85+
return True
86+
87+
return False
88+
4489

4590
class ModelInputFile(BaseModel):
4691
"""
@@ -208,16 +253,25 @@ def save(self, registry: ModelRegistry, source: str, ref: str) -> Path:
208253
Path to cached registry file
209254
"""
210255
cache_dir = self.get_registry_cache_dir(source, ref)
211-
cache_dir.mkdir(parents=True, exist_ok=True)
212-
213256
registry_file = cache_dir / _DEFAULT_REGISTRY_FILE_NAME
214257

215-
# Convert registry to dict and clean None/empty values before serializing to TOML
216-
registry_dict = registry.model_dump(mode="json", by_alias=True, exclude_none=True)
217-
registry_dict = remap(registry_dict, visit=drop_none_or_empty)
258+
# Use a global lock to prevent race conditions with parallel tests/clear()
259+
lock_file = self.root / ".cache_operation.lock"
260+
lock_file.parent.mkdir(parents=True, exist_ok=True)
218261

219-
with registry_file.open("wb") as f:
220-
tomli_w.dump(registry_dict, f)
262+
with FileLock(str(lock_file), timeout=30):
263+
cache_dir.mkdir(parents=True, exist_ok=True)
264+
265+
# Convert registry to dict and clean None/empty values before serializing to TOML
266+
registry_dict = registry.model_dump(mode="json", by_alias=True, exclude_none=True)
267+
268+
# Use remap to recursively filter out None and empty values
269+
# This is essential for TOML serialization which cannot handle None
270+
registry_dict = remap(registry_dict, visit=drop_none_or_empty)
271+
272+
# Write to file
273+
with registry_file.open("wb") as f:
274+
tomli_w.dump(registry_dict, f)
221275

222276
return registry_file
223277

@@ -300,21 +354,26 @@ def _rmtree_with_retry(path, max_retries=5, delay=0.5):
300354
else:
301355
raise
302356

303-
if source and ref:
304-
# Clear specific source/ref
305-
cache_dir = self.get_registry_cache_dir(source, ref)
306-
if cache_dir.exists():
307-
_rmtree_with_retry(cache_dir)
308-
elif source:
309-
# Clear all refs for a source
310-
source_dir = self.root / "registries" / source
311-
if source_dir.exists():
312-
_rmtree_with_retry(source_dir)
313-
else:
314-
# Clear all registries
315-
registries_dir = self.root / "registries"
316-
if registries_dir.exists():
317-
_rmtree_with_retry(registries_dir)
357+
# Use a global lock to prevent race conditions with parallel tests/save()
358+
lock_file = self.root / ".cache_operation.lock"
359+
lock_file.parent.mkdir(parents=True, exist_ok=True)
360+
361+
with FileLock(str(lock_file), timeout=30):
362+
if source and ref:
363+
# Clear specific source/ref
364+
cache_dir = self.get_registry_cache_dir(source, ref)
365+
if cache_dir.exists():
366+
_rmtree_with_retry(cache_dir)
367+
elif source:
368+
# Clear all refs for a source
369+
source_dir = self.root / "registries" / source
370+
if source_dir.exists():
371+
_rmtree_with_retry(source_dir)
372+
else:
373+
# Clear all registries
374+
registries_dir = self.root / "registries"
375+
if registries_dir.exists():
376+
_rmtree_with_retry(registries_dir)
318377

319378
def list(self) -> list[tuple[str, str]]:
320379
"""
@@ -803,7 +862,7 @@ class LocalRegistry(ModelRegistry):
803862
presence of a namefile) and registers corresponding input files.
804863
"""
805864

806-
exclude: ClassVar = [".DS_Store", "compare"]
865+
exclude: ClassVar = _EXCLUDED_PATTERNS # For backwards compatibility
807866

808867
# Non-Pydantic instance variable for tracking indexed paths
809868
_paths: set[Path]
@@ -870,7 +929,7 @@ def index(
870929
self.examples[name] = []
871930
self.examples[name].append(model_name)
872931
for p in model_path.rglob("*"):
873-
if not p.is_file() or any(e in p.name for e in LocalRegistry.exclude):
932+
if not p.is_file() or _should_exclude_file(p):
874933
continue
875934
relpath = p.expanduser().absolute().relative_to(path)
876935
name = "/".join(relpath.parts)
@@ -1133,7 +1192,6 @@ def index(
11331192
files: dict[str, dict[str, str | None]] = {}
11341193
models: dict[str, list[str]] = {}
11351194
examples: dict[str, list[str]] = {}
1136-
exclude = [".DS_Store", "compare"]
11371195
is_zip = url.endswith((".zip", ".tar")) if url else False
11381196

11391197
model_paths = get_model_paths(path, namefile=namefile)
@@ -1149,7 +1207,7 @@ def index(
11491207
examples[name] = []
11501208
examples[name].append(model_name)
11511209
for p in model_path.rglob("*"):
1152-
if not p.is_file() or any(e in p.name for e in exclude):
1210+
if not p.is_file() or _should_exclude_file(p):
11531211
continue
11541212
relpath = p.expanduser().resolve().absolute().relative_to(path)
11551213
name = "/".join(relpath.parts)

0 commit comments

Comments
 (0)