Skip to content

Commit dc06561

Browse files
committed
ruff
1 parent b4c6fc0 commit dc06561

4 files changed

Lines changed: 105 additions & 134 deletions

File tree

autotest/test_dfns.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
Array,
1212
Double,
1313
FieldBase,
14-
FieldV2,
1514
Integer,
1615
Keyword,
1716
Record,
1817
String,
19-
Union,
2018
)
2119
from modflow_devtools.markers import requires_pkg
2220

autotest/test_dfns_schema.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,15 @@
77
from modflow_devtools.dfns.schema.v2 import (
88
Array,
99
Block,
10-
ComponentBase,
1110
DfnSpec,
1211
Double,
1312
Integer,
14-
Keyword,
1513
List,
1614
Model,
1715
Package,
1816
Record,
1917
Simulation,
2018
String,
21-
Union,
2219
_collect_explicit_dims,
2320
_known_dims_for,
2421
_names_in_expr,
@@ -28,7 +25,6 @@
2825
_validate_sum_call,
2926
)
3027

31-
3228
# ── Helpers ───────────────────────────────────────────────────────────────────
3329

3430

@@ -41,9 +37,7 @@ def _dim_block(*names: str) -> Block:
4137

4238

4339
def _pkg(name: str, blocks=None, derived_dims=None, parent=None, **kw) -> Package:
44-
return Package(
45-
name=name, blocks=blocks, derived_dims=derived_dims, parent=parent, **kw
46-
)
40+
return Package(name=name, blocks=blocks, derived_dims=derived_dims, parent=parent, **kw)
4741

4842

4943
# ── _collect_explicit_dims ────────────────────────────────────────────────────
@@ -417,9 +411,7 @@ def test_dfnspec_children_of():
417411
chd = _pkg("gwf-chd", parent="gwf-nam")
418412
rch = _pkg("gwf-rch", parent="gwf-nam")
419413
sim = Simulation(name="sim-nam", blocks=None)
420-
spec = DfnSpec(
421-
components={"sim-nam": sim, "gwf-nam": gwf, "gwf-chd": chd, "gwf-rch": rch}
422-
)
414+
spec = DfnSpec(components={"sim-nam": sim, "gwf-nam": gwf, "gwf-chd": chd, "gwf-rch": rch})
423415
children = spec.children_of("gwf-nam")
424416
assert set(children) == {"gwf-chd", "gwf-rch"}
425417

@@ -491,8 +483,8 @@ def test_known_dims_includes_grid_dims():
491483
chd = _pkg("gwf-chd", parent="gwf-nam")
492484
spec2 = DfnSpec(components=dict(spec.components) | {"gwf-chd": chd})
493485
known = _known_dims_for(spec2, "gwf-chd")
494-
assert "nodes" in known # GRID_DIM_NAMESPACE
495-
assert "nlay" in known # from gwf-dis (sibling dis package)
486+
assert "nodes" in known # GRID_DIM_NAMESPACE
487+
assert "nlay" in known # from gwf-dis (sibling dis package)
496488

497489

498490
# ── _validate_shape_element: dim reference ────────────────────────────────────
@@ -521,9 +513,7 @@ def test_shape_element_valid_grid_dim():
521513

522514

523515
def test_shape_element_valid_derived_dim():
524-
arr, pkg, known = _make_ctx(
525-
{"nlay", "nrow", "ncol"}, derived={"nodes": "nlay * nrow * ncol"}
526-
)
516+
arr, pkg, known = _make_ctx({"nlay", "nrow", "ncol"}, derived={"nodes": "nlay * nrow * ncol"})
527517
_validate_shape_element("nodes", arr, pkg, None, known)
528518

529519

@@ -585,7 +575,7 @@ def test_shape_element_valid_row_level_lookup():
585575

586576

587577
def test_shape_element_lookup_on_top_level_array_raises():
588-
arr, enc, pkg, known = _lookup_ctx()
578+
arr, _enc, pkg, known = _lookup_ctx()
589579
with pytest.raises(ValueError, match="not inside a record"):
590580
_validate_shape_element("packagedata.nlakeconn(lakeno)", arr, pkg, None, known)
591581

