|
2 | 2 |
|
3 | 3 | from unittest.mock import Mock, patch |
4 | 4 |
|
| 5 | +import numpy as np |
5 | 6 | import pytest |
| 7 | +import xarray as xr |
6 | 8 |
|
7 | 9 | from eopf_geozarr.conversion.fs_utils import ( |
8 | 10 | get_s3_credentials_info, |
|
14 | 16 | path_exists, |
15 | 17 | read_json_metadata, |
16 | 18 | replace_json_invalid_floats, |
| 19 | + sanitize_dataset_attributes, |
17 | 20 | validate_s3_access, |
18 | 21 | write_json_metadata, |
19 | 22 | ) |
@@ -210,10 +213,64 @@ def test_replace_json_invalid_floats() -> None: |
210 | 213 | } |
211 | 214 | expected: dict[str, object] = { |
212 | 215 | "nan": "NaN", |
213 | | - "nested_nan": {"nan": "NaN", "inf": "Infinity", "-inf": "-Infinity"}, |
214 | | - "nan_in_list": ["NaN", "Infinity", "-Infinity"], |
| 216 | + "nested": {"nan": "NaN", "inf": "Infinity", "-inf": "-Infinity"}, |
| 217 | + "in_list": ["NaN", "Infinity", "-Infinity"], |
215 | 218 | "inf": "Infinity", |
216 | 219 | "-inf": "-Infinity", |
217 | 220 | } |
218 | 221 | observed = replace_json_invalid_floats(data) |
219 | 222 | assert observed == expected |
| 223 | + |
| 224 | + |
| 225 | +def test_sanitize_dataset_attributes() -> None: |
| 226 | + """ |
| 227 | + Check that sanitize_dataset_attributes removes invalid floats from the attributes of an |
| 228 | + xarray Dataset, as well as the attributes of the data variables and the coordinate variables. |
| 229 | + """ |
| 230 | + |
| 231 | + # Create a dataset with NaN and Infinity values in various attribute locations |
| 232 | + ds = xr.Dataset( |
| 233 | + data_vars={ |
| 234 | + "temperature": ( |
| 235 | + ["x", "y"], |
| 236 | + np.array([[1.0, 2.0], [3.0, 4.0]]), |
| 237 | + {"fill_value": float("nan"), "valid_max": float("inf")}, |
| 238 | + ), |
| 239 | + "pressure": ( |
| 240 | + ["x", "y"], |
| 241 | + np.array([[10.0, 20.0], [30.0, 40.0]]), |
| 242 | + {"fill_value": float("nan"), "valid_min": float("-inf")}, |
| 243 | + ), |
| 244 | + }, |
| 245 | + coords={ |
| 246 | + "x": (["x"], [0, 1], {"missing_value": float("nan")}), |
| 247 | + "y": (["y"], [0, 1], {"scale_factor": 1.0}), |
| 248 | + }, |
| 249 | + attrs={ |
| 250 | + "global_nan": float("nan"), |
| 251 | + "global_inf": float("inf"), |
| 252 | + "nested": {"inner_nan": float("nan")}, |
| 253 | + }, |
| 254 | + ) |
| 255 | + |
| 256 | + # Sanitize the dataset |
| 257 | + result = sanitize_dataset_attributes(ds) |
| 258 | + |
| 259 | + # Check dataset-level attributes |
| 260 | + assert result.attrs["global_nan"] == "NaN" |
| 261 | + assert result.attrs["global_inf"] == "Infinity" |
| 262 | + assert result.attrs["nested"]["inner_nan"] == "NaN" |
| 263 | + |
| 264 | + # Check data variable attributes |
| 265 | + assert result["temperature"].attrs["fill_value"] == "NaN" |
| 266 | + assert result["temperature"].attrs["valid_max"] == "Infinity" |
| 267 | + assert result["pressure"].attrs["fill_value"] == "NaN" |
| 268 | + assert result["pressure"].attrs["valid_min"] == "-Infinity" |
| 269 | + |
| 270 | + # Check coordinate attributes |
| 271 | + assert result.coords["x"].attrs["missing_value"] == "NaN" |
| 272 | + assert result.coords["y"].attrs["scale_factor"] == 1.0 # Normal float unchanged |
| 273 | + |
| 274 | + # Verify the original dataset was not modified |
| 275 | + assert ds.attrs["global_nan"] != ds.attrs["global_nan"] # NaN != NaN |
| 276 | + assert ds["temperature"].attrs["fill_value"] != ds["temperature"].attrs["fill_value"] |
0 commit comments