Skip to content

Commit fca6446

Browse files
authored
Add CorrectorConfigABC (ai2cm#965)
Refactors corrector config typing with `CorrectorConfigABC` with a `get_corrector` abstract method which accepts `DatasetInfo` and returns a `CorrectorABC`. This allows the corrector configs to build their corresponding correctors, removing that responsibility from the vertical coordinates. Changes: - `CorrectorSelector.registry` was previously unused but is now once again active via `CorrectorSelector.__post_init__`. - Adds properties `atmosphere_vertical_coordinate` and `ocean_vertical_coordinate` to `DatasetInfo`. - [x] Tests added
1 parent ae6fc19 commit fca6446

13 files changed

Lines changed: 149 additions & 165 deletions

File tree

fme/ace/step/fcn3.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,7 @@ def get_step(
314314
init_weights: Callable[[list[nn.Module]], None],
315315
) -> "FCN3Step":
316316
logging.info("Initializing stepper from provided config")
317-
corrector = dataset_info.vertical_coordinate.build_corrector(
318-
config=self.corrector,
319-
gridded_operations=dataset_info.gridded_operations,
320-
timestep=dataset_info.timestep,
321-
)
317+
corrector = self.corrector.get_corrector(dataset_info)
322318
normalizer = self.normalization.get_network_normalizer(self._normalize_names)
323319
return FCN3Step(
324320
config=self,

fme/core/atmosphere_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Mapping
2-
from typing import Protocol
2+
from typing import Protocol, runtime_checkable
33

44
import torch
55

@@ -40,6 +40,7 @@
4040
}
4141

4242

43+
@runtime_checkable
4344
class HasAtmosphereVerticalIntegral(Protocol):
4445
def vertical_integral(
4546
self,

fme/core/coordinates.py

Lines changed: 0 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,12 @@
1111

1212
from fme.core import metrics
1313
from fme.core.constants import EARTH_RADIUS, GRAVITY
14-
from fme.core.corrector.atmosphere import AtmosphereCorrector, AtmosphereCorrectorConfig
15-
from fme.core.corrector.ice import IceCorrector, IceCorrectorConfig
16-
from fme.core.corrector.ocean import OceanCorrector, OceanCorrectorConfig
17-
from fme.core.corrector.registry import CorrectorABC
1814
from fme.core.derived_variables import compute_derived_quantities
1915
from fme.core.device import get_device
2016
from fme.core.distributed import Distributed
2117
from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations
2218
from fme.core.mask_provider import MaskProvider, MaskProviderABC, NullMaskProvider
2319
from fme.core.ocean_derived_variables import compute_ocean_derived_quantities
24-
from fme.core.registry.corrector import CorrectorSelector
2520
from fme.core.typing_ import TensorDict, TensorMapping
2621
from fme.core.winds import lon_lat_to_xyz
2722

@@ -136,15 +131,6 @@ def __repr__(self) -> str:
136131
def __eq__(self, other) -> bool:
137132
pass
138133

139-
@abc.abstractmethod
140-
def build_corrector(
141-
self,
142-
config: AtmosphereCorrectorConfig | CorrectorSelector,
143-
gridded_operations: GriddedOperations,
144-
timestep: timedelta,
145-
) -> CorrectorABC:
146-
pass
147-
148134
@abc.abstractmethod
149135
def build_derive_function(
150136
self,
@@ -208,35 +194,6 @@ def __len__(self):
208194
"""The number of vertical layer interfaces."""
209195
return len(self.ak)
210196

211-
def build_corrector(
212-
self,
213-
config: AtmosphereCorrectorConfig | CorrectorSelector,
214-
gridded_operations: GriddedOperations,
215-
timestep: timedelta,
216-
) -> AtmosphereCorrector:
217-
if (
218-
isinstance(config, CorrectorSelector)
219-
and config.type != "atmosphere_corrector"
220-
):
221-
raise ValueError(
222-
f"Cannot build corrector for vertical coordinate {self} with "
223-
f"corrector selector {config}."
224-
)
225-
if isinstance(config, CorrectorSelector):
226-
config_instance = dacite.from_dict(
227-
data_class=AtmosphereCorrectorConfig,
228-
data=config.config,
229-
config=dacite.Config(strict=True),
230-
)
231-
else:
232-
config_instance = config
233-
return AtmosphereCorrector(
234-
config=config_instance,
235-
gridded_operations=gridded_operations,
236-
vertical_coordinate=self,
237-
timestep=timestep,
238-
)
239-
240197
def build_derive_function(
241198
self,
242199
timestep: timedelta,
@@ -392,30 +349,6 @@ def __len__(self):
392349
"""The number of vertical layer interfaces."""
393350
return len(self.idepth)
394351

395-
def build_corrector(
396-
self,
397-
config: AtmosphereCorrectorConfig | CorrectorSelector,
398-
gridded_operations: GriddedOperations,
399-
timestep: timedelta,
400-
) -> OceanCorrector:
401-
if isinstance(config, AtmosphereCorrectorConfig):
402-
raise ValueError(
403-
"Cannot build corrector for depth coordinate with an "
404-
"AtmosphereCorrectorConfig."
405-
)
406-
elif config.type != "ocean_corrector":
407-
raise ValueError(
408-
f"Cannot build corrector for vertical coordinate {self} with "
409-
f"corrector selector {config}."
410-
)
411-
config_instance = OceanCorrectorConfig.from_state(config.config)
412-
return OceanCorrector(
413-
config=config_instance,
414-
gridded_operations=gridded_operations,
415-
vertical_coordinate=self,
416-
timestep=timestep,
417-
)
418-
419352
def build_derive_function(
420353
self,
421354
timestep: timedelta,
@@ -521,56 +454,6 @@ def __repr__(self) -> str:
521454
def __len__(self) -> int:
522455
return 0
523456

524-
def build_corrector(
525-
self,
526-
config: AtmosphereCorrectorConfig | CorrectorSelector,
527-
gridded_operations: GriddedOperations,
528-
timestep: timedelta,
529-
) -> CorrectorABC:
530-
if isinstance(config, AtmosphereCorrectorConfig):
531-
return AtmosphereCorrector(
532-
config=config,
533-
gridded_operations=gridded_operations,
534-
vertical_coordinate=None,
535-
timestep=timestep,
536-
)
537-
if config.type == "atmosphere_corrector":
538-
config_instance = dacite.from_dict(
539-
data_class=AtmosphereCorrectorConfig,
540-
data=config.config,
541-
config=dacite.Config(strict=True),
542-
)
543-
return AtmosphereCorrector(
544-
config=config_instance,
545-
gridded_operations=gridded_operations,
546-
vertical_coordinate=None,
547-
timestep=timestep,
548-
)
549-
elif config.type == "ocean_corrector":
550-
config_instance = OceanCorrectorConfig.from_state(config.config)
551-
return OceanCorrector(
552-
config=config_instance,
553-
gridded_operations=gridded_operations,
554-
vertical_coordinate=None,
555-
timestep=timestep,
556-
)
557-
elif config.type == "ice_corrector":
558-
config_instance = dacite.from_dict(
559-
data_class=IceCorrectorConfig,
560-
data=config.config,
561-
config=dacite.Config(strict=True),
562-
)
563-
return IceCorrector(
564-
config=config_instance,
565-
gridded_operations=gridded_operations,
566-
timestep=timestep,
567-
)
568-
else:
569-
raise ValueError(
570-
f"Invalid corrector type: {config.type}. "
571-
"Must be either 'atmosphere_corrector' or 'ocean_corrector'."
572-
)
573-
574457
def build_derive_function(
575458
self,
576459
timestep: timedelta,

fme/core/corrector/atmosphere.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
compute_layer_thickness,
1414
)
1515
from fme.core.constants import GRAVITY, SPECIFIC_HEAT_OF_DRY_AIR_CONST_VOLUME
16-
from fme.core.corrector.registry import CorrectorABC
16+
from fme.core.corrector.registry import CorrectorABC, CorrectorConfigABC
1717
from fme.core.corrector.utils import force_positive
18+
from fme.core.dataset_info import DatasetInfo
1819
from fme.core.gridded_ops import GriddedOperations
1920
from fme.core.registry.corrector import CorrectorSelector
2021
from fme.core.typing_ import TensorDict, TensorMapping
@@ -40,7 +41,7 @@ class EnergyBudgetConfig:
4041

4142
@CorrectorSelector.register("atmosphere_corrector")
4243
@dataclasses.dataclass
43-
class AtmosphereCorrectorConfig:
44+
class AtmosphereCorrectorConfig(CorrectorConfigABC):
4445
r"""
4546
Configuration for the post-step state corrector.
4647
@@ -138,6 +139,17 @@ def from_state(cls, state: Mapping[str, Any]) -> "AtmosphereCorrectorConfig":
138139
data_class=cls, data=state, config=dacite.Config(strict=True)
139140
)
140141

142+
def get_corrector(
143+
self,
144+
dataset_info: DatasetInfo,
145+
) -> "AtmosphereCorrector":
146+
return AtmosphereCorrector(
147+
self,
148+
dataset_info.gridded_operations,
149+
dataset_info.atmosphere_vertical_coordinate,
150+
dataset_info.timestep,
151+
)
152+
141153

142154
class AtmosphereCorrector(CorrectorABC):
143155
def __init__(

fme/core/corrector/ice.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import dacite
77
import torch
88

9-
from fme.core.corrector.registry import CorrectorABC
9+
from fme.core.corrector.registry import CorrectorABC, CorrectorConfigABC
10+
from fme.core.dataset_info import DatasetInfo
1011
from fme.core.gridded_ops import GriddedOperations
1112
from fme.core.registry.corrector import CorrectorSelector
1213
from fme.core.typing_ import TensorDict, TensorMapping
@@ -185,8 +186,7 @@ def __call__(
185186

186187
@CorrectorSelector.register("ice_corrector")
187188
@dataclasses.dataclass
188-
class IceCorrectorConfig:
189-
# Correctors here. Can add more as needed
189+
class IceCorrectorConfig(CorrectorConfigABC):
190190
budget_correction: IceBudgetCorrectionConfig | None = None
191191

192192
@classmethod
@@ -195,6 +195,16 @@ def from_state(cls, state: Mapping[str, Any]) -> "IceCorrectorConfig":
195195
data_class=cls, data=state, config=dacite.Config(strict=True)
196196
)
197197

198+
def get_corrector(
199+
self,
200+
dataset_info: DatasetInfo,
201+
) -> "IceCorrector":
202+
return IceCorrector(
203+
self,
204+
dataset_info.gridded_operations,
205+
dataset_info.timestep,
206+
)
207+
198208

199209
class IceCorrector(CorrectorABC):
200210
"""

fme/core/corrector/ocean.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
LATENT_HEAT_OF_VAPORIZATION,
1313
SPECIFIC_HEAT_OF_SEA_WATER_CM4,
1414
)
15-
from fme.core.corrector.registry import CorrectorABC
15+
from fme.core.corrector.registry import CorrectorABC, CorrectorConfigABC
1616
from fme.core.corrector.utils import force_positive
17+
from fme.core.dataset_info import DatasetInfo
1718
from fme.core.gridded_ops import GriddedOperations
1819
from fme.core.ocean_data import HasOceanDepthIntegral, OceanData
1920
from fme.core.registry.corrector import CorrectorSelector
@@ -110,7 +111,7 @@ class SurfaceEnergyFluxCorrectionConfig:
110111

111112
@CorrectorSelector.register("ocean_corrector")
112113
@dataclasses.dataclass
113-
class OceanCorrectorConfig:
114+
class OceanCorrectorConfig(CorrectorConfigABC):
114115
force_positive_names: list[str] = dataclasses.field(default_factory=list)
115116
sea_ice_fraction_correction: SeaIceFractionConfig | None = None
116117
surface_energy_flux_correction: SurfaceEnergyFluxCorrectionConfig | None = None
@@ -147,6 +148,17 @@ def remove_deprecated_keys(cls, state: Mapping[str, Any]) -> dict[str, Any]:
147148
)
148149
return state_copy
149150

151+
def get_corrector(
152+
self,
153+
dataset_info: DatasetInfo,
154+
) -> "OceanCorrector":
155+
return OceanCorrector(
156+
self,
157+
dataset_info.gridded_operations,
158+
dataset_info.ocean_vertical_coordinate,
159+
dataset_info.timestep,
160+
)
161+
150162

151163
class OceanCorrector(CorrectorABC):
152164
def __init__(

fme/core/corrector/registry.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
import abc
22

3+
from fme.core.dataset_info import DatasetInfo
34
from fme.core.typing_ import TensorDict, TensorMapping
45

56

7+
class CorrectorConfigABC(abc.ABC):
8+
@abc.abstractmethod
9+
def get_corrector(
10+
self,
11+
dataset_info: DatasetInfo,
12+
) -> "CorrectorABC": ...
13+
14+
615
class CorrectorABC(abc.ABC):
716
@abc.abstractmethod
817
def __call__(

fme/core/dataset_info.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Mapping
44
from typing import Any
55

6+
from fme.core.atmosphere_data import HasAtmosphereVerticalIntegral
67
from fme.core.coordinates import (
78
HorizontalCoordinates,
89
NullVerticalCoordinate,
@@ -14,6 +15,7 @@
1415
from fme.core.dataset.utils import decode_timestep, encode_timestep
1516
from fme.core.gridded_ops import GriddedOperations
1617
from fme.core.mask_provider import MaskProvider, MaskProviderABC, NullMaskProvider
18+
from fme.core.ocean_data import HasOceanDepthIntegral
1719

1820

1921
class MissingDatasetInfo(ValueError):
@@ -181,6 +183,28 @@ def vertical_coordinate(self) -> VerticalCoordinate:
181183
raise MissingDatasetInfo("vertical_coordinate")
182184
return self._vertical_coordinate
183185

186+
@property
187+
def atmosphere_vertical_coordinate(self) -> HasAtmosphereVerticalIntegral | None:
188+
if isinstance(self._vertical_coordinate, HasAtmosphereVerticalIntegral):
189+
return self._vertical_coordinate
190+
elif isinstance(self._vertical_coordinate, NullVerticalCoordinate):
191+
return None
192+
raise RuntimeError(
193+
f"{self._vertical_coordinate} cannot be used as an atmosphere vertical "
194+
"coordinate."
195+
)
196+
197+
@property
198+
def ocean_vertical_coordinate(self) -> HasOceanDepthIntegral | None:
199+
if isinstance(self._vertical_coordinate, HasOceanDepthIntegral):
200+
return self._vertical_coordinate
201+
elif isinstance(self._vertical_coordinate, NullVerticalCoordinate):
202+
return None
203+
raise RuntimeError(
204+
f"{self._vertical_coordinate} cannot be used as an ocean vertical "
205+
"coordinate."
206+
)
207+
184208
@property
185209
def mask_provider(self) -> MaskProvider:
186210
if self._mask_provider is None:

fme/core/ocean_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
from collections.abc import Mapping
33
from types import MappingProxyType
4-
from typing import Protocol
4+
from typing import Protocol, runtime_checkable
55

66
import torch
77

@@ -30,6 +30,7 @@
3030
)
3131

3232

33+
@runtime_checkable
3334
class HasOceanDepthIntegral(Protocol):
3435
def depth_integral(
3536
self,

0 commit comments

Comments
 (0)