Skip to content

Commit 1557e3f

Browse files
authored
Merge pull request #635 from xylar/fix-xdmf
Fix index parsing in `mpas_to_xdmf` tool
2 parents 9d0fe85 + dbc12bc commit 1557e3f

2 files changed

Lines changed: 208 additions & 2 deletions

File tree

conda_package/mpas_tools/viz/mpas_to_xdmf/io.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,22 @@ def _parse_indices(index_string, dim_size):
300300
if not index_string:
301301
return []
302302
if ':' in index_string:
303-
parts = [int(p) if p else None for p in index_string.split(':')]
304-
return list(range(parts[0] or 0, parts[1] or dim_size, parts[2] or 1))
303+
# Support slice notation like ':', '0:10', '0:10:2', etc.
304+
parts = index_string.split(':')
305+
# Validate that parts has at most 3 elements
306+
if len(parts) > 3:
307+
raise ValueError(
308+
f"Invalid index string '{index_string}': too many colons. "
309+
'Expected at most two colons.'
310+
)
311+
# Pad parts to length 3 with empty strings if needed
312+
while len(parts) < 3:
313+
parts.append('')
314+
# Convert to int or None
315+
start = int(parts[0]) if parts[0] else 0
316+
stop = int(parts[1]) if parts[1] else dim_size
317+
step = int(parts[2]) if parts[2] else 1
318+
return list(range(start, stop, step))
305319
return [int(i) for i in index_string.split(',')]
306320