modflow_devtools/dfns/__init__.py

Lines changed: 65 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
cast,
1717
)
1818

19-
_IDENT_RE = re.compile(r"^[A-Za-z_]\w*$")
20-
2119
import tomli
2220
from boltons.dictutils import OMD
2321
from packaging.version import Version
@@ -45,13 +43,17 @@
4543
Integer,
4644
Keyword,
4745
List,
48-
Path as PathField,
4946
Record,
5047
String,
5148
Union,
5249
)
50+
from modflow_devtools.dfns.schema.v2 import (
51+
Path as PathField,
52+
)
5353
from modflow_devtools.misc import try_literal_eval
5454

55+
_IDENT_RE = re.compile(r"^[A-Za-z_]\w*$")
56+
5557
# Experimental API warning
5658
warnings.warn(
5759
"The modflow_devtools.dfns API is experimental and may change or be "
@@ -112,9 +114,7 @@
112114

113115
class _VersionAnnotation:
114116
@classmethod
115-
def __get_pydantic_core_schema__(
116-
cls, source: Any, handler: GetCoreSchemaHandler
117-
) -> Any:
117+
def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> Any:
118118
return core_schema.no_info_plain_validator_function(
119119
lambda v: Version(str(v)) if not isinstance(v, Version) else v,
120120
serialization=core_schema.to_string_ser_schema(),
@@ -209,9 +209,7 @@ def from_dict(cls, d: dict, strict: bool = False) -> "Dfn":
209209
for field_name, field_data in block_data.items():
210210
if isinstance(field_data, dict):
211211
if is_v1:
212-
block_fields[field_name] = FieldV1.from_dict(
213-
field_data, strict=strict
214-
)
212+
block_fields[field_name] = FieldV1.from_dict(field_data, strict=strict)
215213
else:
216214
block_fields[field_name] = FieldBase.from_dict(
217215
field_data, strict=strict
@@ -267,9 +265,7 @@ def map_period_block(dfn: "Dfn", block: dict) -> dict:
267265
list_field: List = fields_list[0]
268266
block.pop(list_field.name)
269267
item = list_field.item
270-
columns: dict = dict(
271-
item.fields if isinstance(item, Record) else item.arms
272-
)
268+
columns: dict = dict(item.fields if isinstance(item, Record) else item.arms)
273269
else:
274270
columns = dict(block)
275271

@@ -290,6 +286,7 @@ def map_period_block(dfn: "Dfn", block: dict) -> dict:
290286
continue
291287

292288
from modflow_devtools.dfns.schema.v2 import GRID_DIM_NAMESPACE
289+
293290
old_dims = list(column.shape) if isinstance(column, Array) else []
294291
new_dims = ["nper"]
295292
if cellid:
@@ -347,19 +344,17 @@ def _map_field(f: FieldV1) -> FieldBase:
347344
time_series: bool = _to_bool(fd.get("time_series"), False)
348345
valid = fd.get("valid")
349346
default = (
350-
try_literal_eval(fd.get("default"))
351-
if _type != "string"
352-
else fd.get("default")
347+
try_literal_eval(fd.get("default")) if _type != "string" else fd.get("default")
353348
)
354349

355-
common = dict(
356-
name=_name,
357-
longname=longname,
358-
description=description,
359-
optional=optional,
360-
default=default,
361-
developmode=developmode,
362-
)
350+
common = {
351+
"name": _name,
352+
"longname": longname,
353+
"description": description,
354+
"optional": optional,
355+
"default": default,
356+
"developmode": developmode,
357+
}
363358

364359
_COL_FK_RE = re.compile(r"^([A-Za-z_]\w*)\(([A-Za-z_]\w*)\)$")
365360

@@ -375,7 +370,9 @@ def _parse_shape(s: str) -> list[str]:
375370
# v1 discretization-conditional (e.g. "ncol*nrow; ncpl")
376371
# → canonical per-layer count; DIS derives ncpl = nrow*ncol.
377372
result.append("ncpl")
378-
elif elem in ("any1d", "unknown") or elem.startswith("<") or elem.startswith(">"):
373+
elif (
374+
elem in ("any1d", "unknown") or elem.startswith("<") or elem.startswith(">")
375+
):
379376
# v1 pseudo-elements with no v2 shape equivalent:
380377
# any1d — inline array of runtime-determined length
381378
# (read to end of record); dtype-agnostic.
@@ -387,10 +384,11 @@ def _parse_shape(s: str) -> list[str]:
387384
# Resolve the block by searching for the integer field.
388385
col_name = m.group(1)
389386
block_name = next(
390-
(fi.block for fi in fields.values(multi=True)
391-
if fi.name == col_name
392-
and fi.type == "integer"
393-
and fi.in_record),
387+
(
388+
fi.block
389+
for fi in fields.values(multi=True)
390+
if fi.name == col_name and fi.type == "integer" and fi.in_record
391+
),
394392
None,
395393
)
396394
if block_name:
@@ -402,9 +400,12 @@ def _parse_shape(s: str) -> list[str]:
402400
# string array's name so _mark_string_dim_arrays can mark
403401
# it dimension="component" and validation resolves it.
404402
provider = next(
405-
(fi.name for fi in fields.values(multi=True)
406-
if fi.type == "string"
407-
and (fi.shape or "").strip() in (f"({elem})", elem)),
403+
(
404+
fi.name
405+
for fi in fields.values(multi=True)
406+
if fi.type == "string"
407+
and (fi.shape or "").strip() in (f"({elem})", elem)
408+
),
408409
None,
409410
)
410411
result.append(provider if provider else elem)
@@ -425,6 +426,7 @@ def _to_scalar() -> FieldBase:
425426
)
426427
if _type == "integer":
427428
from modflow_devtools.dfns.schema.v2 import GRID_DIM_NAMESPACE
429+
428430
v = [int(x) for x in valid] if valid else None
429431
if fd.get("block") == "dimensions":
430432
if _name in GRID_DIM_NAMESPACE:
@@ -444,9 +446,7 @@ def _to_scalar() -> FieldBase:
444446
dimension=_dim_scope,
445447
)
446448
if _type in ("double", "double precision"):
447-
return Double(
448-
**common, netcdf=netcdf, tagged=tagged, time_series=time_series
449-
)
449+
return Double(**common, netcdf=netcdf, tagged=tagged, time_series=time_series)
450450
raise TypeError(f"Unsupported scalar type: {_type!r}")
451451

452452
def _row_field() -> "Record | Union":
@@ -461,13 +461,15 @@ def _row_field() -> "Record | Union":
461461
]
462462

