Skip to content

Commit a306e66

Browse files
committed
Remove netcdf4 as a dependency
1 parent ba6cc37 commit a306e66

15 files changed

Lines changed: 39 additions & 42 deletions

File tree

fme/ace/data_loading/test_data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _save_netcdf(
7878
data_vars[f"ak_{i}"] = float(i)
7979
data_vars[f"bk_{i}"] = float(i + 1)
8080
ds = xr.Dataset(data_vars=data_vars, coords=coords)
81-
ds.to_netcdf(filename, unlimited_dims=["time"], format="NETCDF4_CLASSIC")
81+
ds.to_netcdf(filename, unlimited_dims=["time"], format="NETCDF4")
8282
return ds
8383

8484

fme/ace/data_loading/test_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _save_netcdf(
6060
data_vars[f"bk_{i}"] = float(i + 1)
6161

6262
ds = xr.Dataset(data_vars=data_vars, coords=coords)
63-
ds.to_netcdf(filename, unlimited_dims=["time"], format="NETCDF4_CLASSIC")
63+
ds.to_netcdf(filename, unlimited_dims=["time"], format="NETCDF4")
6464

6565

6666
@pytest.mark.parametrize("n_ensemble_members", [1, 2])

fme/ace/inference/data_writer/monthly.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import torch
99
import xarray as xr
10-
from netCDF4 import Dataset
10+
from h5netcdf.legacyapi import Dataset
1111

1212
from fme.ace.inference.data_writer.dataset_metadata import DatasetMetadata
1313
from fme.ace.inference.data_writer.utils import (
@@ -269,7 +269,7 @@ def append_batch(
269269
month_min = np.min(months)
270270
month_range = np.max(months) - month_min + 1
271271

272-
old_size = self.dataset.variables[LEAD_TIME_DIM].size
272+
old_size = self.dataset.variables[LEAD_TIME_DIM].shape[0]
273273
new_size = month_min + month_range
274274

275275
self._extend_lead_time(old_size, new_size)
@@ -303,7 +303,7 @@ def append_batch(
303303

304304
# Add the data to the variable totals
305305
# Have to extract the data and write it back as `.at` does not play nicely
306-
# with netCDF4
306+
# with h5netcdf
307307
# We pull just the month subset we need for speed reasons
308308
self._extend_variable(variable_name, old_size, new_size, initial_value=0.0)
309309
month_data = self.dataset.variables[variable_name][
@@ -320,7 +320,7 @@ def append_batch(
320320
] = month_data
321321
# counts must be added after data, as we use the base counts when updating means
322322
for i_sample in range(n_samples_data):
323-
self.dataset.variables[COUNTS][i_sample] += np.bincount(
323+
self.dataset.variables[COUNTS][i_sample : i_sample + 1] += np.bincount(
324324
months[i_sample], minlength=self.dataset.variables[COUNTS].shape[1]
325325
)
326326

fme/ace/inference/data_writer/raw.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy.typing as npt
1111
import torch
1212
import xarray as xr
13-
from netCDF4 import Dataset
13+
from h5netcdf.legacyapi import Dataset
1414

1515
from fme.ace.inference.data_writer.dataset_metadata import DatasetMetadata
1616
from fme.ace.inference.data_writer.utils import (
@@ -224,9 +224,6 @@ def append_batch(
224224

225225
if current_lead_time_size > 0:
226226
init_times_numeric: np.ndarray = self.dataset.variables[INIT_TIME][:]
227-
init_times_numeric = (
228-
init_times_numeric.filled()
229-
) # convert masked array to ndarray
230227
init_times: np.ndarray = cftime.num2date(
231228
init_times_numeric,
232229
units=self.dataset.variables[INIT_TIME].units,

fme/ace/inference/data_writer/test_data_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
import torch
88
import xarray as xr
9-
from netCDF4 import Dataset
9+
from h5netcdf.legacyapi import Dataset
1010
from xarray.coding.times import CFDatetimeCoder
1111

1212
from fme.ace.data_loading.batch_data import PairedData

fme/ace/inference/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class InitialConditionConfig:
6666
"""
6767

6868
path: str
69-
engine: Literal["netcdf4", "h5netcdf", "zarr"] = "netcdf4"
69+
engine: Literal["h5netcdf", "zarr"] = "h5netcdf"
7070
start_indices: StartIndices | None = None
7171

7272
def get_dataset(self) -> xr.Dataset:
@@ -408,5 +408,5 @@ def run_segmented_inference(config: InferenceConfig, segments: int):
408408
with GlobalTimer():
409409
run_inference_from_config(config_copy)
410410
config_copy.initial_condition = InitialConditionConfig(
411-
path=restart_path, engine="netcdf4"
411+
path=restart_path, engine="h5netcdf"
412412
)

fme/ace/testing/fv3gfs_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def save_nd_netcdf(
9292
for name in variable_names:
9393
for i in range(dim_sizes.n_time):
9494
ds[name].isel(time=i).values[:] = time_varying_values[i]
95-
ds.to_netcdf(filename, unlimited_dims=["time"], format="NETCDF4_CLASSIC")
95+
ds.to_netcdf(filename, unlimited_dims=["time"], format="NETCDF4")
9696
if return_ds:
9797
return ds
9898
return None
@@ -103,7 +103,7 @@ def save_scalar_netcdf(
103103
variable_names: list[str],
104104
):
105105
ds = get_scalar_dataset(variable_names)
106-
ds.to_netcdf(filename, format="NETCDF4_CLASSIC")
106+
ds.to_netcdf(filename, format="NETCDF4")
107107

108108

109109
@dataclasses.dataclass
@@ -219,7 +219,7 @@ def __post_init__(self):
219219
months_list.append(xr.DataArray(months, dims=["time"]))
220220
ds = xr.concat(member_datasets, dim="sample")
221221
ds.coords["valid_time"] = xr.concat(months_list, dim="sample")
222-
ds.to_netcdf(self.data_filename, format="NETCDF4_CLASSIC")
222+
ds.to_netcdf(self.data_filename, format="NETCDF4")
223223
self.start_time = cftime.DatetimeProlepticGregorian(2000, 1, 1)
224224

225225
@property

fme/core/dataset/test_xarray.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _get_data(
192192
filenames.append(filename)
193193

194194
initial_condition_names = ()
195-
start_indices = _get_cumulative_timesteps(_get_raw_times(filenames, "netcdf4"))
195+
start_indices = _get_cumulative_timesteps(_get_raw_times(filenames, "h5netcdf"))
196196
if write_extra_vars:
197197
variable_names = VariableNames(
198198
time_dependent_names=(*var_names, "varying_scalar_var"),
@@ -289,7 +289,7 @@ def mock_monthly_zarr_ensemble_dim(
289289
)
290290

291291

292-
def load_files_without_dask(files, engine="netcdf4") -> xr.Dataset:
292+
def load_files_without_dask(files, engine="h5netcdf") -> xr.Dataset:
293293
"""Load a sequence of files without dask, concatenating along the time dimension.
294294
295295
We load the data from the files into memory to ensure Datasets are properly closed,
@@ -407,7 +407,7 @@ def xarray_dataset_constructor(
407407
@pytest.mark.parametrize(
408408
"mock_data_fixture, engine, file_pattern, labels",
409409
[
410-
("mock_monthly_netcdfs", "netcdf4", "*.nc", set()),
410+
("mock_monthly_netcdfs", "h5netcdf", "*.nc", set()),
411411
("mock_monthly_zarr", "zarr", "*.zarr", {"foo_label"}),
412412
],
413413
)
@@ -777,7 +777,7 @@ def test_dataset_config_dtype_raises():
777777
@pytest.mark.parametrize(
778778
"mock_data_fixture, engine, file_pattern",
779779
[
780-
("mock_monthly_netcdfs_with_nans", "netcdf4", "*.nc"),
780+
("mock_monthly_netcdfs_with_nans", "h5netcdf", "*.nc"),
781781
("mock_monthly_zarr_with_nans", "zarr", "*.zarr"),
782782
],
783783
)
@@ -1049,7 +1049,7 @@ def test_invalid_config_field_raises_error(kwargs):
10491049
@pytest.mark.parametrize(
10501050
"mock_data_fixture, engine, file_pattern",
10511051
[
1052-
("mock_monthly_netcdfs_ensemble_dim", "netcdf4", "*.nc"),
1052+
("mock_monthly_netcdfs_ensemble_dim", "h5netcdf", "*.nc"),
10531053
("mock_monthly_zarr_ensemble_dim", "zarr", "*.zarr"),
10541054
],
10551055
)
@@ -1074,7 +1074,7 @@ def test_dataset_with_nonspacetime_dim(
10741074
@pytest.mark.parametrize(
10751075
"mock_data_fixture, engine, file_pattern",
10761076
[
1077-
("mock_monthly_netcdfs_ensemble_dim", "netcdf4", "*.nc"),
1077+
("mock_monthly_netcdfs_ensemble_dim", "h5netcdf", "*.nc"),
10781078
("mock_monthly_zarr_ensemble_dim", "zarr", "*.zarr"),
10791079
],
10801080
)
@@ -1098,7 +1098,7 @@ def test_dataset_raise_error_on_dim_mismatch(
10981098
@pytest.mark.parametrize(
10991099
"mock_data_fixture, engine, file_pattern",
11001100
[
1101-
("mock_monthly_netcdfs_ensemble_dim", "netcdf4", "*.nc"),
1101+
("mock_monthly_netcdfs_ensemble_dim", "h5netcdf", "*.nc"),
11021102
("mock_monthly_zarr_ensemble_dim", "zarr", "*.zarr"),
11031103
],
11041104
)
@@ -1191,7 +1191,7 @@ def test_parallel__get_raw_times(tmpdir):
11911191
ds.isel(time=time_slice).to_netcdf(path)
11921192
paths.append(path)
11931193

1194-
result = np.concatenate(_get_raw_times(paths, engine="netcdf4"))
1194+
result = np.concatenate(_get_raw_times(paths, engine="h5netcdf"))
11951195
np.testing.assert_equal(result, times)
11961196

11971197

fme/core/dataset/xarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ class XarrayDataConfig(DatasetConfigABC):
448448
data_path: str
449449
file_pattern: str = "*.nc"
450450
n_repeats: int = 1
451-
engine: Literal["netcdf4", "h5netcdf", "zarr"] = "netcdf4"
451+
engine: Literal["h5netcdf", "zarr"] = "h5netcdf"
452452
spatial_dimensions: Literal["healpix", "latlon"] = "latlon"
453453
subset: Slice | TimeSlice | RepeatedInterval = dataclasses.field(
454454
default_factory=Slice

fme/coupled/data_loading/test_data_loader.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _save_netcdf(
113113
data_vars[name] = data_vars[name].where(mask == 1, float("nan"))
114114
data_vars[f"idepth_{i}"] = float(i)
115115
ds = xr.Dataset(data_vars=data_vars, coords=coords)
116-
ds.to_netcdf(filename, unlimited_dims=["time"], format="NETCDF4_CLASSIC")
116+
ds.to_netcdf(filename, unlimited_dims=["time"], format="NETCDF4")
117117
return ds
118118

119119

@@ -451,14 +451,14 @@ def test_zarr_engine_used_true():
451451
config = CoupledDataLoaderConfig(
452452
dataset=[
453453
CoupledDatasetConfig(
454-
ocean=XarrayDataConfig(data_path="ocean", engine="netcdf4"),
454+
ocean=XarrayDataConfig(data_path="ocean", engine="h5netcdf"),
455455
atmosphere=XarrayDataConfig(
456456
data_path="atmos", file_pattern="data.zarr", engine="zarr"
457457
),
458458
),
459459
CoupledDatasetConfig(
460-
ocean=XarrayDataConfig(data_path="ocean", engine="netcdf4"),
461-
atmosphere=XarrayDataConfig(data_path="atmos", engine="netcdf4"),
460+
ocean=XarrayDataConfig(data_path="ocean", engine="h5netcdf"),
461+
atmosphere=XarrayDataConfig(data_path="atmos", engine="h5netcdf"),
462462
),
463463
],
464464
batch_size=1,
@@ -470,8 +470,8 @@ def test_zarr_engine_used_false():
470470
config = CoupledDataLoaderConfig(
471471
dataset=[
472472
CoupledDatasetConfig(
473-
ocean=XarrayDataConfig(data_path="ocean", engine="netcdf4"),
474-
atmosphere=XarrayDataConfig(data_path="atmos", engine="netcdf4"),
473+
ocean=XarrayDataConfig(data_path="ocean", engine="h5netcdf"),
474+
atmosphere=XarrayDataConfig(data_path="atmos", engine="h5netcdf"),
475475
)
476476
],
477477
batch_size=1,
@@ -482,7 +482,7 @@ def test_zarr_engine_used_false():
482482
def test_zarr_engine_used_true_inference():
483483
config = InferenceDataLoaderConfig(
484484
dataset=CoupledDatasetWithOptionalOceanConfig(
485-
ocean=XarrayDataConfig(data_path="ocean", engine="netcdf4"),
485+
ocean=XarrayDataConfig(data_path="ocean", engine="h5netcdf"),
486486
atmosphere=XarrayDataConfig(
487487
data_path="atmos", file_pattern="data.zarr", engine="zarr"
488488
),
@@ -495,8 +495,8 @@ def test_zarr_engine_used_true_inference():
495495
def test_zarr_engine_used_false_inference():
496496
config = InferenceDataLoaderConfig(
497497
dataset=CoupledDatasetWithOptionalOceanConfig(
498-
ocean=XarrayDataConfig(data_path="ocean", engine="netcdf4"),
499-
atmosphere=XarrayDataConfig(data_path="atmos", engine="netcdf4"),
498+
ocean=XarrayDataConfig(data_path="ocean", engine="h5netcdf"),
499+
atmosphere=XarrayDataConfig(data_path="atmos", engine="h5netcdf"),
500500
),
501501
start_indices=ExplicitIndices([0]),
502502
)

0 commit comments

Comments
 (0)