Skip to content

Commit 3d11517

Browse files
committed
add xarray input to error checks
1 parent d2155a8 commit 3d11517

1 file changed

Lines changed: 16 additions & 11 deletions

File tree

mhkit/river/io/d3d.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def _convert_time(
118118
provided, returns the closest matching time_index.
119119
"""
120120

121-
if not isinstance(data, netCDF4.Dataset):
122-
raise TypeError("data must be NetCDF4 object")
121+
if not isinstance(data, (netCDF4.Dataset, xr.Dataset)):
122+
raise TypeError("data must be NetCDF4 object or xarray Dataset")
123123

124124
if not (time_index or seconds_run):
125125
raise ValueError("Input of time_index or seconds_run needed")
@@ -199,8 +199,8 @@ def get_layer_data(
199199
if not isinstance(layer_index, int):
200200
raise TypeError("layer_index must be an int")
201201

202-
if not isinstance(data, netCDF4.Dataset):
203-
raise TypeError("data must be NetCDF4 object")
202+
if not isinstance(data, (netCDF4.Dataset, xr.Dataset)):
203+
raise TypeError("data must be NetCDF4 object or xarray Dataset")
204204

205205
if variable not in data.variables.keys():
206206
raise ValueError("variable not recognized")
@@ -538,8 +538,10 @@ def variable_interpolation(
538538
f"If a string, points must be cells or faces. Got {points}"
539539
)
540540

541-
if not isinstance(data, netCDF4.Dataset):
542-
raise TypeError(f"data must be netCDF4 object. Got {type(data)}")
541+
if not isinstance(data, (netCDF4.Dataset, xr.Dataset)):
542+
raise TypeError(
543+
f"data must be netCDF4 object or xarray Dataset. Got {type(data)}"
544+
)
543545

544546
if not isinstance(to_pandas, bool):
545547
raise TypeError(f"to_pandas must be of type bool. Got: {type(to_pandas)}")
@@ -620,8 +622,8 @@ def get_all_data_points(
620622
if not isinstance(time_index, int):
621623
raise TypeError("time_index must be an int")
622624

623-
if not isinstance(data, netCDF4.Dataset):
624-
raise TypeError("data must be NetCDF4 object")
625+
if not isinstance(data, (netCDF4.Dataset, xr.Dataset)):
626+
raise TypeError("data must be NetCDF4 object or xarray Dataset")
625627

626628
if variable not in data.variables.keys():
627629
raise ValueError("variable not recognized")
@@ -792,8 +794,8 @@ def turbulent_intensity(
792794
f"value of the max time index {max_time_index}"
793795
)
794796

795-
if not isinstance(data, netCDF4.Dataset):
796-
raise TypeError("data must be netCDF4 object")
797+
if not isinstance(data, (netCDF4.Dataset, xr.Dataset)):
798+
raise TypeError("data must be netCDF4 object or xarray Dataset")
797799

798800
for variable in ["turkin1", "ucx", "ucy", "ucz"]:
799801
if variable not in data.variables.keys():
@@ -886,7 +888,10 @@ def list_variables(data: Union[netCDF4.Dataset, xr.Dataset, xr.DataArray]) -> Li
886888
>>> print(variables)
887889
['time', 'x', 'y', 'waterdepth', 'ucx', 'ucy', 'ucz', 'turkin1']
888890
"""
889-
if isinstance(data, netCDF4.Dataset):
891+
if isinstance(
892+
data,
893+
netCDF4.Dataset,
894+
):
890895
return list(data.variables.keys())
891896
if isinstance(data, (xr.Dataset, xr.DataArray)):
892897
return list(data.variables.keys())

0 commit comments

Comments
 (0)