Skip to content

Commit 1ef9529

Browse files
committed
Use a utility to find test data files
1 parent 85f5c90 commit 1ef9529

5 files changed

Lines changed: 45 additions & 12 deletions

File tree

conda_package/tests/test_conversion.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,26 @@
55
from mpas_tools.io import write_netcdf
66
from mpas_tools.mesh.conversion import convert, cull, mask
77

8+
from .util import get_test_data_file
9+
810
matplotlib.use('Agg')
911
import xarray
1012
from geometric_features import read_feature_collection
1113

1214

1315
def test_conversion():
1416
dsMesh = xarray.open_dataset(
15-
'mesh_tools/mesh_conversion_tools/test/mesh.QU.1920km.151026.nc'
17+
get_test_data_file('mesh.QU.1920km.151026.nc')
1618
)
1719
dsMesh = convert(dsIn=dsMesh)
1820
write_netcdf(dsMesh, 'mesh.nc')
1921

20-
dsMask = xarray.open_dataset(
21-
'mesh_tools/mesh_conversion_tools/test/land_mask_final.nc'
22-
)
22+
dsMask = xarray.open_dataset(get_test_data_file('land_mask_final.nc'))
2323
dsCulled = cull(dsIn=dsMesh, dsMask=dsMask)
2424
write_netcdf(dsCulled, 'culled_mesh.nc')
2525

2626
fcMask = read_feature_collection(
27-
'mesh_tools/mesh_conversion_tools/test/Arctic_Ocean.geojson'
27+
get_test_data_file('Arctic_Ocean.geojson')
2828
)
2929
dsMask = mask(dsMesh=dsMesh, fcMask=fcMask)
3030
write_netcdf(dsMask, 'antarctic_mask.nc')

conda_package/tests/test_depth.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
write_time_varying_zmid,
1414
)
1515

16+
from .util import get_test_data_file
17+
1618

1719
def create_3d_mesh():
1820
outFileName = 'test_depth_mesh.nc'
1921
if os.path.exists(outFileName):
2022
dsMesh = xarray.open_dataset(outFileName)
2123
else:
2224
dsMesh = xarray.open_dataset(
23-
'mesh_tools/mesh_conversion_tools/test/mesh.QU.1920km.151026.nc'
25+
get_test_data_file('mesh.QU.1920km.151026.nc')
2426
)
2527
nCells = dsMesh.sizes['nCells']
2628
nVertLevels = 10

conda_package/tests/test_mesh_mask.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
compute_projection_grid_region_masks,
1919
)
2020

21+
from .util import get_test_data_file
22+
2123

2224
def test_compute_mpas_region_masks():
2325
ds_mesh, _ = _get_mesh()
@@ -205,9 +207,7 @@ def _get_pool():
205207

206208

207209
def _get_mesh():
208-
ds_mesh = xr.open_dataset(
209-
'mesh_tools/mesh_conversion_tools/test/mesh.QU.1920km.151026.nc'
210-
)
210+
ds_mesh = xr.open_dataset(get_test_data_file('mesh.QU.1920km.151026.nc'))
211211
earth_radius = constants['SHR_CONST_REARTH']
212212
ds_mesh.attrs['sphere_radius'] = earth_radius
213213
for coord in [

conda_package/tests/test_viz_transects.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
make_triangle_tree,
1313
)
1414

15+
from .util import get_test_data_file
16+
1517

1618
def test_mesh_to_triangles():
1719
_, ds_tris = _get_triangles()
@@ -129,9 +131,7 @@ def test_find_planar_transect_cells_and_weights():
129131

130132

131133
def _get_triangles():
132-
ds_mesh = xr.open_dataset(
133-
'mesh_tools/mesh_conversion_tools/test/mesh.QU.1920km.151026.nc'
134-
)
134+
ds_mesh = xr.open_dataset(get_test_data_file('mesh.QU.1920km.151026.nc'))
135135
earth_radius = constants['SHR_CONST_REARTH']
136136
ds_mesh.attrs['sphere_radius'] = earth_radius
137137
for coord in [

conda_package/tests/util.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
3+
4+
def get_test_data_file(filename):
5+
"""
6+
Get the full path to a data file in the tests/data directory.
7+
8+
Parameters
9+
----------
10+
filename : str
11+
The name of the data file.
12+
13+
Returns
14+
-------
15+
str
16+
The full relative path to the data file.
17+
"""
18+
19+
local_path = os.path.join(
20+
'mesh_tools', 'mesh_conversion_tools', 'test', filename
21+
)
22+
repo_path = os.path.join('..', '..', local_path)
23+
if os.path.exists(local_path):
24+
return local_path
25+
elif os.path.exists(repo_path):
26+
return repo_path
27+
28+
raise FileNotFoundError(
29+
f"Data file '{filename}' not found in expected locations: "
30+
f'{local_path} or {repo_path}'
31+
)

0 commit comments

Comments
 (0)