Skip to content

Commit aafe9d0

Browse files
authored
feat(advanced package support): refactoring for advanced package support
2 parents 2954151 + 7cf336c commit aafe9d0

6 files changed

Lines changed: 817 additions & 315 deletions

File tree

modflowapi/extensions/advpaks.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import numpy as np
2+
3+
from .data import ListInput
4+
from .pakbase import AdvancedPackage
5+
6+
7+
class SfrPakage(AdvancedPackage):
8+
"""
9+
Container for SFR and SFR like packages
10+
11+
Parameters
12+
----------
13+
model : ApiModel
14+
modflowapi model object
15+
pkg_type : str
16+
package type. Ex. "SFR"
17+
pkg_name : str
18+
package name (in the mf6 variables)
19+
sim_package : bool
20+
boolean flag for simulation level packages. Ex. TDIS, IMS
21+
"""
22+
23+
def __init__(self, model, pkg_type, pkg_name, sim_package=False):
24+
super().__init__(model, pkg_type, pkg_name, sim_package)
25+
26+
self._diversion_var_arrs = []
27+
self._set_advanced_variable_addrs("diversions", "_diversion_var_addrs")
28+
self._diversion_vars = ListInput(self, self._diversion_var_arrs, spd=False)
29+
30+
@property
31+
def diversions(self):
32+
return self._diversion_vars
33+
34+
@diversions.setter
35+
def diversions(self, recarray):
36+
"""
37+
Setter object to update the diversions data
38+
39+
"""
40+
if isinstance(recarray, np.recarray):
41+
self._diversion_vars.values = recarray
42+
elif isinstance(recarray, ListInput):
43+
self._diversion_vars.values = recarray.values
44+
elif recarray is None:
45+
self._diversion_vars.values = recarray
46+
else:
47+
raise TypeError(f"{type(recarray)} is not a supported diversions type")
48+
49+
50+
class LakPackage(AdvancedPackage):
51+
"""
52+
Container for LAK and LAK like packages
53+
54+
Parameters
55+
----------
56+
model : ApiModel
57+
modflowapi model object
58+
pkg_type : str
59+
package type. Ex. "LAK"
60+
pkg_name : str
61+
package name (in the mf6 variables)
62+
sim_package : bool
63+
boolean flag for simulation level packages. Ex. TDIS, IMS
64+
"""
65+
66+
def __init__(self, model, pkg_type, pkg_name, sim_package=False):
67+
super().__init__(model, pkg_type, pkg_name, sim_package)
68+
69+
70+
class MawPackage(AdvancedPackage):
71+
"""
72+
Container for MAW and MAW like packages
73+
74+
Parameters
75+
----------
76+
model : ApiModel
77+
modflowapi model object
78+
pkg_type : str
79+
package type. Ex. "MAW"
80+
pkg_name : str
81+
package name (in the mf6 variables)
82+
sim_package : bool
83+
boolean flag for simulation level packages. Ex. TDIS, IMS
84+
"""
85+
86+
def __init__(self, model, pkg_type, pkg_name, sim_package=False):
87+
super().__init__(model, pkg_type, pkg_name, sim_package)
88+
89+
90+
class UzfPackage(AdvancedPackage):
91+
"""
92+
Container for UZF and UZF like packages
93+
94+
Parameters
95+
----------
96+
model : ApiModel
97+
modflowapi model object
98+
pkg_type : str
99+
package type. Ex. "UZF"
100+
pkg_name : str
101+
package name (in the mf6 variables)
102+
sim_package : bool
103+
boolean flag for simulation level packages. Ex. TDIS, IMS
104+
"""
105+
106+
def __init__(self, model, pkg_type, pkg_name, sim_package=False):
107+
super().__init__(model, pkg_type, pkg_name, sim_package)

modflowapi/extensions/apimodel.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import numpy as np
22

3+
from .datamodel import get_package_type, gridshape
34
from .pakbase import AdvancedPackage, ArrayPackage, ListPackage, package_factory
45

5-
gridshape = {"dis": ["nlay", "nrow", "ncol"], "disu": ["nlay", "ncpl"]}
6-
76

87
class ApiMbase:
98
"""
@@ -15,16 +14,16 @@ class ApiMbase:
1514
initialized ModflowApi object
1615
name : str
1716
modflow model name. ex. "GWF_1", "GWF-GWF_1"
18-
pkg_types : dict
19-
dictionary of package types and ApiPackage class types
17+
pkg_types : None, dict
18+
optional dictionary of package types and ApiPackage class types
2019
"""
2120

22-
def __init__(self, mf6, name, pkg_types):
21+
def __init__(self, mf6, name, pkg_types=None):
2322
self.mf6 = mf6
2423
self.name = name
2524
self._pkg_names = None
2625
self._pak_type = None
27-
self.pkg_types = pkg_types
26+
self._pkg_types = pkg_types
2827
self.package_dict = {}
2928
self._set_package_names()
3029
self._create_package_list()
@@ -73,10 +72,13 @@ def _create_package_list(self):
7372
"""
7473
for ix, pkg_name in enumerate(self._pkg_names):
7574
pkg_type = self._pak_type[ix].lower()
76-
if pkg_type in self.pkg_types:
77-
basepackage = self.pkg_types[pkg_type]
75+
if self._pkg_types is None:
76+
basepackage = get_package_type(pkg_type)
7877
else:
79-
basepackage = AdvancedPackage
78+
if pkg_type in self._pkg_types:
79+
basepackage = self._pkg_types[pkg_type]
80+
else:
81+
basepackage = AdvancedPackage
8082

8183
package = package_factory(pkg_type, basepackage)
8284
adj_pkg_name = "".join(pkg_type.split("-"))
@@ -135,41 +137,14 @@ def __init__(self, mf6, name):
135137
else:
136138
raise AssertionError(f"Unrecognized discretization type {grid_type}")
137139

138-
pkg_types = {
139-
"dis": ArrayPackage,
140-
"chd": ListPackage,
141-
"drn": ListPackage,
142-
"evt": ListPackage,
143-
"ghb": ListPackage,
144-
"ic": ArrayPackage,
145-
"npf": ArrayPackage,
146-
"rch": ListPackage,
147-
"riv": ListPackage,
148-
"sto": ArrayPackage,
149-
"wel": ListPackage,
150-
# gwt
151-
"dsp": ArrayPackage,
152-
"cnc": ListPackage,
153-
"ist": ArrayPackage,
154-
"mst": ArrayPackage,
155-
"src": ListPackage,
156-
# gwe
157-
"cnd": ArrayPackage,
158-
"est": ArrayPackage,
159-
"cpt": ListPackage,
160-
"esl": ListPackage,
161-
# prt
162-
"mip": ArrayPackage,
163-
}
164-
165140
self.allow_convergence = True
166141
self._shape = None
167142
self._size = None
168143
self._nodetouser = None
169144
self._usertonode = None
170145
self._iteration = 0
171146

172-
super().__init__(mf6, name, pkg_types)
147+
super().__init__(mf6, name)
173148

174149
def __repr__(self):
175150
s = f"{self.name}, "

modflowapi/extensions/apisimulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def load(mf6):
311311
i[:-1].lower() for ix, i in enumerate(mf6.get_value("__INPUT__/SIM/NAM/SLNTYPE")) if idp_names[ix]
312312
]
313313

314-
tmpmdl = ApiMbase(mf6, "", {})
314+
tmpmdl = ApiMbase(mf6, "")
315315
solution_names = list(set(solution_names))
316316
solution_dict = {}
317317
for name in solution_names:

0 commit comments

Comments
 (0)