Skip to content

Commit 6c88ede

Browse files
API: Ensure mas/pot3d/quantities fetchers return argument ordered filepaths
1 parent c4c1f11 commit 6c88ede

2 files changed

Lines changed: 30 additions & 5 deletions

File tree

psi_data/_static_assets.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,14 @@ def fetch_mas_data(*, domains: Optional[Iterable] = 'cor',
209209
PosixPath('.../cor/mhd/br002.h5')
210210
"""
211211
if domains is None:
212-
domains = {"cor", "hel"}
212+
domains = ("cor", "hel")
213213
else:
214-
domains = set(domains.replace(" ", "").lower().split(",") if isinstance(domains, str) else domains)
214+
domains = tuple(domains.replace(" ", "").lower().split(",") if isinstance(domains, str) else domains)
215215

216216
if variables is None:
217217
variables = set.intersection(*(DOM_VAR_MAP[dom] for dom in domains))
218218
else:
219-
variables = set(variables.replace(" ", "").lower().split(",") if isinstance(variables, str) else variables)
219+
variables = tuple(variables.replace(" ", "").lower().split(",") if isinstance(variables, str) else variables)
220220
req_pairs = product(domains, variables)
221221

222222
ext = HDF_EXT.get(hdf)
@@ -275,7 +275,7 @@ def fetch_mas_quantities(*, quantities: Optional[Iterable] = 'ch_pm', hdf: int =
275275
if quantities is None:
276276
quantities = DOM_VAR_MAP["quantities"]
277277
else:
278-
quantities = set(quantities.replace(" ", "").lower().split(",") if isinstance(quantities, str) else quantities)
278+
quantities = tuple(quantities.replace(" ", "").lower().split(",") if isinstance(quantities, str) else quantities)
279279

280280
ext = HDF_EXT.get(hdf)
281281
filepaths = {
@@ -333,7 +333,7 @@ def fetch_pot3d_data(*, variables: Optional[Iterable] = 'br', hdf: int = 5) -> o
333333
if variables is None:
334334
variables = DOM_VAR_MAP["pot3d"]
335335
else:
336-
variables = set(variables.replace(" ", "").lower().split(",") if isinstance(variables, str) else variables)
336+
variables = tuple(variables.replace(" ", "").lower().split(",") if isinstance(variables, str) else variables)
337337

338338
ext = HDF_EXT.get(hdf)
339339
filepaths = {

tests/test_static_assets.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,23 @@ def test_fetch_mas_data_keys_exist_in_registry(fake_fetch):
106106
assert key in sa.FETCHER.registry, f"{key} not in registry"
107107

108108

109+
def test_fetch_mas_returns_ordered_namedtuple(fake_fetch):
110+
paths = sa.fetch_mas_data(domains="cor", variables="br,bt,bp", hdf=5)
111+
assert paths._fields == ("cor_br", "cor_bt", "cor_bp")
112+
assert paths[0] == Path("H5CR2309_hmi_mast_mas_std_0201/cor/mhd/br002.h5")
113+
assert paths[1] == Path("H5CR2309_hmi_mast_mas_std_0201/cor/mhd/bt002.h5")
114+
assert paths[2] == Path("H5CR2309_hmi_mast_mas_std_0201/cor/mhd/bp002.h5")
115+
116+
117+
def test_fetch_mas_returns_ordered_namedtuple_with_multiple_domains(fake_fetch):
118+
paths = sa.fetch_mas_data(domains="cor,hel", variables="br,bt", hdf=5)
119+
assert paths._fields == ("cor_br", "cor_bt", "hel_br", "hel_bt")
120+
assert paths[0] == Path("H5CR2309_hmi_mast_mas_std_0201/cor/mhd/br002.h5")
121+
assert paths[1] == Path("H5CR2309_hmi_mast_mas_std_0201/cor/mhd/bt002.h5")
122+
assert paths[2] == Path("H5CR2309_hmi_mast_mas_std_0201/hel/mhd/br002.h5")
123+
assert paths[3] == Path("H5CR2309_hmi_mast_mas_std_0201/hel/mhd/bt002.h5")
124+
125+
109126
# --- fetch_pot3d_data --------------------------------------------------------
110127

111128
def test_fetch_pot3d_data_default(fake_fetch):
@@ -119,6 +136,14 @@ def test_fetch_pot3d_data_none_fetches_all_components(fake_fetch):
119136
assert set(paths._fields) == {"br", "bt", "bp"}
120137

121138

139+
def test_fetch_pot3d_returns_ordered_namedtuple(fake_fetch):
140+
paths = sa.fetch_pot3d_data(variables="br,bt,bp", hdf=5)
141+
assert paths._fields == ("br", "bt", "bp")
142+
assert paths[0] == Path("H5CR2309_hmi_mast_mas_std_0201/cor/pfss/br.h5")
143+
assert paths[1] == Path("H5CR2309_hmi_mast_mas_std_0201/cor/pfss/bt.h5")
144+
assert paths[2] == Path("H5CR2309_hmi_mast_mas_std_0201/cor/pfss/bp.h5")
145+
146+
122147
# --- fetch_mas_quantities ----------------------------------------------------
123148

124149
def test_fetch_mas_quantities_default(fake_fetch):

0 commit comments

Comments
 (0)