@@ -118,8 +118,8 @@ def _convert_time(
118118 provided, returns the closest matching time_index.
119119 """
120120
121- if not isinstance (data , netCDF4 .Dataset ):
122- raise TypeError ("data must be NetCDF4 object" )
121+ if not isinstance (data , ( netCDF4 .Dataset , xr . Dataset ) ):
122+ raise TypeError ("data must be NetCDF4 object or xarray Dataset " )
123123
124124 if not (time_index or seconds_run ):
125125 raise ValueError ("Input of time_index or seconds_run needed" )
@@ -199,8 +199,8 @@ def get_layer_data(
199199 if not isinstance (layer_index , int ):
200200 raise TypeError ("layer_index must be an int" )
201201
202- if not isinstance (data , netCDF4 .Dataset ):
203- raise TypeError ("data must be NetCDF4 object" )
202+ if not isinstance (data , ( netCDF4 .Dataset , xr . Dataset ) ):
203+ raise TypeError ("data must be NetCDF4 object or xarray Dataset " )
204204
205205 if variable not in data .variables .keys ():
206206 raise ValueError ("variable not recognized" )
@@ -538,8 +538,10 @@ def variable_interpolation(
538538 f"If a string, points must be cells or faces. Got { points } "
539539 )
540540
541- if not isinstance (data , netCDF4 .Dataset ):
542- raise TypeError (f"data must be netCDF4 object. Got { type (data )} " )
541+ if not isinstance (data , (netCDF4 .Dataset , xr .Dataset )):
542+ raise TypeError (
543+ f"data must be netCDF4 object or xarray Dataset. Got { type (data )} "
544+ )
543545
544546 if not isinstance (to_pandas , bool ):
545547 raise TypeError (f"to_pandas must be of type bool. Got: { type (to_pandas )} " )
@@ -620,8 +622,8 @@ def get_all_data_points(
620622 if not isinstance (time_index , int ):
621623 raise TypeError ("time_index must be an int" )
622624
623- if not isinstance (data , netCDF4 .Dataset ):
624- raise TypeError ("data must be NetCDF4 object" )
625+ if not isinstance (data , ( netCDF4 .Dataset , xr . Dataset ) ):
626+ raise TypeError ("data must be NetCDF4 object or xarray Dataset " )
625627
626628 if variable not in data .variables .keys ():
627629 raise ValueError ("variable not recognized" )
@@ -792,8 +794,8 @@ def turbulent_intensity(
792794 f"value of the max time index { max_time_index } "
793795 )
794796
795- if not isinstance (data , netCDF4 .Dataset ):
796- raise TypeError ("data must be netCDF4 object" )
797+ if not isinstance (data , ( netCDF4 .Dataset , xr . Dataset ) ):
798+ raise TypeError ("data must be netCDF4 object or xarray Dataset " )
797799
798800 for variable in ["turkin1" , "ucx" , "ucy" , "ucz" ]:
799801 if variable not in data .variables .keys ():
@@ -886,7 +888,10 @@ def list_variables(data: Union[netCDF4.Dataset, xr.Dataset, xr.DataArray]) -> Li
886888 >>> print(variables)
887889 ['time', 'x', 'y', 'waterdepth', 'ucx', 'ucy', 'ucz', 'turkin1']
888890 """
889- if isinstance (data , netCDF4 .Dataset ):
891+ if isinstance (
892+ data ,
893+ netCDF4 .Dataset ,
894+ ):
890895 return list (data .variables .keys ())
891896 if isinstance (data , (xr .Dataset , xr .DataArray )):
892897 return list (data .variables .keys ())
0 commit comments