Skip to content

Commit ed1273c

Browse files
committed
updated get_all_data for xarray
1 parent 82facb0 commit ed1273c

1 file changed

Lines changed: 80 additions & 103 deletions

File tree

mhkit/river/io/d3d.py

Lines changed: 80 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def seconds_to_index(data: netCDF4.Dataset, seconds_run: Union[int, float]) -> i
9494
def _convert_time(
9595
data: netCDF4.Dataset or xr.Dataset,
9696
time_index: Optional[Union[int, float]] = None,
97-
seconds_run: Optional[Union[int, float]] = None,
97+
seconds_run: Optional[Union[int, float, np.datetime64]] = None,
9898
) -> Union[int, float]:
9999
"""
100100
Converts a time index to seconds or seconds to a time index. The user
@@ -393,8 +393,7 @@ def get_layer_data(
393393
if isinstance(data, netCDF4.Dataset):
394394
time = np.ma.getdata(data.variables["time"][time_index], False) * np.ones(len(x))
395395
elif isinstance(data, xr.Dataset):
396-
time= [data.time.values[time_index]] * len(x)
397-
396+
time=(data.time.values[time_index] - data.time.values[0]).astype('timedelta64[s]').astype(int) * np.ones(len(x))
398397

399398
index = np.arange(0, len(time))
400399
layer_data = xr.Dataset(
@@ -596,10 +595,8 @@ def variable_interpolation(
596595
f"If a string, points must be cells or faces. Got {points}"
597596
)
598597

599-
if not isinstance(data, (netCDF4.Dataset, xr.Dataset)):
600-
raise TypeError(
601-
f"data must be netCDF4 object or xarray Dataset. Got {type(data)}"
602-
)
598+
if not isinstance(data, netCDF4.Dataset):
599+
raise TypeError(f"data must be netCDF4 object. Got {type(data)}")
603600

604601
if not isinstance(to_pandas, bool):
605602
raise TypeError(f"to_pandas must be of type bool. Got: {type(to_pandas)}")
@@ -648,7 +645,7 @@ def variable_interpolation(
648645

649646

650647
def get_all_data_points(
651-
data: netCDF4.Dataset, variable: str, time_index: int = -1, to_pandas: bool = True
648+
data: (netCDF4.Dataset, xr.Dataset), variable: str, time_index: int = -1, to_pandas: bool = True
652649
) -> Union[pd.DataFrame, xr.Dataset]:
653650
"""
654651
Get data points for a passed variable for all layers at a specified time from
@@ -694,53 +691,83 @@ def get_all_data_points(
694691
raise ValueError(
695692
f"time_index must be less than the max time index {max_time_index}"
696693
)
694+
if isinstance(data, netCDF4.Dataset):
695+
if "mesh2d" in variable:
696+
cords_to_layers = {
697+
"mesh2d_face_x mesh2d_face_y": {
698+
"name": "mesh2d_nLayers",
699+
"coords": data.variables["mesh2d_layer_sigma"][:],
700+
},
701+
"mesh2d_edge_x mesh2d_edge_y": {
702+
"name": "mesh2d_nInterfaces",
703+
"coords": data.variables["mesh2d_interface_sigma"][:],
704+
},
705+
}
697706

698-
if "mesh2d" in variable:
699-
cords_to_layers = {
700-
"mesh2d_face_x mesh2d_face_y": {
701-
"name": "mesh2d_nLayers",
702-
"coords": data.variables["mesh2d_layer_sigma"][:],
703-
},
704-
"mesh2d_face_x mesh2d_face_y mesh2d_layer_sigma": {
705-
"name": "mesh2d_nLayers",
706-
"coords": data.variables["mesh2d_layer_sigma"][:],
707-
},
708-
"mesh2d_edge_x mesh2d_edge_y": {
709-
"name": "mesh2d_nInterfaces",
710-
"coords": data.variables["mesh2d_interface_sigma"][:],
711-
},
712-
}
713-
714-
elif str(data.variables[variable].coordinates) == "FlowElem_xcc FlowElem_ycc":
715-
cords_to_layers = {
716-
"FlowElem_xcc FlowElem_ycc": {
717-
"name": "laydim",
718-
"coords": data.variables["LayCoord_cc"][:],
719-
},
720-
"FlowLink_xu FlowLink_yu": {
721-
"name": "wdim",
722-
"coords": data.variables["LayCoord_w"][:],
723-
},
724-
}
725-
else:
726-
cords_to_layers = {
727-
"FlowElem_xcc FlowElem_ycc LayCoord_cc LayCoord_cc": {
728-
"name": "laydim",
729-
"coords": data.variables["LayCoord_cc"][:],
730-
},
731-
"FlowLink_xu FlowLink_yu": {
732-
"name": "wdim",
733-
"coords": data.variables["LayCoord_w"][:],
734-
},
735-
}
736-
737-
layer_dim = str(data.variables[variable].coordinates)
707+
elif str(data.variables[variable].coordinates) == "FlowElem_xcc FlowElem_ycc":
708+
cords_to_layers = {
709+
"FlowElem_xcc FlowElem_ycc": {
710+
"name": "laydim",
711+
"coords": data.variables["LayCoord_cc"][:],
712+
},
713+
"FlowLink_xu FlowLink_yu": {
714+
"name": "wdim",
715+
"coords": data.variables["LayCoord_w"][:],
716+
},
717+
}
718+
else:
719+
cords_to_layers = {
720+
"FlowElem_xcc FlowElem_ycc LayCoord_cc LayCoord_cc": {
721+
"name": "laydim",
722+
"coords": data.variables["LayCoord_cc"][:],
723+
},
724+
"FlowLink_xu FlowLink_yu": {
725+
"name": "wdim",
726+
"coords": data.variables["LayCoord_w"][:],
727+
},
728+
}
729+
730+
layer_dim = str(data.variables[variable].coordinates)
731+
elif isinstance(data, xr.Dataset):
732+
if "mesh2d" in variable:
733+
cords_to_layers = {
734+
"mesh2d_face_x mesh2d_face_y": {
735+
"name": "mesh2d_nLayers",
736+
"coords": data.variables["mesh2d_layer_sigma"][:],
737+
},
738+
"mesh2d_edge_x mesh2d_edge_y": {
739+
"name": "mesh2d_nInterfaces",
740+
"coords": data.variables["mesh2d_interface_sigma"][:],
741+
},
742+
}
743+
bottom_depth = data["mesh2d_waterdepth"].values[time_index, :]
744+
waterlevel = data["mesh2d_s1"].values[time_index, :]
745+
coords = list(data["waterdepth"].coords)
746+
elif str(list(data[variable].coords)) == "['FlowElem_xcc', 'FlowElem_ycc', 'time']":
747+
cords_to_layers = {
748+
"FlowElem_xcc FlowElem_ycc": {
749+
"name": "laydim",
750+
"coords": data.variables["LayCoord_cc"][:],
751+
},
752+
"FlowLink_xu FlowLink_yu": {
753+
"name": "wdim",
754+
"coords": data.variables["LayCoord_w"][:],
755+
},
756+
}
757+
bottom_depth = data["waterdepth"].values[time_index, :]
758+
waterlevel = data["s1"].values[time_index, :]
759+
coords = list(data["waterdepth"].coords)
760+
761+
layer_dim = " ".join(map(str, list(data[variable].coords)[0:2]))
738762

739763
try:
740764
cord_sys = cords_to_layers[layer_dim]["coords"]
741765
except KeyError as exc:
742766
raise ValueError("Coordinates not recognized.") from exc
743-
layer_percentages = np.ma.getdata(cord_sys, False)
767+
if isinstance(data, netCDF4.Dataset):
768+
layer_percentages = np.ma.getdata(cord_sys, False) # accumulative
769+
elif isinstance(data, xr.Dataset):
770+
layer_percentages= cord_sys.values # accumulative
744771

745772
x_all = []
746773
y_all = []
@@ -758,7 +785,7 @@ def get_all_data_points(
758785
depth_all = np.append(depth_all, layer_data.waterdepth)
759786
water_level_all = np.append(water_level_all, layer_data.waterlevel)
760787
v_all = np.append(v_all, layer_data.v)
761-
time_all = np.append(time_all, layer_data.time)
788+
time_all = np.append(time_all, np.asarray(layer_data.time))
762789

763790
index = np.arange(0, len(time_all))
764791
all_data = xr.Dataset(
@@ -852,8 +879,8 @@ def turbulent_intensity(
852879
f"value of the max time index {max_time_index}"
853880
)
854881

855-
if not isinstance(data, (netCDF4.Dataset, xr.Dataset)):
856-
raise TypeError("data must be netCDF4 object or xarray Dataset")
882+
if not isinstance(data, netCDF4.Dataset):
883+
raise TypeError("data must be netCDF4 object")
857884

858885
for variable in ["turkin1", "ucx", "ucy", "ucz"]:
859886
if variable not in data.variables.keys():
@@ -946,61 +973,11 @@ def list_variables(data: Union[netCDF4.Dataset, xr.Dataset, xr.DataArray]) -> Li
946973
>>> print(variables)
947974
['time', 'x', 'y', 'waterdepth', 'ucx', 'ucy', 'ucz', 'turkin1']
948975
"""
949-
if isinstance(
950-
data,
951-
netCDF4.Dataset,
952-
):
976+
if isinstance(data, netCDF4.Dataset):
953977
return list(data.variables.keys())
954978
if isinstance(data, (xr.Dataset, xr.DataArray)):
955979
return list(data.variables.keys())
956980
raise TypeError(
957981
"data must be a NetCDF4 Dataset, xarray Dataset, or "
958982
f"xarray DataArray. Got: {type(data)}"
959983
)
960-
961-
962-
def calculate_grid_convergence_index(
963-
fine_grid, coarse_grid, refinement_ratio, factor_of_safety=1.25, order=2
964-
):
965-
"""
966-
Calculate the Grid Convergence Index (GCI) between two grid sizes.
967-
968-
NASA. (n.d.). Examining spatial (grid) convergence. Accessed Febuary 3, 2026. NASA. https://www.grc.nasa.gov/WWW/wind/valid/tutorial/spatconv.html
969-
970-
Parameters
971-
----------
972-
fine_grid: numpy.ndarray
973-
Results from the finer grid.
974-
coarse_grid: numpy.ndarray
975-
Results from the coarser grid.
976-
refinement_ratio: float
977-
Refinement ratio between the grids.
978-
factor_of_safety: float
979-
Factor of safety (default is 1.25).
980-
order: int
981-
Order of accuracy (default is 2).
982-
983-
Returns
984-
-------
985-
gci: float
986-
Grid Convergence Index (GCI).
987-
"""
988-
989-
# Validate inputs
990-
if not (np.issubdtype(refinement_ratio.dtype, np.number)):
991-
raise TypeError("refinement_ratio must be a numeric values")
992-
if not (np.issubdtype(factor_of_safety.dtype, np.number)):
993-
raise TypeError("factor_of_safety must be a numeric values")
994-
if not (np.issubdtype(order.dtype, np.number)):
995-
raise TypeError("order must be a numeric values")
996-
if not (np.issubdtype(fine_grid.dtype, np.number) and np.issubdtype(coarse_grid.dtype, np.number)):
997-
raise TypeError("fine_grid and coarse_grid must contain numeric values")
998-
if fine_grid.shape != coarse_grid.shape:
999-
raise ValueError("fine_grid and coarse_grid must have the same shape")
1000-
1001-
# Calculate the approximate relative error
1002-
error = np.abs((fine_grid - coarse_grid) / fine_grid)
1003-
1004-
# Calculate the GCI
1005-
gci = (factor_of_safety * error) / (refinement_ratio**order - 1)
1006-
return gci

0 commit comments

Comments
 (0)