307321

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import os
2+
import sys
3+
4+
import numpy as np
5+
import pytest
6+
import xarray as xr
7+
8+
from mpas_tools.io import write_netcdf
9+
from mpas_tools.viz.mpas_to_xdmf.io import (
10+
_load_dataset,
11+
_parse_indices,
12+
_process_extra_dims,
13+
)
14+
from mpas_tools.viz.mpas_to_xdmf.mpas_to_xdmf import MpasToXdmf, main
15+
from mpas_tools.viz.mpas_to_xdmf.time import _set_time
16+
17+
from .util import get_test_data_file
18+
19+
TEST_MESH = get_test_data_file('mesh.QU.1920km.151026.nc')
20+
21+
22+
@pytest.mark.skipif(
23+
not os.path.exists(TEST_MESH), reason='Test mesh not available'
24+
)
25+
def test_load_mesh_only():
26+
converter = MpasToXdmf()
27+
converter.load(mesh_filename=TEST_MESH)
28+
assert isinstance(converter.ds, xr.Dataset)
29+
assert isinstance(converter.ds_mesh, xr.Dataset)
30+
# Should have mesh dimensions
31+
assert 'nCells' in converter.ds.dims
32+
33+
34+
@pytest.mark.skipif(
35+
not os.path.exists(TEST_MESH), reason='Test mesh not available'
36+
)
37+
def test_set_time_with_no_xtime():
38+
converter = MpasToXdmf()
39+
converter.load(mesh_filename=TEST_MESH)
40+
# Should create a 'Time' variable if 'Time' in dims
41+
if 'Time' in converter.ds.dims:
42+
assert 'Time' in converter.ds
43+
arr = converter.ds['Time'].values
44+
assert np.all(arr == np.arange(converter.ds.sizes['Time']))
45+
46+
47+
@pytest.mark.skipif(
48+
not os.path.exists(TEST_MESH), reason='Test mesh not available'
49+
)
50+
def test_convert_to_xdmf(tmp_path):
51+
converter = MpasToXdmf()
52+
variables = ['xCell', 'areaCell', 'cellsOnCell']
53+
extra_dims = {'maxEdges': [0]}
54+
converter.load(mesh_filename=TEST_MESH, variables=variables)
55+
out_dir = tmp_path / 'out'
56+
converter.convert_to_xdmf(str(out_dir), extra_dims=extra_dims)
57+
# Check that output files exist for cells
58+
assert (out_dir / 'fieldsOnCells.h5').exists()
59+
assert (out_dir / 'fieldsOnCells.xdmf').exists()
60+
61+
62+
@pytest.mark.skipif(
63+
not os.path.exists(TEST_MESH), reason='Test mesh not available'
64+
)
65+
def test_extra_dims(tmp_path):
66+
converter = MpasToXdmf()
67+
converter.load(mesh_filename=TEST_MESH)
68+
# Simulate an extra dimension if present
69+
extra_dims = {}
70+
for dim in converter.ds.dims:
71+
if dim not in ['Time', 'nCells', 'nEdges', 'nVertices']:
72+
extra_dims[dim] = [0]
73+
out_dir = tmp_path / 'out_extra'
74+
converter.convert_to_xdmf(str(out_dir), extra_dims=extra_dims)
75+
assert (out_dir / 'fieldsOnCells.h5').exists()
76+
77+
78+
@pytest.mark.skipif(
79+
not os.path.exists(TEST_MESH), reason='Test mesh not available'
80+
)
81+
def test_load_with_time_series_and_variables(tmp_path):
82+
ts1 = tmp_path / 'ts1.nc'
83+
ts2 = tmp_path / 'ts2.nc'
84+
85+
# Simulate a time series by adding xtime and area variables
86+
ds = xr.open_dataset(TEST_MESH)
87+
ds['xtime'] = ('Time', ['0001-01-01_00:00:00'])
88+
ds['area'] = (('Time', 'nCells'), ds.areaCell.values[None, :])
89+
write_netcdf(ds, ts1)
90+
ds['xtime'] = ('Time', ['0001-01-02_00:00:00'])
91+
write_netcdf(ds, ts2)
92+
93+
variables = ['areaCell', 'area']
94+
95+
converter = MpasToXdmf()
96+
converter.load(
97+
mesh_filename=TEST_MESH,
98+
time_series_filenames=[str(ts1), str(ts2)],
99+
variables=variables,
100+
)
101+
print(converter.ds)
102+
for var in variables:
103+
assert var in converter.ds.data_vars, (
104+
f'Variable {var} not found in dataset'
105+
)
106+
assert converter.ds.sizes['Time'] == 2
107+
108+
109+
@pytest.mark.skipif(
110+
not os.path.exists(TEST_MESH), reason='Test mesh not available'
111+
)
112+
def test_process_extra_dims_drop(tmp_path):
113+
converter = MpasToXdmf()
114+
converter.load(mesh_filename=TEST_MESH)
115+
116+
# drop all variables with extra dimensions
117+
extra_dims = {
118+
'maxEdges': [],
119+
'maxEdges2': [],
120+
'TWO': [],
121+
'vertexDegree': [],
122+
}
123+
124+
ds = _process_extra_dims(converter.ds, extra_dims=extra_dims)
125+
for dim in extra_dims:
126+
assert dim not in ds.dims, f'Dimension {dim} should be dropped'
127+
128+
129+
@pytest.mark.skipif(
130+
not os.path.exists(TEST_MESH), reason='Test mesh not available'
131+
)
132+
def test_set_time_invalid_xtime(tmp_path):
133+
ts1 = tmp_path / 'ts1.nc'
134+
# Simulate a time-depndent variable and add xtime
135+
ds = xr.open_dataset(TEST_MESH)
136+
ds['xtime'] = ('Time', ['0001-01-01_00:00:00'])
137+
ds['area'] = (('Time', 'nCells'), ds.areaCell.values[None, :])
138+
write_netcdf(ds, ts1)
139+
140+
converter = MpasToXdmf()
141+
converter.load(mesh_filename=TEST_MESH, time_series_filenames=[str(ts1)])
142+
# Should raise ValueError if xtime_var is not present
143+
with pytest.raises(ValueError):
144+
_set_time(ds=converter.ds, xtime_var='not_a_var')
145+
146+
147+
def test_parse_indices_invalid_cases():
148+
# Should raise on mixed slice/list
149+
with pytest.raises(ValueError):
150+
_parse_indices('1:3,5', 5)
151+
# Should raise on invalid string
152+
with pytest.raises(ValueError):
153+
_parse_indices('foo', 5)
154+
155+
156+
def test_parse_indices_valid_cases():
157+
# Empty list
158+
assert _parse_indices('', 5) == []
159+
# Single index
160+
assert _parse_indices('0', 5) == [0]
161+
# Comma-separated list
162+
assert _parse_indices('1,2,3', 5) == [1, 2, 3]
163+
# Slice notation
164+
assert _parse_indices('0:3', 5) == [0, 1, 2]
165+
# Slice with stride
166+
assert _parse_indices('0:5:2', 5) == [0, 2, 4]
167+
# Full slice
168+
assert _parse_indices(':', 4) == [0, 1, 2, 3]
169+
170+
171+
def test_main_cli(monkeypatch, tmp_path):
172+
# Test CLI entry point with minimal arguments
173+
mesh = TEST_MESH
174+
if not os.path.exists(mesh):
175+
pytest.skip('Test mesh not available')
176+
out_dir = tmp_path / 'cli_out'
177+
sys_argv = ['prog', '-m', mesh, '-o', str(out_dir), '-v', 'areaCell']
178+
monkeypatch.setattr(sys, 'argv', sys_argv)
179+
# Patch input to always return blank (skip extra dims)
180+
monkeypatch.setattr('builtins.input', lambda _: '')
181+
main()
182+
assert (out_dir / 'fieldsOnCells.h5').exists()
183+
184+
185+
@pytest.mark.skipif(
186+
not os.path.exists(TEST_MESH), reason='Test mesh not available'
187+
)
188+
def test_load_dataset_missing_variable():
189+
# Should not raise if variable is missing in mesh, but should raise if not
190+
# present at all
191+
with pytest.raises(KeyError):
192+
_load_dataset(TEST_MESH, None, ['not_a_var'], None)

0 commit comments

Comments
 (0)