463463
# Single explicit record or keystring
464-
if len(item_names) == 1 and item_types and (
465-
(item_types[0] or "").startswith("record")
466-
or (item_types[0] or "").startswith("keystring")
467-
):
468-
mapped = MapV1To2.map_field(
469-
dfn, next(iter(fields.getlist(item_names[0])))
464+
if (
465+
len(item_names) == 1
466+
and item_types
467+
and (
468+
(item_types[0] or "").startswith("record")
469+
or (item_types[0] or "").startswith("keystring")
470470
)
471+
):
472+
mapped = MapV1To2.map_field(dfn, next(iter(fields.getlist(item_names[0]))))
471473
if isinstance(mapped, (Record, Union)):
472474
return mapped
473475
raise TypeError(
@@ -498,8 +500,7 @@ def _row_field() -> "Record | Union":
498500
return Record(
499501
name=_name,
500502
description=(
501-
(description or "").replace("is the list of", "is the record of")
502-
or None
503+
(description or "").replace("is the list of", "is the record of") or None
503504
),
504505
fields=children,
505506
)
@@ -659,9 +660,11 @@ def _mark(fields: dict) -> dict:
659660
new_fields = _mark(f.fields)
660661
if local_dims:
661662
new_fields = {
662-
fn: (sf.model_copy(update={"dimension": "record"})
663-
if fn in local_dims and isinstance(sf, Integer)
664-
else sf)
663+
fn: (
664+
sf.model_copy(update={"dimension": "record"})
665+
if fn in local_dims and isinstance(sf, Integer)
666+
else sf
667+
)
665668
for fn, sf in new_fields.items()
666669
}
667670
f = f.model_copy(update={"fields": new_fields})
@@ -674,9 +677,11 @@ def _mark(fields: dict) -> dict:
674677
new_item_fields = _mark(item.fields)
675678
if local_dims:
676679
new_item_fields = {
677-
fn: (sf.model_copy(update={"dimension": "record"})
678-
if fn in local_dims and isinstance(sf, Integer)
679-
else sf)
680+
fn: (
681+
sf.model_copy(update={"dimension": "record"})
682+
if fn in local_dims and isinstance(sf, Integer)
683+
else sf
684+
)
680685
for fn, sf in new_item_fields.items()
681686
}
682687
new_item = item.model_copy(update={"fields": new_item_fields})
@@ -840,27 +845,23 @@ def to_component(dfn: "Dfn") -> "Any":
840845
)
841846

842847
name = dfn.name
843-
blocks: "dict[str, Block] | None" = None
848+
blocks: dict[str, Block] | None = None
844849
if dfn.blocks:
845850
blocks = {
846851
block_name: Block(
847852
name=block_name,
848-
fields={
849-
k: v
850-
for k, v in block_fields.items()
851-
if isinstance(v, FieldBase)
852-
},
853+
fields={k: v for k, v in block_fields.items() if isinstance(v, FieldBase)},
853854
)
854855
for block_name, block_fields in dfn.blocks.items()
855856
if isinstance(block_fields, dict)
856857
}
857858

858-
common: dict[str, Any] = dict(
859-
name=name,
860-
blocks=blocks,
861-
parent=dfn.parent,
862-
schema_version=dfn.schema_version,
863-
)
859+
common: dict[str, Any] = {
860+
"name": name,
861+
"blocks": blocks,
862+
"parent": dfn.parent,
863+
"schema_version": dfn.schema_version,
864+
}
864865
if name == "sim-nam":
865866
return Simulation(**common)
866867
if name.endswith("-nam"):
@@ -911,10 +912,7 @@ def load(f, format: str = "dfn", **kwargs) -> Dfn:
911912
name = kwargs.pop("name")
912913
fields_parsed, meta = parse_dfn(f, **kwargs)
913914
blocks = {
914-
block_name: {
915-
field_dict["name"]: FieldV1.from_dict(field_dict)
916-
for field_dict in block
917-
}
915+
block_name: {field_dict["name"]: FieldV1.from_dict(field_dict) for field_dict in block}
918916
for block_name, block in groupby(
919917
fields_parsed.values(multi=True), lambda fd: fd["block"]
920918
)
@@ -945,9 +943,7 @@ def load(f, format: str = "dfn", **kwargs) -> Dfn:
945943

946944
if (expected_name := kwargs.pop("name", None)) is not None:
947945
if dfn_fields["name"] != expected_name:
948-
raise ValueError(
949-
f"DFN name mismatch: {expected_name} != {dfn_fields['name']}"
950-
)
946+
raise ValueError(f"DFN name mismatch: {expected_name} != {dfn_fields['name']}")
951947

952948
blocks = {}
953949
for section_name, section_data in data.items():
@@ -1052,11 +1048,7 @@ def to_tree(dfns: Dfns) -> Dfn:
10521048

10531049
def _build_tree(node_name: str) -> Dfn:
10541050
node = dfns[node_name]
1055-
children = {
1056-
name: dfn
1057-
for name, dfn in dfns.items()
1058-
if dfn.parent == node_name
1059-
}
1051+
children = {name: dfn for name, dfn in dfns.items() if dfn.parent == node_name}
10601052
if children:
10611053
node = node.model_copy(
10621054
update={"children": {name: _build_tree(name) for name in children}}
@@ -1065,9 +1057,7 @@ def _build_tree(node_name: str) -> Dfn:
10651057

10661058
return _build_tree(next(iter(roots.keys())))
10671059
case _:
1068-
raise ValueError(
1069-
f"Unsupported schema version: {schema_version}. Expected 1 or 2."
1070-
)
1060+
raise ValueError(f"Unsupported schema version: {schema_version}. Expected 1 or 2.")
10711061

10721062

10731063
def to_flat(dfn: Dfn) -> Dfns:

0 commit comments

Comments
 (0)