Skip to content

fix: Add anemoi weight averaging classes in order to work with Imputers and Scalers#1113

Draft
jakob-schloer wants to merge 1 commit into
mainfrom
fix/ema_with_imputers
Draft

fix: Add anemoi weight averaging classes in order to work with Imputers and Scalers#1113
jakob-schloer wants to merge 1 commit into
mainfrom
fix/ema_with_imputers

Conversation

@jakob-schloer
Copy link
Copy Markdown
Collaborator

@jakob-schloer jakob-schloer commented May 12, 2026

Description

For weight averaging, a copy of the model is created at the beginning of training and its parameters are updated/averaged over the course of training. The pytorch lightning classes for weight averaging pair parameters and buffers positionally via zip, which fails in Anemoi when used together with:

  1. Imputers (ConstantImputer and friends) register scratch buffers (nan_locations, loss_mask_training) that get reassigned with new shapes on the first forward pass. The averaged model's deep-copy still holds the original shape=(0,) tensor, so the per-batch update_parameters of the PL AverageModel call crashes with a shape mismatch.

  2. Updating loss scalers (NaNMaskScaler, etc.) call ScaleTensor.update_scaler every batch — this pops and re-registers the scaler buffer, shuffling the buffer order in the live model relative to the averaged snapshot. With matching shapes the positional zip silently mis-pairs tensors; with mismatched shapes it crashes.

Changes

  • anemoi.training.diagnostics.callbacks.weight_averaging.AveragedModel — name-based matching, shape filtering, and a float/non-float split (non-float buffers are sync'd from source rather than averaged).
  • WeightAveraging — overrides setup, _swap_models, _copy_average_to_current so all transfer paths between live and averaged model use name-based matching.
  • EMAWeightAveraging and SWAWeightAveraging — thin subclasses mirroring pytorch_lightning.callbacks.EMAWeightAveraging (and the
    analogous SWA configuration).
  • _get_weight_averaging_callback still instantiates whatever the user targets, but logs a warning if a stock pytorch_lightning.callbacks.* class is configured, since those will crash when used in combination with imputers or updating scalers.

Usage

weight_averaging:
   _target_: anemoi.training.diagnostics.callbacks.weight_averaging.EMAWeightAveraging
   decay: 0.999

Notes

Since we still allow stock pytorch_lightning.callbacks.* classes, this is not a breaking change

As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/

By opening this pull request, I affirm that all authors agree to the Contributor License Agreement.


📚 Documentation preview 📚: https://anemoi-training--1113.org.readthedocs.build/en/1113/


📚 Documentation preview 📚: https://anemoi-graphs--1113.org.readthedocs.build/en/1113/


📚 Documentation preview 📚: https://anemoi-models--1113.org.readthedocs.build/en/1113/

@github-project-automation github-project-automation Bot moved this to To be triaged in Anemoi-dev May 12, 2026
@jakob-schloer jakob-schloer requested a review from ssmmnn11 May 12, 2026 07:47
@jakob-schloer jakob-schloer self-assigned this May 12, 2026
@github-actions github-actions Bot added the bug Something isn't working label May 12, 2026
@HCookie HCookie self-requested a review May 12, 2026 12:28
@HCookie HCookie added the ATS Approval Not Needed No approval needed by ATS label May 12, 2026
@HCookie HCookie moved this from To be triaged to Reviewers needed in Anemoi-dev May 12, 2026
@anaprietonem anaprietonem marked this pull request as draft May 14, 2026 14:27
@anaprietonem
Copy link
Copy Markdown
Contributor

@jakob-schloer I am marking this as draft, as from what I understand, it might be better to check if there are ways we can adapt the dynamic scalers and imputers so that we don't have to hack the EMA callback so much. And @ssmmnn11 is looking at preparing a branch with some suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ATS Approval Not Needed No approval needed by ATS bug Something isn't working training

Projects

Status: Reviewers needed

Development

Successfully merging this pull request may close these issues.

3 participants