diff --git a/none b/none new file mode 100644 index 0000000000..ef8e43b1fb Binary files /dev/null and b/none differ diff --git a/training/docs/modules/losses.rst b/training/docs/modules/losses.rst index 72a1fd2d60..212af171e2 100644 --- a/training/docs/modules/losses.rst +++ b/training/docs/modules/losses.rst @@ -91,6 +91,60 @@ deterministic: .. _multiscale-loss-functions: +*************************** + Time Aggregate Loss Functions +*************************** + +These loss functions encourage the model to produce **temporally consistent** outputs +i.e. output sequences that are internally coherent over +time, not just accurate at each individual step. + +:class:`~anemoi.training.losses.aggregate.TimeAggregateLossWrapper` +addresses this by applying a base loss function to *time-aggregated* +versions of the prediction and target, rather than step-by-step. The +following aggregations are supported: + +.. list-table:: + :widths: 15 85 + :header-rows: 1 + + - - Aggregation + - Description + + - - ``mean`` + - Mean over the output time window — penalises bias in the + temporal average. + + - - ``max`` + - Maximum over the output time window — penalises errors in peak + values. + + - - ``min`` + - Minimum over the output time window — penalises errors in + minimum values. + + - - ``diff`` + - Consecutive step-to-step differences + (``pred[:, 1:] - pred[:, :-1]``) — penalises unrealistic + temporal transitions and discontinuities. + +The wrapper accumulates the specified loss function evaluated on each aggregation in +turn and returns the sum. Because the ``time_steps`` scaler is +intentionally excluded from the inner ``loss_fn`` (temporal aggregation +collapses the time dimension), only spatial and variable scalers should +be listed there. + +.. note:: + + ``TimeAggregateLossWrapper`` requires an output time dimension + greater than one, as it is not + meaningful for single-step tasks. + +We strongly recommend using the time aggregate loss when training any +temporal downscaler. The pre-built config variants ``single_MSE_aggregation`` +and ``ensemble_multiscale_aggregation`` combine it with the primary loss inside a +:class:`~anemoi.training.losses.combined.CombinedLoss`. + *************************** Multiscale Loss Functions *************************** diff --git a/training/docs/modules/tasks.rst b/training/docs/modules/tasks.rst index 79fcff8643..c8616740c6 100644 --- a/training/docs/modules/tasks.rst +++ b/training/docs/modules/tasks.rst @@ -155,6 +155,9 @@ Example: ``input_timestep="6H"``, ``output_timestep="3H"``, ``output_left_boundary=True`` produces output offsets ``[0H, 3H]`` and input offsets ``[0H, 6H]``. +The default is to use the time aggregate loss when training any +temporal downscaler. + .. automodule:: anemoi.training.tasks.temporal_downscaling :members: :no-undoc-members: diff --git a/training/src/anemoi/training/config/temporal_downscaler.yaml b/training/src/anemoi/training/config/temporal_downscaler.yaml index 2717f69856..fe78ccaf22 100644 --- a/training/src/anemoi/training/config/temporal_downscaler.yaml +++ b/training/src/anemoi/training/config/temporal_downscaler.yaml @@ -7,6 +7,7 @@ defaults: - model: graphtransformer - task: temporal_downscaler - training: single +- override training/training_loss: single_MSE_aggregation - _self_ config_validation: True diff --git a/training/src/anemoi/training/config/temporal_downscaler_ensemble.yaml b/training/src/anemoi/training/config/temporal_downscaler_ensemble.yaml index 8baf76b729..e9f3b710db 100644 --- a/training/src/anemoi/training/config/temporal_downscaler_ensemble.yaml +++ b/training/src/anemoi/training/config/temporal_downscaler_ensemble.yaml @@ -7,6 +7,7 @@ defaults: - model: graphtransformer_ens - task: temporal_downscaler - training: ensemble +- override training/training_loss: ensemble_multiscale_aggregation - _self_ config_validation: True diff --git a/training/src/anemoi/training/config/training/diffusion.yaml b/training/src/anemoi/training/config/training/diffusion.yaml index b6f4536667..4cff5cdeda 100644 --- a/training/src/anemoi/training/config/training/diffusion.yaml +++ b/training/src/anemoi/training/config/training/diffusion.yaml @@ -1,6 +1,7 @@ --- defaults: - scalers: global + - training_loss: single - optimization: default - weight_averaging: null diff --git a/training/src/anemoi/training/config/training/ensemble.yaml b/training/src/anemoi/training/config/training/ensemble.yaml index 3cd709f851..7f7086d804 100644 --- a/training/src/anemoi/training/config/training/ensemble.yaml +++ b/training/src/anemoi/training/config/training/ensemble.yaml @@ -1,6 +1,7 @@ --- defaults: - scalers: global + - training_loss: ensemble - optimization: default - weight_averaging: null @@ -51,33 +52,6 @@ strategy: loss_gradient_scaling: False -# loss function for the model -# To train without multiscale loss, set it to the desired loss directly -training_loss: - datasets: - data: # user-defined key in data - # loss class to initialise, can be anything subclassing torch.nn.Module - _target_: anemoi.training.losses.MultiscaleLossWrapper - # Disk mode: multiscale_config: {loss_matrices_path: /path, loss_matrices: ["file.npz", null]} - # On-the-fly: multiscale_config: {num_scales: 4, base_num_nearest_neighbours: 16, base_sigma: 0.01570} - multiscale_config: null # null = single scale, no smoothing - weights: [1.0] - - per_scale_loss: - _target_: anemoi.training.losses.CRPS - scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights', 'time_steps'] - - # Scalers to include in loss calculation - # A selection of available scalers are listed in training/scalers. - # '*' is a valid entry to use all `scalers` given, if a scaler is to be excluded - # add `!scaler_name`, i.e. ['*', '!scaler_1'], and `scaler_1` will not be added. - # scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights'] - ignore_nans: False - no_autocast: True - alpha: 0.95 - - - # Validation metrics calculation, # This may be a list, in which case all metrics will be calculated # and logged according to their name. diff --git a/training/src/anemoi/training/config/training/lam.yaml b/training/src/anemoi/training/config/training/lam.yaml index 14d82b7dd2..1cf4e36067 100644 --- a/training/src/anemoi/training/config/training/lam.yaml +++ b/training/src/anemoi/training/config/training/lam.yaml @@ -1,6 +1,7 @@ --- defaults: - scalers: lam + - training_loss: single - optimization: default - weight_averaging: null @@ -48,19 +49,6 @@ strategy: # don't enable this by default until it's been tested and proven beneficial loss_gradient_scaling: False -# loss function for the model -training_loss: - datasets: - data: # user-defined key in data - # loss class to initialise - _target_: anemoi.training.losses.MSELoss - # Scalers to include in loss calculation - # A selection of available scalers are listed in training/scalers/scalers.yaml - # '*' is a valid entry to use all `scalers` given, if a scaler is to be excluded - # add `!scaler_name`, i.e. ['*', '!scaler_1'], and `scaler_1` will not be added. - scalers: ['pressure_level', 'general_variable', 'node_weights', 'time_steps'] - ignore_nans: False - # Validation metrics calculation, # This may be a list, in which case all metrics will be calculated # and logged according to their name. diff --git a/training/src/anemoi/training/config/training/multi.yaml b/training/src/anemoi/training/config/training/multi.yaml index edd8cf1f79..6cec1d1382 100644 --- a/training/src/anemoi/training/config/training/multi.yaml +++ b/training/src/anemoi/training/config/training/multi.yaml @@ -1,6 +1,7 @@ --- defaults: - scalers: multi + - training_loss: single - optimization: default - weight_averaging: null @@ -56,7 +57,6 @@ max_steps: 150000 submodules_to_freeze: [] - # Dataset-specific loss and metrics configuration training_loss: datasets: diff --git a/training/src/anemoi/training/config/training/single.yaml b/training/src/anemoi/training/config/training/single.yaml index f17e5f9906..bdde98ade1 100644 --- a/training/src/anemoi/training/config/training/single.yaml +++ b/training/src/anemoi/training/config/training/single.yaml @@ -1,6 +1,7 @@ --- defaults: - scalers: global + - training_loss: single - optimization: default - weight_averaging: null @@ -39,26 +40,11 @@ strategy: num_gpus_per_model: ${system.hardware.num_gpus_per_model} read_group_size: ${dataloader.read_group_size} -# loss functions - # dynamic rescaling of the loss gradient # see https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2 # don't enable this by default until it's been tested and proven beneficial loss_gradient_scaling: False -# loss function for the model -training_loss: - datasets: - data: # user-defined key in data - # loss class to initialise - _target_: anemoi.training.losses.MSELoss - # Scalers to include in loss calculation - # A selection of available scalers are listed in training/scalers. - # '*' is a valid entry to use all `scalers` given, if a scaler is to be excluded - # add `!scaler_name`, i.e. ['*', '!scaler_1'], and `scaler_1` will not be added. - scalers: ['pressure_level', 'general_variable', 'node_weights', 'time_steps'] - ignore_nans: False - # Validation metrics calculation, # This may be a list, in which case all metrics will be calculated # and logged according to their name. diff --git a/training/src/anemoi/training/config/training/stretched.yaml b/training/src/anemoi/training/config/training/stretched.yaml index af3fd658b9..7222256a8b 100644 --- a/training/src/anemoi/training/config/training/stretched.yaml +++ b/training/src/anemoi/training/config/training/stretched.yaml @@ -1,6 +1,7 @@ --- defaults: - scalers: stretched + - training_loss: single - optimization: default - weight_averaging: null @@ -49,19 +50,6 @@ strategy: # don't enable this by default until it's been tested and proven beneficial loss_gradient_scaling: False -# loss function for the model -training_loss: - datasets: - data: # user-defined key in data - # loss class to initialise - _target_: anemoi.training.losses.MSELoss - # Scalers to include in loss calculation - # A selection of available scalers are listed in training/scalers/scalers.yaml - # '*' is a valid entry to use all `scalers` given, if a scaler is to be excluded - # add `!scaler_name`, i.e. ['*', '!scaler_1'], and `scaler_1` will not be added. - scalers: ['pressure_level', 'general_variable', 'node_weights', 'time_steps'] - ignore_nans: False - # Validation metrics calculation, # This may be a list, in which case all metrics will be calculated # and logged according to their name. diff --git a/training/src/anemoi/training/config/training/training_loss/ensemble.yaml b/training/src/anemoi/training/config/training/training_loss/ensemble.yaml new file mode 100644 index 0000000000..6a12b888e7 --- /dev/null +++ b/training/src/anemoi/training/config/training/training_loss/ensemble.yaml @@ -0,0 +1,22 @@ +# loss function for the model +# To train without multiscale loss, set it to the desired loss directly +datasets: + data: # user-defined key in data + # loss class to initialise, can be anything subclassing torch.nn.Module + _target_: anemoi.training.losses.MultiscaleLossWrapper + # Disk mode: multiscale_config: {loss_matrices_path: /path, loss_matrices: ["file.npz", null]} + # On-the-fly: multiscale_config: {num_scales: 4, base_num_nearest_neighbours: 16, base_sigma: 0.01570} + multiscale_config: null # null = single scale, no smoothing + weights: [1.0] + per_scale_loss: + _target_: anemoi.training.losses.CRPS + scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights', 'time_steps'] + + # Scalers to include in loss calculation + # A selection of available scalers are listed in training/scalers. + # '*' is a valid entry to use all `scalers` given, if a scaler is to be excluded + # add `!scaler_name`, i.e. ['*', '!scaler_1'], and `scaler_1` will not be added. + # scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights'] + ignore_nans: False + no_autocast: True + alpha: 0.95 diff --git a/training/src/anemoi/training/config/training/training_loss/ensemble_multiscale_aggregation.yaml b/training/src/anemoi/training/config/training/training_loss/ensemble_multiscale_aggregation.yaml new file mode 100644 index 0000000000..e2d103bb38 --- /dev/null +++ b/training/src/anemoi/training/config/training/training_loss/ensemble_multiscale_aggregation.yaml @@ -0,0 +1,27 @@ +datasets: + data: + _target_: anemoi.training.losses.combined.CombinedLoss + ignore_nans: False + # loss_weights: [n_timesteps / (n_timesteps + n_agg_ops), n_agg_ops / (n_timesteps + n_agg_ops)] + # Each sub-loss averages internally (raw over timesteps, aggregate over agg ops). + # These weights re-scale so the total matches: sum_all / (n_timesteps + n_agg_ops). + # Example for 6 timesteps and 4 agg ops: [0.6, 0.4] + loss_weights: [0.6, 0.4] + losses: + - _target_: anemoi.training.losses.MultiscaleLossWrapper + multiscale_config: null # null = single scale, no smoothing + weights: [1.0] + per_scale_loss: + _target_: anemoi.training.losses.CRPS + scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights', 'time_steps'] + ignore_nans: False + no_autocast: True + alpha: 0.95 + - _target_: anemoi.training.losses.aggregate.TimeAggregateLossWrapper + scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights'] + time_aggregation_types: [mean, max, min, diff] + loss_fn: + _target_: anemoi.training.losses.CRPS + ignore_nans: False + no_autocast: True + alpha: 0.95 diff --git a/training/src/anemoi/training/config/training/training_loss/single.yaml b/training/src/anemoi/training/config/training/training_loss/single.yaml new file mode 100644 index 0000000000..d2ddccb98c --- /dev/null +++ b/training/src/anemoi/training/config/training/training_loss/single.yaml @@ -0,0 +1,9 @@ +datasets: + data: + _target_: anemoi.training.losses.MSELoss + # Scalers to include in loss calculation + # A selection of available scalers are listed in training/scalers. + # '*' is a valid entry to use all `scalers` given, if a scaler is to be excluded + # add `!scaler_name`, i.e. ['*', '!scaler_1'], and `scaler_1` will not be added. + scalers: ['pressure_level', 'general_variable', 'node_weights', 'time_steps'] + ignore_nans: False diff --git a/training/src/anemoi/training/config/training/training_loss/single_MSE_aggregation.yaml b/training/src/anemoi/training/config/training/training_loss/single_MSE_aggregation.yaml new file mode 100644 index 0000000000..5e335e34f7 --- /dev/null +++ b/training/src/anemoi/training/config/training/training_loss/single_MSE_aggregation.yaml @@ -0,0 +1,19 @@ +datasets: + data: + _target_: anemoi.training.losses.combined.CombinedLoss + ignore_nans: False + # loss_weights: [n_timesteps / (n_timesteps + n_agg_ops), n_agg_ops / (n_timesteps + n_agg_ops)] + # Each sub-loss averages internally (raw over timesteps, aggregate over agg ops). + # These weights re-scale so the total matches: sum_all / (n_timesteps + n_agg_ops). + # Example for 6 timesteps and 4 agg ops: [0.6, 0.4] + loss_weights: [0.6, 0.4] + losses: + - _target_: anemoi.training.losses.MSELoss + scalers: ['pressure_level', 'general_variable', 'node_weights', 'time_steps'] + ignore_nans: False + - _target_: anemoi.training.losses.aggregate.TimeAggregateLossWrapper + scalers: ['pressure_level', 'general_variable', 'node_weights'] + time_aggregation_types: [mean, max, min, diff] + loss_fn: + _target_: anemoi.training.losses.MSELoss + ignore_nans: False diff --git a/training/src/anemoi/training/diagnostics/callbacks/per_timestep_metrics.py b/training/src/anemoi/training/diagnostics/callbacks/per_timestep_metrics.py new file mode 100644 index 0000000000..243bcb4cbd --- /dev/null +++ b/training/src/anemoi/training/diagnostics/callbacks/per_timestep_metrics.py @@ -0,0 +1,141 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Callback to log per-timestep validation metrics for temporal downscaling tasks.""" + +import logging +from contextlib import nullcontext + +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import Callback + +from anemoi.training.losses.base import BaseLoss +from anemoi.training.utils.enums import TensorDim + +LOGGER = logging.getLogger(__name__) + + +class PerTimestepMetrics(Callback): + """Log validation metrics broken down by output timestep. + + For tasks where the model predicts multiple + output timesteps at once, this callback slices predictions and targets + along the time dimension and logs per-timestep validation metrics. + + Parameters + ---------- + every_n_batches : int + Frequency of per-timestep evaluation (runs every N validation batches). + Default is 1 (every batch). + """ + + def __init__(self, every_n_batches: int = 1) -> None: + super().__init__() + self.every_n_batches = every_n_batches + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list, # noqa: ARG002 + batch: dict[str, torch.Tensor], + batch_idx: int, + ) -> None: + if batch_idx % self.every_n_batches != 0: + return + + precision_mapping = { + "16-mixed": torch.float16, + "bf16-mixed": torch.bfloat16, + } + prec = trainer.precision + dtype = precision_mapping.get(prec) + + context = ( + torch.autocast(device_type=next(iter(batch.values())).device.type, dtype=dtype) + if dtype is not None + else nullcontext() + ) + + with context, torch.no_grad(): + self._eval_per_timestep(pl_module, batch) + + def _eval_per_timestep(self, pl_module: pl.LightningModule, batch: dict[str, torch.Tensor]) -> None: + """Run model and compute metrics per timestep.""" + # Get inputs and targets via the task + x = pl_module.task.get_inputs(batch, data_indices=pl_module.data_indices) + x = pl_module._expand_ens_dim(x) + + # Run model forward + y_pred = pl_module(x) + + # Get targets + y_full = pl_module.task.get_targets(batch) + y = pl_module._collapse_ens_dim(y_full) + + batch_size = next(iter(batch.values())).shape[0] + + # For each dataset, compute per-timestep metrics + for dataset_name in y_pred: + pred = y_pred[dataset_name] # (bs, time, ens, grid, var) + target = y[dataset_name] # (bs, time, grid, var) + + n_timesteps = target.shape[TensorDim.TIME] + + # Gather ensemble members across the ensemble comm group + if hasattr(pl_module, "ens_comm_subgroup") and pl_module.ens_comm_subgroup is not None: + from anemoi.models.distributed.graph import gather_tensor + + pred = gather_tensor( + pred.clone(), + dim=TensorDim.ENSEMBLE_DIM, + sizes=[pred.size(TensorDim.ENSEMBLE_DIM)] * pl_module.ens_comm_subgroup_size, + mgroup=pl_module.ens_comm_subgroup, + ) + + # Post-process for metrics (in physical space) + post_processor = pl_module.model.post_processors[dataset_name] + metrics_dict = pl_module.metrics[dataset_name] + val_metric_ranges = pl_module.val_metric_ranges[dataset_name] + grid_shard_slice = pl_module.grid_shard_slice.get(dataset_name) + + for t in range(n_timesteps): + # Slice single timestep: remove time dim + pred_t = pred[:, t : t + 1, :, :, :] # keep time dim for post-processor + target_t = target[:, t : t + 1, :, :] + + pred_t_post = post_processor(pred_t, in_place=False) + target_t_post = post_processor(target_t, in_place=False) + + for metric_name, metric in metrics_dict.items(): + if not isinstance(metric, BaseLoss): + continue + + for mkey, indices in val_metric_ranges.items(): + step_name = f"val_{metric_name}_metric/{dataset_name}/{mkey}/t_{t + 1}" + + metric_kwargs = { + "scaler_indices": (..., indices), + "grid_shard_slice": grid_shard_slice, + "group": pl_module.model_comm_group, + } + + value = metric(pred_t_post, target_t_post, **metric_kwargs) + + pl_module.log( + step_name, + value, + on_epoch=True, + on_step=False, + prog_bar=False, + logger=pl_module.logger_enabled, + batch_size=batch_size, + sync_dist=True, + ) diff --git a/training/src/anemoi/training/diagnostics/callbacks/plot.py b/training/src/anemoi/training/diagnostics/callbacks/plot.py index 1f41724044..7a924244d8 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/training/src/anemoi/training/diagnostics/callbacks/plot.py @@ -696,7 +696,8 @@ def _plot( parameter_positions = list[int](data_indices.model.output.name_to_index.values()) # reorder parameter_names by position parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] - metadata_variables = pl_module.model.metadata["dataset"].get("variables_metadata") + metadata = pl_module.model.metadata + metadata_variables = metadata["dataset"].get("variables_metadata") if metadata is not None else None # Sort the list using the custom key argsort_indices = argsort_variablename_variablelevel( diff --git a/training/src/anemoi/training/losses/__init__.py b/training/src/anemoi/training/losses/__init__.py index fa4a274f0d..d3bf2b5542 100644 --- a/training/src/anemoi/training/losses/__init__.py +++ b/training/src/anemoi/training/losses/__init__.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from .aggregate import TimeAggregateLossWrapper from .combined import CombinedLoss from .huber import HuberLoss from .kcrps import CRPS @@ -39,6 +40,7 @@ "RMSELoss", "SpectralCRPSLoss", "SpectralL2Loss", + "TimeAggregateLossWrapper", "WeightedMSELoss", "get_loss_function", ] diff --git a/training/src/anemoi/training/losses/aggregate.py b/training/src/anemoi/training/losses/aggregate.py new file mode 100644 index 0000000000..961543127e --- /dev/null +++ b/training/src/anemoi/training/losses/aggregate.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch + +from anemoi.training.losses.base import BaseLossWrapper +from anemoi.training.utils.enums import TensorDim + +if TYPE_CHECKING: + from torch.distributed.distributed_c10d import ProcessGroup + + from anemoi.training.losses.base import BaseLoss + +LOGGER = logging.getLogger(__name__) + + +class TimeAggregateLossWrapper(BaseLossWrapper): + """Wraps a base loss and applies it to time-aggregated predictions. + + Supported time aggregation types: + + - ``"diff"`` - temporal differences (``pred[:, 1:] - pred[:, :-1]``) + - ``"mean"``, ``"min"``, ``"max"`` - applied over the time window + """ + + def __init__( + self, + time_aggregation_types: list[str], + loss_fn: BaseLoss, + ignore_nans: bool = False, + ) -> None: + super().__init__(loss=loss_fn, ignore_nans=ignore_nans) + self.time_aggregation_types = time_aggregation_types + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + squash: bool = True, + *, + scaler_indices: tuple[int, ...] | None = None, + without_scalers: list[str] | list[int] | None = None, + grid_shard_slice: slice | None = None, + group: ProcessGroup | None = None, + squash_mode: str | None = None, + **kwargs, + ) -> torch.Tensor: + """Compute the time aggregate loss over all time aggregation types. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape ``(bs, time, ens, latlon, nvar)``. + target : torch.Tensor + Target tensor, shape ``(bs, time, latlon, nvar)``. + squash : bool, optional + Average the variable dimension, by default ``True``. + scaler_indices : tuple[int, ...] | None, optional + Indices to subset the scaler, by default ``None``. + without_scalers : list[str] | list[int] | None, optional + Scalers to exclude, by default ``None``. + grid_shard_slice : slice | None, optional + Grid shard slice, by default ``None``. + group : ProcessGroup | None, optional + Distributed group for reduction, by default ``None``. + squash_mode : str | None, optional + Variable-dimension reduction mode. If omitted, the wrapped loss default is used. + + Returns + ------- + torch.Tensor + Accumulated loss across all aggregation types. + """ + assert ( + pred.shape[1] > 1 + ), "TimeAggregateLossWrapper requires an output time dimension of size > 1 for aggregation." + loss = torch.tensor(0.0, dtype=pred.dtype, device=pred.device, requires_grad=False) + + # Exclude the TIME scaler from inner loss calls since we iterate per-step + # and apply time weights manually. + without_time = without_scalers or [] + if TensorDim.TIME not in without_time and TensorDim.TIME.value not in without_time: + without_time = [*list(without_time), TensorDim.TIME.value] + + # Extract time weights from the shared scaler (if present) + time_weights = None + for dims, scaler in self.loss.scaler.tensors.values(): + if isinstance(dims, int): + dims = (dims,) + if TensorDim.TIME.value in dims or TensorDim.TIME in dims: + time_weights = scaler + break + + shared_kwargs = dict( + squash=squash, + scaler_indices=scaler_indices, + without_scalers=without_time, + grid_shard_slice=grid_shard_slice, + group=group, + **kwargs, + ) + if squash_mode is not None: + shared_kwargs["squash_mode"] = squash_mode + + for agg_op in self.time_aggregation_types: + loss = loss + self._compute_agg_loss(agg_op, pred, target, time_weights, shared_kwargs) + + # Average over the number of aggregation types, matching the old per-term + # normalisation (old code: loss /= num_interp_steps + num_aggregate_ops). + if self.time_aggregation_types: + loss = loss / len(self.time_aggregation_types) + return loss + + def _compute_agg_loss( + self, + agg_op: str, + pred: torch.Tensor, + target: torch.Tensor, + time_weights: torch.Tensor | None, + shared_kwargs: dict, + ) -> torch.Tensor: + """Compute loss for a single aggregation operation.""" + if agg_op == "diff": + return self._compute_diff_loss(pred, target, time_weights, shared_kwargs) + agg_fns = {"mean": torch.mean, "min": torch.amin, "max": torch.amax} + if agg_op not in agg_fns: + msg = f"Unknown aggregation type '{agg_op}'. Supported: 'diff', 'mean', 'min', 'max'." + raise ValueError(msg) + fn = agg_fns[agg_op] + pred_agg = fn(pred, dim=1, keepdim=True) + target_agg = fn(target, dim=1, keepdim=True) + return self.loss(pred_agg, target_agg, **shared_kwargs) + + def _compute_diff_loss( + self, + pred: torch.Tensor, + target: torch.Tensor, + time_weights: torch.Tensor | None, + shared_kwargs: dict, + ) -> torch.Tensor: + """Compute per-step diff loss, optionally weighted by time scaler.""" + pred_agg = pred[:, 1:, ...] - pred[:, :-1, ...] # (bs, time-1, ens, latlon, nvar) + target_agg = target[:, 1:, ...] - target[:, :-1, ...] # (bs, time-1, latlon, nvar) + loss = torch.tensor(0.0, dtype=pred.dtype, device=pred.device, requires_grad=False) + for step in range(pred_agg.shape[1]): + step_loss = self.loss( + pred_agg[:, step : step + 1, ...], + target_agg[:, step : step + 1, ...], + **shared_kwargs, + ) + if time_weights is not None: + step_loss = step_loss * time_weights[step] + loss = loss + step_loss + return loss diff --git a/training/src/anemoi/training/losses/base.py b/training/src/anemoi/training/losses/base.py index e2b072d036..4bf26721e8 100644 --- a/training/src/anemoi/training/losses/base.py +++ b/training/src/anemoi/training/losses/base.py @@ -14,6 +14,7 @@ from abc import abstractmethod from collections.abc import Iterator from enum import StrEnum +from typing import Any from typing import ClassVar import torch @@ -278,6 +279,52 @@ def forward( """ +class BaseLossWrapper(BaseLoss): + """Transparent wrapper around a single inner loss. + + By default, all scaler and metadata methods are delegated to the + wrapped loss so that the wrapper behaves as if it *were* the inner + loss from the perspective of ``CombinedLoss`` and the scaler + machinery. Subclasses only need to override ``forward``. + """ + + def __init__(self, loss: BaseLoss, **kwargs: Any) -> None: + super().__init__(**kwargs) + if not isinstance(loss, BaseLoss): + msg = f"Invalid loss type provided: {type(loss)}. Expected BaseLoss." + raise TypeError(msg) + self.loss = loss + # Share the inner loss's scaler so that scaler additions/updates + # applied to this wrapper are visible to the actual loss computation. + self.scaler = self.loss.scaler + self.supports_sharding = getattr(self.loss, "supports_sharding", True) + + # -- scaler delegation -------------------------------------------------- + + @functools.wraps(ScaleTensor.add_scaler) + def add_scaler(self, dimension: int | tuple[int], scaler: torch.Tensor, *, name: str | None = None) -> None: + self.loss.add_scaler(dimension=dimension, scaler=scaler, name=name) + + @functools.wraps(ScaleTensor.update_scaler) + def update_scaler(self, name: str, scaler: torch.Tensor, *, override: bool = False) -> None: + self.loss.update_scaler(name=name, scaler=scaler, override=override) + + @functools.wraps(ScaleTensor.has_scaler_for_dim) + def has_scaler_for_dim(self, dim: TensorDim) -> bool: + return self.loss.has_scaler_for_dim(dim=dim) + + # -- metadata delegation ------------------------------------------------ + + @property + def needs_shard_layout_info(self) -> bool: + """Delegate to the wrapped loss.""" + return getattr(self.loss, "needs_shard_layout_info", False) + + def iter_leaf_losses(self) -> Iterator["BaseLoss"]: + """Yield leaf losses from the wrapped loss.""" + yield from self.loss.iter_leaf_losses() + + class FunctionalLoss(BaseLoss): """Loss which a user can subclass and provide `calculate_difference`. diff --git a/training/src/anemoi/training/losses/combined.py b/training/src/anemoi/training/losses/combined.py index fb5ace6699..4b645b6e55 100644 --- a/training/src/anemoi/training/losses/combined.py +++ b/training/src/anemoi/training/losses/combined.py @@ -29,7 +29,7 @@ class CombinedLoss(BaseLoss): """Combined Loss function.""" needs_graph_data: bool = True - # CombinedLoss builds child losses itself, so it needs the filtered scaler + # CombinedLoss builds child losses itself, so it needs the full scaler # set and data indices during construction. factory_context_keys = frozenset( {LossFactoryContextKey.AVAILABLE_SCALERS, LossFactoryContextKey.DATA_INDICES}, @@ -50,18 +50,8 @@ def __init__( Allows multiple losses to be combined into a single loss function, and the components weighted. - As the losses are designed for use within the context of the - anemoi-training configuration, `losses` work best as a dictionary. - - If `losses` is a `tuple[dict]`, the `scalers` key will be extracted - before being passed to `get_loss_function`, and the `scalers` defined - in each loss only applied to the respective loss. Thereby `scalers` - added to this class will be routed correctly. - If `losses` is a `tuple[Callable]`, all `scalers` added to this class - will be added to all underlying losses. - And if `losses` is a `tuple[BaseLoss]`, no scalers added to - this class will be added to the underlying losses, as it is - assumed that will be done by the parent function. + Each child loss controls its own scalers via its `scalers` config key. + All available scalers are passed through to child losses unconditionally. Parameters ---------- @@ -69,8 +59,8 @@ def __init__( if a `tuple[dict]`: Tuple of losses to initialise with `get_loss_function`. Allows for kwargs to be passed, and weighings controlled. - If a loss should only have some of the scalers, set `scalers` in the loss config. - If no scalers are set, all scalers added to this class will be included. + Each child loss specifies its own `scalers` to control which + scalers it receives. if a `tuple[Callable]`: Will be called with `kwargs`, and all scalers added to this class added. if a `tuple[BaseLoss]`: @@ -83,8 +73,7 @@ def __init__( If None, all losses are weighted equally. by default None. available_scalers : dict[str, TENSOR_SPEC] | None, optional - Scaler tensors already filtered by the top-level CombinedLoss configuration. - These are passed down to child losses when present. + All scaler tensors available. Passed through to child losses. data_indices : IndexCollection | None, optional Training data indices needed by child losses that perform variable mapping. kwargs: Any @@ -97,7 +86,6 @@ def __init__( loss_weights=(1.0,), ) CombinedLoss.add_scaler(name = 'scaler_1', ...) - # Only added to the `MSELoss` if specified in it's `scalers`. -------- >>> CombinedLoss( losses = [anemoi.training.losses.MSELoss], @@ -110,28 +98,16 @@ def __init__( _target_: anemoi.training.losses.combined.CombinedLoss losses: - _target_: anemoi.training.losses.MSELoss - - _target_: anemoi.training.losses.MAELoss - scalers: ['*'] - loss_weights: [1.0, 0.6] - # All scalers passed to this class will be added to each underlying loss - ``` - - ``` - training_loss: - _target_: anemoi.training.losses.combined.CombinedLoss - losses: - - _target_: anemoi.training.losses.MSELoss - scalers: ['variable'] + scalers: ['variable', 'node_weights'] - _target_: anemoi.training.losses.MAELoss scalers: ['loss_weights_mask'] - scalers: ['*'] - # Only the specified scalers will be added to each loss + loss_weights: [1.0, 0.6] + # Each child loss specifies its own scalers ``` """ super().__init__() self.losses: list[type[BaseLoss]] = [] - self._loss_scaler_specification: dict[int, list[str]] = {} losses = (*(losses or []), *extra_losses) if loss_weights is None: @@ -143,28 +119,22 @@ def __init__( for i, loss in enumerate(losses): if isinstance(loss, DictConfig | dict): loss_config = dict(loss) - scaler_spec = loss_config.pop("scalers", ["*"]) - self._loss_scaler_specification[i] = scaler_spec - # Only propagate scaler declarations when explicitly provided. - if available_scalers: - loss_config["scalers"] = scaler_spec self.losses.append( get_loss_function( DictConfig(loss_config), scalers=available_scalers, data_indices=data_indices, - **dict(kwargs), + graph_data=kwargs.get("graph_data"), + data_node_name=kwargs.get("data_node_name"), ), ) elif isinstance(loss, type): - self._loss_scaler_specification[i] = ["*"] self.losses.append(loss(**kwargs)) else: assert isinstance(loss, BaseLoss) - self._loss_scaler_specification[i] = loss.scaler self.losses.append(loss) - self.add_module(str(i), self.losses[-1]) # (self.losses[-1].name + str(i), self.losses[-1]) + self.add_module(str(i), self.losses[-1]) self.loss_weights = loss_weights del self.scaler # Remove scaler property from parent class, as it is not used here @@ -223,15 +193,13 @@ def forward( @functools.wraps(ScaleTensor.add_scaler, assigned=("__doc__", "__annotations__")) def add_scaler(self, dimension: int | tuple[int], scaler: torch.Tensor, *, name: str | None = None) -> None: - for i, spec in self._loss_scaler_specification.items(): - if "*" in spec or name in spec: - self.losses[i].add_scaler(dimension=dimension, scaler=scaler, name=name) + for loss in self.losses: + loss.add_scaler(dimension=dimension, scaler=scaler, name=name) @functools.wraps(ScaleTensor.update_scaler, assigned=("__doc__", "__annotations__")) def update_scaler(self, name: str, scaler: torch.Tensor, *, override: bool = False) -> None: - for i, spec in self._loss_scaler_specification.items(): - if "*" in spec or name in spec: - self.losses[i].update_scaler(name=name, scaler=scaler, override=override) + for loss in self.losses: + loss.update_scaler(name=name, scaler=scaler, override=override) def has_scaler_for_dim(self, dim: TensorDim) -> bool: return any(loss.has_scaler_for_dim(dim=dim) for loss in self.losses) diff --git a/training/src/anemoi/training/losses/loss.py b/training/src/anemoi/training/losses/loss.py index fab996c425..9d5bbd83bd 100644 --- a/training/src/anemoi/training/losses/loss.py +++ b/training/src/anemoi/training/losses/loss.py @@ -27,6 +27,8 @@ METRIC_RANGE_DTYPE = dict[str, list[int]] +NESTED_LOSSES = ["anemoi.training.losses.MultiscaleLossWrapper"] +WRAPPED_LOSSES = ["anemoi.training.losses.aggregate.TimeAggregateLossWrapper"] LOGGER = logging.getLogger(__name__) @@ -104,6 +106,38 @@ def _extract_constructor_context( return context.for_loss_class(get_class(target)) +def _propagate_combined_scalers(loss_config: dict, scalers_to_include: list) -> None: + """Propagate parent scalers to CombinedLoss sub-losses that don't specify their own.""" + for sub_loss in loss_config.get("losses", []): + if ( + isinstance(sub_loss, dict) + and "scalers" not in sub_loss + and "MultiscaleLossWrapper" not in sub_loss.get("_target_", "") + ): + sub_loss["scalers"] = list(scalers_to_include) + + +def _build_wrapped_loss( + loss_config: dict, + scalers_to_include: list, + scalers: dict[str, TENSOR_SPEC] | None, + data_indices: "IndexCollection | None", +) -> BaseLoss: + """Instantiate a WRAPPED_LOSSES target (e.g. TimeAggregateLossWrapper).""" + inner_loss_config = loss_config.pop("loss_fn") + inner_loss = get_loss_function(OmegaConf.create(inner_loss_config), scalers, data_indices) + wrapper = instantiate(loss_config, loss_fn=inner_loss) + # Apply any scalers specified on the wrapper itself (delegated to the inner loss). + if scalers_to_include and scalers: + resolved = ( + [s for s in scalers if f"!{s}" not in scalers_to_include] + if "*" in scalers_to_include + else list(scalers_to_include) + ) + _apply_scalers(wrapper, resolved, scalers, data_indices) + return wrapper + + # Future import breaks other type hints TODO Harrison Cook def get_loss_function( config: DictConfig, @@ -155,9 +189,14 @@ def get_loss_function( target_variables = loss_config.pop("target_variables", None) graph_extra = {"data_node_name": data_node_name} if data_node_name is not None else {} + target = loss_config.get("_target_") - per_scale_loss_config = loss_config.pop("per_scale_loss", None) - if per_scale_loss_config is not None: + # For CombinedLoss, propagate parent scalers to sub-losses that don't specify their own. + if "CombinedLoss" in (target or "") and scalers_to_include: + _propagate_combined_scalers(loss_config, scalers_to_include) + + if target in NESTED_LOSSES: + per_scale_loss_config = loss_config.pop("per_scale_loss") per_scale_loss = get_loss_function( OmegaConf.create(per_scale_loss_config), scalers, @@ -173,13 +212,22 @@ def get_loss_function( **_graph_data_kwargs(target_cls, graph_data, graph_extra), ) - if scalers is None: - scalers = {} + if target in WRAPPED_LOSSES: + return _build_wrapped_loss(loss_config, scalers_to_include, scalers, data_indices) + + scalers = scalers or {} if "*" in scalers_to_include: scalers_to_include = [s for s in list(scalers.keys()) if f"!{s}" not in scalers_to_include] available_scalers = _filter_scalers(scalers_to_include, scalers) if has_scalers_config else None + # If the target class requests AVAILABLE_SCALERS (e.g. CombinedLoss), always + # pass the full unfiltered scalers so child losses can control their own. + if ( + hasattr(target_cls, "factory_context_keys") + and LossFactoryContextKey.AVAILABLE_SCALERS in target_cls.factory_context_keys + ): + available_scalers = scalers factory_context = LossFactoryContext( available_scalers=available_scalers, data_indices=data_indices, diff --git a/training/src/anemoi/training/losses/multiscale.py b/training/src/anemoi/training/losses/multiscale.py index ba6237b739..c3d3391d29 100644 --- a/training/src/anemoi/training/losses/multiscale.py +++ b/training/src/anemoi/training/losses/multiscale.py @@ -9,6 +9,7 @@ import logging +from collections.abc import Iterator from pathlib import Path import einops @@ -26,12 +27,12 @@ from anemoi.models.layers.graph_provider import ProjectionGraphProvider from anemoi.models.layers.sparse_projector import SparseProjector from anemoi.training.losses.base import BaseLoss +from anemoi.training.losses.base import BaseLossWrapper LOGGER = logging.getLogger(__name__) -class MultiscaleLossWrapper(BaseLoss): - """Apply the same base loss across progressively smoothed target fields.""" +class MultiscaleLossWrapper(BaseLossWrapper): name: str = "MultiscaleLossWrapper" needs_graph_data: bool = True @@ -93,7 +94,7 @@ def __init__( loss_matrices : list[Path | str] | None Deprecated. Pass inside *multiscale_config* instead. """ - super().__init__(ignore_nans=ignore_nans) + super().__init__(loss=per_scale_loss, ignore_nans=ignore_nans) _has_matrices = bool(loss_matrices) # [None] still signals file mode (identity scale) if _has_matrices or loss_matrices_path is not None: @@ -118,8 +119,6 @@ def __init__( len(weights) == self.num_scales ), f"Number of weights ({len(weights)}) must match number of scales ({self.num_scales})" self.weights = weights - self.loss = per_scale_loss - self.scaler = self.loss.scaler self.supports_sharding = True self.mloss = None self.projector = SparseProjector(autocast=autocast) @@ -128,19 +127,9 @@ def __init__( def needs_shard_layout_info(self) -> bool: return True - def update_scaler(self, name: str, scaler: torch.Tensor, *, override: bool = False) -> None: - """Update the scaler values for the internal loss. - - Parameters - ---------- - name : str - Name of the scaler to update - scaler : torch.Tensor - New scaler values - override : bool, optional - Whether to override existing scaler values, by default False - """ - self.loss.update_scaler(name=name, scaler=scaler, override=override) + def iter_leaf_losses(self) -> Iterator["BaseLoss"]: + """MultiscaleLossWrapper is a leaf: it performs substantive computation.""" + yield self def _load_smoothing_matrices( self, diff --git a/training/src/anemoi/training/losses/utils.py b/training/src/anemoi/training/losses/utils.py index 0285c35461..b62725c630 100644 --- a/training/src/anemoi/training/losses/utils.py +++ b/training/src/anemoi/training/losses/utils.py @@ -13,8 +13,8 @@ import logging from typing import TYPE_CHECKING +from anemoi.training.losses.base import BaseLossWrapper from anemoi.training.losses.combined import CombinedLoss -from anemoi.training.losses.multiscale import MultiscaleLossWrapper from anemoi.training.losses.variable_mapper import LossVariableMapper from anemoi.training.utils.enums import TensorDim @@ -61,13 +61,12 @@ def print_variable_scaling(loss: BaseLoss, data_indices: IndexCollection) -> dic variable_scaling[f"{base_key}{suffix}"] = print_variable_scaling(sub_loss, data_indices) return variable_scaling - if isinstance(loss, MultiscaleLossWrapper): - return print_variable_scaling(loss.loss, data_indices) - if isinstance(loss, LossVariableMapper): subset_vars = enumerate(loss.predicted_variables) # LossVariableMapper forwards scalers to its inner loss, so get scaling from there scaler_source = loss.loss.scaler + elif isinstance(loss, BaseLossWrapper): + return print_variable_scaling(loss.loss, data_indices) else: subset_vars = enumerate(data_indices.model.output.name_to_index.keys()) scaler_source = loss.scaler @@ -76,7 +75,7 @@ def print_variable_scaling(loss: BaseLoss, data_indices: IndexCollection) -> dic log_text = f"Final Variable Scaling in {loss.__class__.__name__}: " scaling_values, scaling_sum = {}, 0.0 for idx, name in subset_vars: - value = float(variable_scaling[idx]) + value = float(variable_scaling[idx]) if idx < variable_scaling.shape[0] else 1.0 log_text += f"{name}: {value:.4g}, " scaling_values[name] = value scaling_sum += value diff --git a/training/src/anemoi/training/losses/variable_mapper.py b/training/src/anemoi/training/losses/variable_mapper.py index 55c7ab074c..fca36579d6 100644 --- a/training/src/anemoi/training/losses/variable_mapper.py +++ b/training/src/anemoi/training/losses/variable_mapper.py @@ -17,6 +17,7 @@ from anemoi.models.data_indices.collection import IndexCollection from anemoi.training.losses.base import BaseLoss +from anemoi.training.losses.base import BaseLossWrapper from anemoi.training.losses.scaler_tensor import ScaleTensor from anemoi.training.utils.enums import TensorDim from anemoi.training.utils.index_space import IndexSpace @@ -24,7 +25,7 @@ LOGGER = logging.getLogger(__name__) -class LossVariableMapper(BaseLoss): +class LossVariableMapper(BaseLossWrapper): """Loss wrapper to filter variables to compute the loss on.""" def __init__( @@ -52,18 +53,9 @@ def __init__( target_variables, ), "predicted and target variables must have the same length for loss computation" - super().__init__() + super().__init__(loss=loss) self._loss_scaler_specification = {} - if not isinstance(loss, BaseLoss): - msg = f"Invalid loss type provided: {type(loss)}. Expected BaseLoss." - raise TypeError(msg) - self.loss = loss - if hasattr(self.loss, "scaler"): - # Share the inner loss scaler so scaler membership and updates remain visible - # to training/task utilities that inspect `loss.scaler`. - self.scaler = self.loss.scaler - self.supports_sharding = getattr(self.loss, "supports_sharding", False) self.predicted_variables = list(predicted_variables) if predicted_variables is not None else None self.target_variables = list(target_variables) if target_variables is not None else None self.data_indices: IndexCollection | None = None @@ -72,11 +64,6 @@ def __init__( if data_indices is not None: self.set_data_indices(data_indices) - @property - def needs_shard_layout_info(self) -> bool: - """Whether the wrapped loss requires explicit shard-layout metadata.""" - return getattr(self.loss, "needs_shard_layout_info", False) - def _get_predicted_indices_for_scaler_variable_axis(self, variable_size: int) -> list[int] | None: if variable_size == 1: # Broadcast scalers do not need filtering. @@ -155,15 +142,11 @@ def add_scaler(self, dimension: int | tuple[int], scaler: torch.Tensor, *, name: @functools.wraps(ScaleTensor.update_scaler) def update_scaler(self, name: str, scaler: torch.Tensor, *, override: bool = False) -> None: # Keep update behavior consistent with add_scaler for VARIABLE-axis scalers. - if hasattr(self.loss, "scaler") and name in self.loss.scaler.tensors: + if name in self.loss.scaler.tensors: dimension = self.loss.scaler.tensors[name][0] scaler = self._filter_variable_axis_scaler(dimension, scaler) self.loss.update_scaler(name=name, scaler=scaler, override=override) - @functools.wraps(ScaleTensor.has_scaler_for_dim) - def has_scaler_for_dim(self, dim: TensorDim) -> bool: - return self.loss.has_scaler_for_dim(dim=dim) - @staticmethod def _to_layout(layout: IndexSpace | str, *, layout_name: str) -> IndexSpace: if isinstance(layout, IndexSpace): @@ -321,7 +304,7 @@ def forward( without_scalers: list[str] | list[int] | None = None, grid_shard_slice: slice | None = None, group: ProcessGroup | None = None, - squash_mode: str = "avg", + squash_mode: str | None = None, pred_layout: IndexSpace | str | None = None, target_layout: IndexSpace | str | None = None, **kwargs, @@ -364,9 +347,10 @@ def forward( "without_scalers": without_scalers, "grid_shard_slice": grid_shard_slice, "group": group, - "squash_mode": squash_mode, }, ) + if squash_mode is not None: + loss_kwargs["squash_mode"] = squash_mode empty_metric_selection = False if isinstance(scaler_indices, tuple): diff --git a/training/src/anemoi/training/schemas/training.py b/training/src/anemoi/training/schemas/training.py index 28f353c61a..388ea1caf1 100644 --- a/training/src/anemoi/training/schemas/training.py +++ b/training/src/anemoi/training/schemas/training.py @@ -252,7 +252,6 @@ class ImplementedLossesUsingBaseLossSchema(StrEnum): mae = "anemoi.training.losses.MAELoss" logcosh = "anemoi.training.losses.LogCoshLoss" huber = "anemoi.training.losses.HuberLoss" - combined = "anemoi.training.losses.combined.CombinedLoss" fcl = "anemoi.training.losses.spectral.FourierCorrelationLoss" lsd = "anemoi.training.losses.spectral.LogSpectralDistance" logfft2d = "anemoi.training.losses.spectral.LogFFT2Distance" @@ -263,8 +262,8 @@ class ImplementedLossesUsingBaseLossSchema(StrEnum): class BaseLossSchema(BaseModel): target_: ImplementedLossesUsingBaseLossSchema = Field(..., alias="_target_") "Loss function object from anemoi.training.losses." - scalers: list[str] = Field(example=["variable"]) # TODO(Mario): Validate scalers are defined - "Scalers to include in loss calculation" + scalers: list[str] = Field(default_factory=list, example=["variable"]) + "Scalers to include in loss calculation. Defaults to empty (no scaling)." ignore_nans: bool = False "Allow nans in the loss and apply methods ignoring nans for measuring the loss." predicted_variables: list[str] | None = None @@ -297,6 +296,8 @@ class MultiscaleConfigDiskSchema(BaseModel): loss_matrices_path: str | None = None loss_matrices: list[str | None] + scalers: list[str] | None = None + "Scalers to apply to the wrapped loss (delegated to inner per_scale_loss)." class MultiscaleConfigOnTheFlySchema(BaseModel): @@ -338,6 +339,21 @@ class MultiScaleLossSchema(BaseModel): loss_matrices_path: str | None = None loss_matrices: list[str | None] | None = None + @field_validator("per_scale_loss", mode="before") + @classmethod + def add_empty_scalers_to_inner(cls, v: Any) -> Any: + """Inject empty scalers for inner loss; scalers flow through the wrapper.""" + if isinstance(v, dict) and "scalers" not in v: + v["scalers"] = [] + else: + from omegaconf import DictConfig + from omegaconf.omegaconf import open_dict + + if isinstance(v, DictConfig) and "scalers" not in v: + with open_dict(v): + v["scalers"] = [] + return v + @model_validator(mode="after") def check_no_deprecated_mixed_with_on_the_fly(self) -> Self: if isinstance(self.multiscale_config, MultiscaleConfigOnTheFlySchema) and ( @@ -352,6 +368,33 @@ def check_no_deprecated_mixed_with_on_the_fly(self) -> Self: return self +class TimeAggregateLossWrapperSchema(BaseModel): + """Schema for TimeAggregateLossWrapper used inside CombinedLoss.""" + + target_: Literal["anemoi.training.losses.aggregate.TimeAggregateLossWrapper"] = Field(..., alias="_target_") + time_aggregation_types: list[Literal["diff", "mean", "min", "max"]] = Field(min_length=1) + "Time aggregation operations to apply over the time dimension before computing the loss." + loss_fn: BaseLossSchema | CRPSSchema + "Inner loss function applied to each time-aggregated output." + scalers: list[str] | None = None + "Scalers to apply to the wrapped loss (delegated to inner loss_fn)." + + @field_validator("loss_fn", mode="before") + @classmethod + def add_empty_scalers_to_inner(cls, v: Any) -> Any: + """Inject empty scalers for inner loss; scalers flow through the wrapper.""" + if isinstance(v, dict) and "scalers" not in v: + v["scalers"] = [] + else: + from omegaconf import DictConfig + from omegaconf.omegaconf import open_dict + + if isinstance(v, DictConfig) and "scalers" not in v: + with open_dict(v): + v["scalers"] = [] + return v + + class HuberLossSchema(BaseLossSchema): delta: float = 1.0 "Threshold for Huber loss." @@ -387,18 +430,30 @@ def _loss_discriminator(v: Any) -> str: return "spectral" if target == "anemoi.training.losses.HuberLoss": return "huber" + if target == "anemoi.training.losses.aggregate.TimeAggregateLossWrapper": + return "time_aggregate" return "base" class CombinedLossSchema(BaseLossSchema): + """Schema for CombinedLoss. + + Top-level ``scalers`` act as defaults for sub-losses that don't specify their own. + Sub-losses that explicitly set ``scalers`` override the parent value. + """ + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + target_: Literal["anemoi.training.losses.combined.CombinedLoss"] = Field(..., alias="_target_") + "CombinedLoss target." losses: list[ Annotated[ Annotated[BaseLossSchema, Tag("base")] | Annotated[HuberLossSchema, Tag("huber")] | Annotated[CRPSSchema, Tag("crps")] | Annotated[SpectralLossSchema, Tag("spectral")] - | Annotated[MultiScaleLossSchema, Tag("multiscale")], + | Annotated[MultiScaleLossSchema, Tag("multiscale")] + | Annotated[TimeAggregateLossWrapperSchema, Tag("time_aggregate")], Discriminator(_loss_discriminator), ] ] = Field(min_length=1) @@ -406,24 +461,35 @@ class CombinedLossSchema(BaseLossSchema): loss_weights: list[int | float] | None = None "Weightings of losses, if not set, all losses are weighted equally." - @field_validator("losses", mode="before") + @model_validator(mode="before") @classmethod - def add_empty_scalers(cls, losses: Any) -> Any: - """Add empty scalers to loss functions that use them (not MultiscaleLossWrapper).""" - from omegaconf import OmegaConf + def propagate_scalers_to_children(cls, data: Any) -> Any: + """Propagate parent scalers to sub-losses that don't specify their own. + + MultiscaleLossWrapper is skipped because it manages scalers via per_scale_loss. + """ + from omegaconf import DictConfig from omegaconf.omegaconf import open_dict + parent_scalers = data.get("scalers", []) if hasattr(data, "get") else [] + if not parent_scalers: + return data + + losses = data.get("losses", []) if hasattr(data, "get") else [] for loss in losses: - target = loss.get("_target_", "") if hasattr(loss, "get") else "" + if not hasattr(loss, "get"): + continue + target = loss.get("_target_", "") + # MultiscaleLossWrapper manages scalers on per_scale_loss, not at top level if "MultiscaleLossWrapper" in str(target): continue if "scalers" not in loss: - if OmegaConf.is_config(loss): + if isinstance(loss, DictConfig): with open_dict(loss): - loss["scalers"] = [] - else: - loss["scalers"] = [] - return losses + loss["scalers"] = list(parent_scalers) + elif isinstance(loss, dict): + loss["scalers"] = list(parent_scalers) + return data @model_validator(mode="after") def check_length_of_weights_and_losses(self) -> Self: @@ -441,6 +507,7 @@ def check_length_of_weights_and_losses(self) -> Self: | Annotated[CombinedLossSchema, Tag("combined")] | Annotated[CRPSSchema, Tag("crps")] | Annotated[SpectralLossSchema, Tag("spectral")] + | Annotated[TimeAggregateLossWrapperSchema, Tag("time_aggregate")] | Annotated[MultiScaleLossSchema, Tag("multiscale")], Discriminator(_loss_discriminator), ] @@ -514,10 +581,10 @@ class BaseTrainingSchema(BaseModel): "Config for gradient clipping." strategy: StrategySchemas "Strategy to use." - weight_averaging: WeightAveragingSchema | None = Field(default=None) - "Config for weight averaging (SWA or EMA). Set to null to disable." training_loss: DatasetDict[LossSchemas] "Training loss configuration." + weight_averaging: WeightAveragingSchema | None = Field(default=None) + "Config for weight averaging (SWA or EMA). Set to null to disable." loss_gradient_scaling: bool = False "Dynamic rescaling of the loss gradient. Not yet tested." scalers: DatasetDict[dict[str, ScalerSchema]] diff --git a/training/tests/integration/conftest.py b/training/tests/integration/conftest.py index 04cccd8d12..c7def0afee 100644 --- a/training/tests/integration/conftest.py +++ b/training/tests/integration/conftest.py @@ -250,23 +250,40 @@ def lam_config_with_graph( return cfg, urls +def _get_multiscale_cfgs(training_loss_cfg: DictConfig) -> list[DictConfig]: + """Extract multiscale_config dicts that contain loss_matrices.""" + multiscale_cfg = training_loss_cfg.get("multiscale_config") + if multiscale_cfg is not None and "loss_matrices" in multiscale_cfg: + return [multiscale_cfg] + if "losses" in training_loss_cfg: + results = [] + for sub_loss in training_loss_cfg.losses: + mc = sub_loss.get("multiscale_config") + if mc is not None and "loss_matrices" in mc: + results.append(mc) + return results + return [] + + def handle_truncation_matrices(cfg: DictConfig, get_test_data: GetTestData) -> DictConfig: url_loss_matrices = cfg.system.input.loss_matrices_path tmp_path_loss_matrices = None training_losses_cfg = get_multiple_datasets_config(cfg.training.training_loss) for dataset_name, training_loss_cfg in training_losses_cfg.items(): - multiscale_cfg = training_loss_cfg.get("multiscale_config") - if multiscale_cfg is None: - continue - for file in multiscale_cfg.get("loss_matrices") or []: - if file is not None: - tmp_path_loss_matrices = get_test_data(url_loss_matrices + file) + multiscale_cfgs = _get_multiscale_cfgs(training_loss_cfg) + + for multiscale_cfg in multiscale_cfgs: + for file in multiscale_cfg.get("loss_matrices") or []: + if file is not None: + tmp_path_loss_matrices = get_test_data(url_loss_matrices + file) + if tmp_path_loss_matrices is not None: + OmegaConf.set_struct(multiscale_cfg, False) + multiscale_cfg.loss_matrices_path = str(Path(tmp_path_loss_matrices).parent) + if tmp_path_loss_matrices is not None: resolved_path = str(Path(tmp_path_loss_matrices).parent) cfg.system.input.loss_matrices_path = Path(tmp_path_loss_matrices).parent - OmegaConf.set_struct(multiscale_cfg, False) - multiscale_cfg.loss_matrices_path = resolved_path val_multiscale_cfg = cfg.training.validation_metrics.datasets[dataset_name].multiscale.multiscale_config OmegaConf.set_struct(val_multiscale_cfg, False) diff --git a/training/tests/unit/diagnostics/callbacks/test_per_timestep_metrics.py b/training/tests/unit/diagnostics/callbacks/test_per_timestep_metrics.py new file mode 100644 index 0000000000..db08d5b6e9 --- /dev/null +++ b/training/tests/unit/diagnostics/callbacks/test_per_timestep_metrics.py @@ -0,0 +1,224 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Tests for PerTimestepMetrics callback.""" + +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +import torch + +from anemoi.training.diagnostics.callbacks.per_timestep_metrics import PerTimestepMetrics +from anemoi.training.losses.base import BaseLoss + +BS = 2 +TIME = 6 +ENS = 4 +GRID = 16 +NVAR = 3 + + +@pytest.fixture +def callback() -> PerTimestepMetrics: + return PerTimestepMetrics(every_n_batches=1) + + +@pytest.fixture +def callback_every_2() -> PerTimestepMetrics: + return PerTimestepMetrics(every_n_batches=2) + + +class FakeLoss(BaseLoss): + """Minimal BaseLoss subclass for testing.""" + + def __init__(self) -> None: + super().__init__() + + def forward(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: object) -> torch.Tensor: # noqa: ARG002 + return torch.tensor(1.0) + + @property + def name(self) -> str: + return "fake" + + +def _make_pl_module( + n_timesteps: int = TIME, + n_ens: int = ENS, + n_grid: int = GRID, + n_var: int = NVAR, +) -> MagicMock: + """Create a mocked pl_module with the attributes needed by the callback.""" + pl_module = MagicMock() + + pred = torch.randn(BS, n_timesteps, n_ens, n_grid, n_var) + target = torch.randn(BS, n_timesteps, n_grid, n_var) + + # task.get_inputs returns input dict + pl_module.task.get_inputs.return_value = {"data": torch.randn(BS, 2, n_grid, n_var)} + pl_module._expand_ens_dim.return_value = {"data": torch.randn(BS, 2, n_ens, n_grid, n_var)} + + # model forward returns predictions + pl_module.__call__ = MagicMock(return_value={"data": pred}) + pl_module.return_value = {"data": pred} + + # task.get_targets returns targets + y_full = {"data": target.unsqueeze(2)} # add ens dim for _collapse_ens_dim + pl_module.task.get_targets.return_value = y_full + pl_module._collapse_ens_dim.return_value = {"data": target} + + # No ensemble comm group (single GPU case) + pl_module.ens_comm_subgroup = None + + pl_module.model.post_processors = {"data": lambda x, **_: x} + + fake_loss = FakeLoss() + pl_module.metrics = {"data": {"fkcrps": fake_loss}} + + pl_module.val_metric_ranges = { + "data": { + "pl": torch.arange(0, 2), + "sfc": torch.arange(2, 3), + }, + } + + pl_module.grid_shard_slice = {"data": None} + pl_module.model_comm_group = None + pl_module.logger_enabled = True + pl_module.data_indices = MagicMock() + + return pl_module + + +def _make_trainer(precision: str = "32-true") -> MagicMock: + trainer = MagicMock() + trainer.precision = precision + return trainer + + +def _make_batch(n_timesteps: int = TIME) -> dict[str, torch.Tensor]: + """Create a batch dict with the expected structure.""" + total_steps = 2 + n_timesteps + return {"data": torch.randn(BS, total_steps, GRID, NVAR)} + + +class TestPerTimestepMetrics: + def test_init_default(self) -> None: + cb = PerTimestepMetrics() + assert cb.every_n_batches == 1 + + def test_init_custom(self) -> None: + cb = PerTimestepMetrics(every_n_batches=5) + assert cb.every_n_batches == 5 + + def test_skips_non_matching_batch(self, callback_every_2: PerTimestepMetrics) -> None: + """Callback should skip batches that don't match every_n_batches.""" + trainer = _make_trainer() + pl_module = _make_pl_module() + batch = _make_batch() + + # batch_idx=1 should be skipped (1 % 2 != 0) + callback_every_2.on_validation_batch_end(trainer, pl_module, [], batch, batch_idx=1) + pl_module.task.get_inputs.assert_not_called() + + def test_runs_on_matching_batch(self, callback_every_2: PerTimestepMetrics) -> None: + """Callback should run on batches matching every_n_batches.""" + trainer = _make_trainer() + pl_module = _make_pl_module() + batch = _make_batch() + + callback_every_2.on_validation_batch_end(trainer, pl_module, [], batch, batch_idx=0) + pl_module.task.get_inputs.assert_called_once() + + def test_logs_per_timestep_metrics(self, callback: PerTimestepMetrics) -> None: + """Callback should log metrics for each timestep and variable group.""" + trainer = _make_trainer() + pl_module = _make_pl_module() + batch = _make_batch() + + callback.on_validation_batch_end(trainer, pl_module, [], batch, batch_idx=0) + + # Should have logged: TIME timesteps * 2 var groups * 1 metric = 12 calls + assert pl_module.log.call_count == TIME * 2 + + # Check metric names + logged_names = [call.args[0] for call in pl_module.log.call_args_list] + for t in range(1, TIME + 1): + assert f"val_fkcrps_metric/data/pl/t_{t}" in logged_names + assert f"val_fkcrps_metric/data/sfc/t_{t}" in logged_names + + def test_log_kwargs(self, callback: PerTimestepMetrics) -> None: + """Check that log is called with correct kwargs.""" + trainer = _make_trainer() + pl_module = _make_pl_module() + batch = _make_batch() + + callback.on_validation_batch_end(trainer, pl_module, [], batch, batch_idx=0) + + # Check first log call kwargs + _, kwargs = pl_module.log.call_args_list[0] + assert kwargs["on_epoch"] is True + assert kwargs["on_step"] is False + assert kwargs["prog_bar"] is False + assert kwargs["sync_dist"] is True + assert kwargs["batch_size"] == BS + + def test_handles_single_timestep(self, callback: PerTimestepMetrics) -> None: + """Should work with a single output timestep.""" + trainer = _make_trainer() + pl_module = _make_pl_module(n_timesteps=1) + batch = _make_batch(n_timesteps=1) + + callback.on_validation_batch_end(trainer, pl_module, [], batch, batch_idx=0) + + # 1 timestep * 2 groups = 2 log calls + assert pl_module.log.call_count == 2 + logged_names = [call.args[0] for call in pl_module.log.call_args_list] + assert "val_fkcrps_metric/data/pl/t_1" in logged_names + assert "val_fkcrps_metric/data/sfc/t_1" in logged_names + + def test_skips_non_baseloss_metrics(self, callback: PerTimestepMetrics) -> None: + """Non-BaseLoss metrics should be skipped.""" + trainer = _make_trainer() + pl_module = _make_pl_module() + batch = _make_batch() + + pl_module.metrics["data"]["non_loss"] = MagicMock(spec=[]) + + callback.on_validation_batch_end(trainer, pl_module, [], batch, batch_idx=0) + + # Only BaseLoss metrics logged: TIME * 2 groups + assert pl_module.log.call_count == TIME * 2 + + def test_uses_autocast_for_mixed_precision(self, callback: PerTimestepMetrics) -> None: + """Should apply autocast when precision is mixed.""" + trainer = _make_trainer(precision="16-mixed") + pl_module = _make_pl_module() + batch = _make_batch() + + # Should not raise + callback.on_validation_batch_end(trainer, pl_module, [], batch, batch_idx=0) + assert pl_module.log.call_count == TIME * 2 + + def test_ensemble_gather(self, callback: PerTimestepMetrics) -> None: + """Should call gather_tensor when ens_comm_subgroup is set.""" + pl_module = _make_pl_module() + batch = _make_batch() + + pl_module.ens_comm_subgroup = MagicMock() + pl_module.ens_comm_subgroup_size = 2 + + with patch( + "anemoi.models.distributed.graph.gather_tensor", + side_effect=lambda x, **_: x, + ): + callback._eval_per_timestep(pl_module, batch) + + assert pl_module.log.call_count == TIME * 2 diff --git a/training/tests/unit/losses/test_aggregate_loss.py b/training/tests/unit/losses/test_aggregate_loss.py new file mode 100644 index 0000000000..39048310b3 --- /dev/null +++ b/training/tests/unit/losses/test_aggregate_loss.py @@ -0,0 +1,416 @@ +# (C) Copyright 2026 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest +import torch + +from anemoi.training.losses.aggregate import TimeAggregateLossWrapper +from anemoi.training.losses.base import BaseLoss +from anemoi.training.losses.kcrps import CRPS +from anemoi.training.losses.mae import MAELoss +from anemoi.training.losses.multiscale import MultiscaleLossWrapper +from anemoi.training.utils.enums import TensorDim + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_loss() -> MAELoss: + """Return an MAE loss with a unit grid scaler (4 grid points).""" + loss = MAELoss() + loss.add_scaler(TensorDim.GRID, torch.ones(4), name="unit_grid") + return loss + + +def _make_crps_loss() -> CRPS: + """Return a CRPS loss with a unit grid scaler (4 grid points).""" + loss = CRPS(no_autocast=False) + loss.add_scaler(TensorDim.GRID, torch.ones(4), name="unit_grid") + return loss + + +# Shapes used throughout: (bs=1, time=3, ens=1, latlon=4, nvar=2) +BS, TIME, ENS, LATLON, NVAR = 1, 3, 1, 4, 2 +# CRPS requires ens > 1 +ENS_CRPS = 3 + + +@pytest.fixture +def pred() -> torch.Tensor: + return torch.rand(BS, TIME, ENS, LATLON, NVAR) + + +@pytest.fixture +def target() -> torch.Tensor: + return torch.rand(BS, TIME, LATLON, NVAR) + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_is_base_loss() -> None: + wrapper = TimeAggregateLossWrapper(["mean"], _make_loss()) + assert isinstance(wrapper, BaseLoss) + + +def test_stores_loss_and_agg_types() -> None: + inner = _make_loss() + wrapper = TimeAggregateLossWrapper(["mean", "diff"], inner) + assert wrapper.loss is inner + assert wrapper.time_aggregation_types == ["mean", "diff"] + + +# --------------------------------------------------------------------------- +# Output shape / type +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("agg_op", ["mean", "min", "max", "diff"]) +def test_returns_scalar_tensor(agg_op: str, pred: torch.Tensor, target: torch.Tensor) -> None: + wrapper = TimeAggregateLossWrapper([agg_op], _make_loss()) + result = wrapper(pred, target) + assert isinstance(result, torch.Tensor) + assert result.numel() == 1 + + +def test_multiple_agg_ops_return_scalar(pred: torch.Tensor, target: torch.Tensor) -> None: + wrapper = TimeAggregateLossWrapper(["mean", "max", "diff"], _make_loss()) + result = wrapper(pred, target) + assert result.numel() == 1 + + +# --------------------------------------------------------------------------- +# Empty aggregation list +# --------------------------------------------------------------------------- + + +def test_empty_aggregation_returns_zero(pred: torch.Tensor, target: torch.Tensor) -> None: + wrapper = TimeAggregateLossWrapper([], _make_loss()) + result = wrapper(pred, target) + assert torch.allclose(result, torch.zeros(1)) + + +# --------------------------------------------------------------------------- +# Correctness: accumulation across multiple time aggregation types +# --------------------------------------------------------------------------- + + +def test_loss_accumulates_across_agg_ops(pred: torch.Tensor, target: torch.Tensor) -> None: + """Combined wrapper loss equals average of individual wrapper losses.""" + inner = _make_loss() + + wrapper_mean = TimeAggregateLossWrapper(["mean"], inner) + wrapper_diff = TimeAggregateLossWrapper(["diff"], inner) + wrapper_both = TimeAggregateLossWrapper(["mean", "diff"], inner) + + loss_mean = wrapper_mean(pred, target) + loss_diff = wrapper_diff(pred, target) + loss_both = wrapper_both(pred, target) + + assert torch.allclose(loss_both, (loss_mean + loss_diff) / 2, atol=1e-6) + + +# --------------------------------------------------------------------------- +# Correctness: "diff" aggregation uses temporal differences +# --------------------------------------------------------------------------- + + +def test_diff_aggregation_computes_temporal_differences() -> None: + """The diff wrapper should apply loss on (pred[:,1:]-pred[:,:-1]) vs (target[:,1:]-target[:,:-1]).""" + inner = _make_loss() + + pred = torch.rand(BS, TIME, ENS, LATLON, NVAR) + target = torch.rand(BS, TIME, LATLON, NVAR) + + pred_diff = pred[:, 1:, ...] - pred[:, :-1, ...] + target_diff = target[:, 1:, ...] - target[:, :-1, ...] + + wrapper_diff = TimeAggregateLossWrapper(["diff"], inner) + # The wrapper iterates per diff-step to handle time scalers correctly. + expected = torch.tensor(0.0) + for step in range(pred_diff.shape[1]): + expected = expected + inner(pred_diff[:, step : step + 1, ...], target_diff[:, step : step + 1, ...]) + result = wrapper_diff(pred, target) + + assert torch.allclose(result, expected, atol=1e-6) + + +# --------------------------------------------------------------------------- +# Correctness: "mean"/"min"/"max" aggregation reduces over time dim +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("agg_op", ["mean", "min", "max"]) +def test_reduction_aggregation_reduces_time_dim(agg_op: str) -> None: + inner = _make_loss() + pred = torch.rand(BS, TIME, ENS, LATLON, NVAR) + target = torch.rand(BS, TIME, LATLON, NVAR) + + if agg_op == "min": + pred_agg = torch.amin(pred, dim=1, keepdim=True) + target_agg = torch.amin(target, dim=1, keepdim=True) + elif agg_op == "max": + pred_agg = torch.amax(pred, dim=1, keepdim=True) + target_agg = torch.amax(target, dim=1, keepdim=True) + else: + pred_agg = torch.mean(pred, dim=1, keepdim=True) + target_agg = torch.mean(target, dim=1, keepdim=True) + + expected = inner(pred_agg, target_agg) + result = TimeAggregateLossWrapper([agg_op], inner)(pred, target) + + assert torch.allclose(result, expected, atol=1e-6) + + +# --------------------------------------------------------------------------- +# CRPS tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("agg_op", ["mean", "min", "max", "diff"]) +def test_crps_returns_scalar_tensor(agg_op: str) -> None: + """TimeAggregateLossWrapper with CRPS should return a scalar for each agg type.""" + pred = torch.rand(BS, TIME, ENS_CRPS, LATLON, NVAR) + target = torch.rand(BS, TIME, LATLON, NVAR) + wrapper = TimeAggregateLossWrapper([agg_op], _make_crps_loss()) + result = wrapper(pred, target) + assert isinstance(result, torch.Tensor) + assert result.numel() == 1 + + +def test_crps_multiple_agg_ops_return_scalar() -> None: + """Multiple aggregation types should accumulate into a single scalar.""" + pred = torch.rand(BS, TIME, ENS_CRPS, LATLON, NVAR) + target = torch.rand(BS, TIME, LATLON, NVAR) + wrapper = TimeAggregateLossWrapper(["mean", "diff"], _make_crps_loss()) + result = wrapper(pred, target) + assert result.numel() == 1 + + +def test_crps_loss_accumulates_across_agg_ops() -> None: + """Combined CRPS wrapper loss equals average of individual wrapper losses.""" + inner = _make_crps_loss() + pred = torch.rand(BS, TIME, ENS_CRPS, LATLON, NVAR) + target = torch.rand(BS, TIME, LATLON, NVAR) + + loss_mean = TimeAggregateLossWrapper(["mean"], inner)(pred, target) + loss_diff = TimeAggregateLossWrapper(["diff"], inner)(pred, target) + loss_both = TimeAggregateLossWrapper(["mean", "diff"], inner)(pred, target) + + assert torch.allclose(loss_both, (loss_mean + loss_diff) / 2, atol=1e-6) + + +@pytest.mark.parametrize("agg_op", ["mean", "min", "max"]) +def test_crps_reduction_reduces_time_dim(agg_op: str) -> None: + """CRPS wrapper with time-reduction passes keepdim=True aggregated tensors to inner loss.""" + inner = _make_crps_loss() + pred = torch.rand(BS, TIME, ENS_CRPS, LATLON, NVAR) + target = torch.rand(BS, TIME, LATLON, NVAR) + + if agg_op == "min": + pred_agg = torch.amin(pred, dim=1, keepdim=True) + target_agg = torch.amin(target, dim=1, keepdim=True) + elif agg_op == "max": + pred_agg = torch.amax(pred, dim=1, keepdim=True) + target_agg = torch.amax(target, dim=1, keepdim=True) + else: + pred_agg = torch.mean(pred, dim=1, keepdim=True) + target_agg = torch.mean(target, dim=1, keepdim=True) + + expected = inner(pred_agg, target_agg) + result = TimeAggregateLossWrapper([agg_op], inner)(pred, target) + + assert torch.allclose(result, expected, atol=1e-6) + + +def test_crps_wrapper_forwards_explicit_squash_mode() -> None: + inner = _make_crps_loss() + pred = torch.rand(BS, TIME, ENS_CRPS, LATLON, NVAR) + target = torch.rand(BS, TIME, LATLON, NVAR) + pred_mean = torch.mean(pred, dim=1, keepdim=True) + target_mean = torch.mean(target, dim=1, keepdim=True) + + expected = inner(pred_mean, target_mean, squash_mode="avg") + result = TimeAggregateLossWrapper(["mean"], inner)(pred, target, squash_mode="avg") + + assert torch.allclose(result, expected, atol=1e-6) + + +# --------------------------------------------------------------------------- +# Unknown aggregation type raises ValueError +# --------------------------------------------------------------------------- + + +def test_unknown_agg_op_raises(pred: torch.Tensor, target: torch.Tensor) -> None: + wrapper = TimeAggregateLossWrapper(["sum"], _make_loss()) + with pytest.raises(ValueError, match="Unknown aggregation type"): + wrapper(pred, target) + + +# --------------------------------------------------------------------------- +# ignore_nans flag is forwarded to BaseLoss +# --------------------------------------------------------------------------- + + +def test_ignore_nans_flag() -> None: + wrapper = TimeAggregateLossWrapper(["mean"], _make_loss(), ignore_nans=True) + assert wrapper.avg_function is torch.nanmean + assert wrapper.sum_function is torch.nansum + + +def test_default_no_ignore_nans() -> None: + wrapper = TimeAggregateLossWrapper(["mean"], _make_loss()) + assert wrapper.avg_function is torch.mean + assert wrapper.sum_function is torch.sum + + +# --------------------------------------------------------------------------- +# Transparent wrapper: scaler delegation +# --------------------------------------------------------------------------- + + +def test_scaler_is_shared_with_inner_loss() -> None: + inner = _make_loss() + wrapper = TimeAggregateLossWrapper(["mean"], inner) + assert wrapper.scaler is inner.scaler + + +def test_add_scaler_reaches_inner_loss() -> None: + inner = MAELoss() + wrapper = TimeAggregateLossWrapper(["mean"], inner) + scaler = torch.ones(NVAR) + wrapper.add_scaler(TensorDim.VARIABLE, scaler, name="var_scaler") + assert inner.scaler.has_scaler_for_dim(TensorDim.VARIABLE) + + +def test_update_scaler_delegates_to_inner_loss() -> None: + inner = _make_loss() + wrapper = TimeAggregateLossWrapper(["mean"], inner) + new_grid = torch.ones(4) * 2.0 + wrapper.update_scaler("unit_grid", new_grid, override=True) + assert torch.allclose(inner.scaler.unit_grid, new_grid) + + +def test_has_scaler_for_dim_delegates() -> None: + inner = _make_loss() + wrapper = TimeAggregateLossWrapper(["mean"], inner) + assert wrapper.has_scaler_for_dim(TensorDim.GRID) is True + assert wrapper.has_scaler_for_dim(TensorDim.VARIABLE) is False + + +# --------------------------------------------------------------------------- +# Transparent wrapper: metadata delegation +# --------------------------------------------------------------------------- + + +def test_supports_sharding_matches_inner() -> None: + inner = _make_loss() + wrapper = TimeAggregateLossWrapper(["mean"], inner) + assert wrapper.supports_sharding == inner.supports_sharding + + +def test_supports_sharding_propagates_false() -> None: + inner = _make_loss() + inner.supports_sharding = False + wrapper = TimeAggregateLossWrapper(["mean"], inner) + assert wrapper.supports_sharding is False + + +def test_needs_shard_layout_info_default_false() -> None: + inner = _make_loss() + wrapper = TimeAggregateLossWrapper(["mean"], inner) + assert wrapper.needs_shard_layout_info is False + + +def test_iter_leaf_losses_yields_inner_leaves() -> None: + inner = _make_loss() + wrapper = TimeAggregateLossWrapper(["mean"], inner) + leaves = list(wrapper.iter_leaf_losses()) + assert leaves == [inner] + assert wrapper not in leaves + + +# --------------------------------------------------------------------------- +# Nested composition: TimeAggregateLossWrapper(MultiscaleLossWrapper(...)) +# --------------------------------------------------------------------------- + + +def _make_multiscale_wrapper(inner: BaseLoss | None = None) -> "MultiscaleLossWrapper": + """Build a single-scale MultiscaleLossWrapper (no smoothing matrices).""" + if inner is None: + inner = _make_loss() + return MultiscaleLossWrapper( + per_scale_loss=inner, + weights=[1.0], + ) + + +def test_nested_scaler_shared_through_chain() -> None: + leaf = _make_loss() + ms = _make_multiscale_wrapper(leaf) + wrapper = TimeAggregateLossWrapper(["mean"], ms) + # All three should share the same scaler + assert wrapper.scaler is ms.scaler + assert ms.scaler is leaf.scaler + + +def test_nested_add_scaler_reaches_leaf() -> None: + leaf = MAELoss() + ms = _make_multiscale_wrapper(leaf) + wrapper = TimeAggregateLossWrapper(["mean"], ms) + wrapper.add_scaler(TensorDim.GRID, torch.ones(4), name="grid_w") + assert leaf.scaler.has_scaler_for_dim(TensorDim.GRID) + + +def test_nested_iter_leaf_losses_reaches_innermost() -> None: + leaf = _make_loss() + ms = _make_multiscale_wrapper(leaf) + wrapper = TimeAggregateLossWrapper(["mean"], ms) + # MultiscaleLossWrapper inherits default iter_leaf_losses (yields self), + # so the leaf list should be [ms], not [wrapper] + leaves = list(wrapper.iter_leaf_losses()) + assert wrapper not in leaves + assert ms in leaves + + +# --------------------------------------------------------------------------- +# CombinedLoss integration +# --------------------------------------------------------------------------- + + +def test_combined_loss_scaler_reaches_wrapped_inner() -> None: + inner1 = MAELoss() + inner2 = MAELoss() + wrapper = TimeAggregateLossWrapper(["mean"], inner2) + + # Verify that add_scaler on the wrapper propagates to the inner loss + grid_scaler = torch.ones(4) + wrapper.add_scaler(TensorDim.GRID, grid_scaler, name="node_weights") + inner1.add_scaler(TensorDim.GRID, grid_scaler, name="node_weights") + + # Both leaf losses should have the scaler + assert inner1.scaler.has_scaler_for_dim(TensorDim.GRID) + assert inner2.scaler.has_scaler_for_dim(TensorDim.GRID) + + +def test_combined_loss_iter_leaf_losses_includes_wrapped() -> None: + from anemoi.training.losses.combined import CombinedLoss + + inner1 = _make_loss() + inner2 = _make_loss() + wrapper = TimeAggregateLossWrapper(["mean"], inner2) + + combined = CombinedLoss(losses=[inner1, wrapper]) + leaves = list(combined.iter_leaf_losses()) + assert inner1 in leaves + assert inner2 in leaves + assert wrapper not in leaves diff --git a/training/tests/unit/losses/test_combined_loss.py b/training/tests/unit/losses/test_combined_loss.py index 12001706b4..36a05d98b8 100644 --- a/training/tests/unit/losses/test_combined_loss.py +++ b/training/tests/unit/losses/test_combined_loss.py @@ -44,10 +44,9 @@ def test_combined_loss() -> None: { "_target_": "anemoi.training.losses.CombinedLoss", "losses": [ - {"_target_": "anemoi.training.losses.MSELoss"}, - {"_target_": "anemoi.training.losses.MAELoss"}, + {"_target_": "anemoi.training.losses.MSELoss", "scalers": ["test"]}, + {"_target_": "anemoi.training.losses.MAELoss", "scalers": ["test"]}, ], - "scalers": ["test"], "loss_weights": [1.0, 0.5], }, ), @@ -68,10 +67,9 @@ def test_combined_loss_invalid_loss_weights() -> None: { "_target_": "anemoi.training.losses.combined.CombinedLoss", "losses": [ - {"_target_": "anemoi.training.losses.MSELoss"}, - {"_target_": "anemoi.training.losses.MAELoss"}, + {"_target_": "anemoi.training.losses.MSELoss", "scalers": ["test"]}, + {"_target_": "anemoi.training.losses.MAELoss", "scalers": ["test"]}, ], - "scalers": ["test"], "loss_weights": [1.0, 0.5, 1], }, ), @@ -106,7 +104,6 @@ def test_combined_loss_seperate_scalers() -> None: {"_target_": "anemoi.training.losses.MSELoss", "scalers": ["test"]}, {"_target_": "anemoi.training.losses.MAELoss", "scalers": ["test2"]}, ], - "scalers": ["test", "test2"], "loss_weights": [1.0, 0.5], }, ), @@ -208,7 +205,6 @@ def test_combined_loss_with_filtered_target_only_subloss_preserves_scaler_remapp }, ], "loss_weights": [1.0, 0.5], - "scalers": ["*"], }, ), scalers={ diff --git a/training/tests/unit/schemas/test_training_schemas.py b/training/tests/unit/schemas/test_training_schemas.py index 5cd182ddda..81a9253dc0 100644 --- a/training/tests/unit/schemas/test_training_schemas.py +++ b/training/tests/unit/schemas/test_training_schemas.py @@ -15,6 +15,36 @@ from anemoi.training.schemas.training import MultiscaleConfigOnTheFlySchema from anemoi.training.schemas.training import MultiScaleLossSchema from anemoi.training.schemas.training import OptimizerSchema +from anemoi.training.schemas.training import TimeAggregateLossWrapperSchema + +_TIME_AGG_CFG = { + "_target_": "anemoi.training.losses.aggregate.TimeAggregateLossWrapper", + "time_aggregation_types": ["mean", "diff"], + "loss_fn": { + "_target_": "anemoi.training.losses.MSELoss", + "scalers": ["node_weights"], + }, +} + + +def test_time_aggregate_loss_config_valid() -> None: + """TimeAggregateLossWrapperSchema accepts a valid config.""" + schema = TimeAggregateLossWrapperSchema(**_TIME_AGG_CFG) + assert schema.time_aggregation_types == ["mean", "diff"] + + +def test_time_aggregate_loss_config_invalid_agg_type() -> None: + """Unknown aggregation type is rejected.""" + cfg = {**_TIME_AGG_CFG, "time_aggregation_types": ["sum"]} + with pytest.raises(ValidationError): + TimeAggregateLossWrapperSchema(**cfg) + + +def test_time_aggregate_loss_config_empty_agg_types() -> None: + """Empty aggregation list is rejected (min_length=1).""" + cfg = {**_TIME_AGG_CFG, "time_aggregation_types": []} + with pytest.raises(ValidationError): + TimeAggregateLossWrapperSchema(**cfg) def test_optimizer_schema_allows_extra_keys() -> None: @@ -124,7 +154,6 @@ def test_multiscale_loss_deprecated_loss_matrices_path_with_on_the_fly_config_re _COMBINED_LOSS_BASE = { "_target_": "anemoi.training.losses.combined.CombinedLoss", - "scalers": [], } @@ -132,7 +161,6 @@ def test_combined_loss_with_scalers_valid() -> None: CombinedLossSchema( **{ **_COMBINED_LOSS_BASE, - "scalers": ["*"], "losses": [ {"_target_": "anemoi.training.losses.MSELoss", "scalers": ["nan_mask_weights"]}, {"_target_": "anemoi.training.losses.MAELoss", "scalers": ["nan_mask_weights"]},