Skip to content

feat(training, models)!: add transport diffusion and stochastic interpolant#1096

Open
ssmmnn11 wants to merge 21 commits into
mainfrom
feat/transport-si-diffusion
Open

feat(training, models)!: add transport diffusion and stochastic interpolant#1096
ssmmnn11 wants to merge 21 commits into
mainfrom
feat/transport-si-diffusion

Conversation

@ssmmnn11
Copy link
Copy Markdown
Member

@ssmmnn11 ssmmnn11 commented May 6, 2026

Introduce transport model for stochastic interpolants and diffusion EDM.

Stochastic interpolants provide a general framework for learning continuous paths between probability distributions, with Flow Matching being a special case. This makes them a flexible framework for forecasting, time-interpolation and downscaling.

training config
    training.transport_objective: diffusion | stochastic_interpolant
    training.prediction_mode:     state | tendency
    model.model.transport.objective: diffusion | stochastic_interpolant
          |
          v
  TransportTraining
    - owns the training step
    - selects PredictionMode
    - selects TransportObjective
          |
          +-----------------------------+
          |                             |
          v                             v
  PredictionMode                  TransportObjective
  state / tendency                diffusion / stochastic_interpolant
    - builds clean target           - builds source/noise
    - decides target space          - corrupts target
    - reconstructs metrics          - defines loss target
          |                             |
          +-------------+---------------+
                        v
  PreparedTransportObjective
    conditioned_target  -> corrupted/noised target passed to model
    condition           -> sigma or SI time
    loss_target         -> clean target or SI drift
    weights             -> EDM weights or None
    aux                 -> SI source/interpolant/time etc.
                        |
                        v
           model.forward(
             x,
             conditioned_target,
             condition,
           )
                        |
                        v
           loss / validation
  • PredictionMode decides what to predict:

    • state
    • tendency
  • TransportObjective decides how to train:

    • diffusion: target as target + sigma * source, predicts clean endpoint, uses EDM weighting.
    • stochastic_interpolant: builds bridge state and trains the model to predict bridge drift.
  • TransportSourceBuilder decides the source / anchor field:

    • gaussian
    • zero
    • reference_state
    • plus scale and additional additive Gaussian noise
  • SI bridge-noise schedules are now:

    • brownian_bridge default
  • Flow-matching-like training is SI with:

    • Gaussian source
    • linear alpha/beta
    • si_noise_scale: 0.0
    • deterministic sampler such as heun.
  • During inference, the model-side TransportModelObjective dispatches to the sampler:

    • diffusion -> EDM samplers like heun, dpmpp_2m
    • SI -> euler/heun for deterministic vector field, or euler_maruyama for noisy sampling.
a b

📚 Documentation preview 📚: https://anemoi-training--1096.org.readthedocs.build/en/1096/


📚 Documentation preview 📚: https://anemoi-graphs--1096.org.readthedocs.build/en/1096/


📚 Documentation preview 📚: https://anemoi-models--1096.org.readthedocs.build/en/1096/

@ssmmnn11 ssmmnn11 requested a review from JPXKQX May 6, 2026 13:04
@ssmmnn11 ssmmnn11 self-assigned this May 6, 2026
@github-project-automation github-project-automation Bot moved this to To be triaged in Anemoi-dev May 6, 2026
@ssmmnn11 ssmmnn11 changed the title Add transport diffusion and stochastic interpolant feat: add transport diffusion and stochastic interpolant May 6, 2026
ssmmnn11 added 3 commits May 6, 2026 15:05
…fusion

# Conflicts:
#	models/src/anemoi/models/models/transport_encoder_processor_decoder.py
@ssmmnn11
Copy link
Copy Markdown
Member Author

ssmmnn11 commented May 6, 2026

tests

@ssmmnn11 ssmmnn11 changed the title feat: add transport diffusion and stochastic interpolant feat(training, models): add transport diffusion and stochastic interpolant May 7, 2026
ssmmnn11 added 3 commits May 7, 2026 12:58
…fusion

# Conflicts:
#	models/src/anemoi/models/models/transport_encoder_processor_decoder.py
#	models/src/anemoi/models/samplers/transport_samplers.py
#	training/src/anemoi/training/train/methods/diffusion.py
…fusion

# Conflicts:
#	training/src/anemoi/training/diagnostics/callbacks/plot.py
#	training/src/anemoi/training/diagnostics/plots.py
#	training/src/anemoi/training/train/methods/diffusion.py
#	training/tests/unit/diagnostics/test_plotting_callbacks.py
@mchantry
Copy link
Copy Markdown
Member

@icedoom888 for awareness.

@mchantry mchantry changed the title feat(training, models): add transport diffusion and stochastic interpolant feat(training, models)!: add transport diffusion and stochastic interpolant May 13, 2026
@mchantry mchantry added ATS Approved Approved by ATS and removed ATS Approval Needed Approval needed by ATS labels May 13, 2026
@JoffreyDumontLeBrazidec
Copy link
Copy Markdown
Member

for downscaling, a clean solution is to add "residuals" as a new prediction objective (could also be part of the "tendency" but probably will result in something too hacky).

Copy link
Copy Markdown
Member

@JoffreyDumontLeBrazidec JoffreyDumontLeBrazidec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work Simon. It’s great to be able to try out the different options

Some comments:

1/ Some knob combinations are allowed but surprising
diffusion + source.kind: zero
diffusion + reference_state
We should either disallow clearly degenerate combinations or emit warnings/document them as experimental.
For example source.kind: reference_state with objective: diffusion is potentially useful for downscaling, where the source might be an upsampled LR state.
But I feel that could be confusing for users. Maybe document that Gaussian is the recommended/default source for EDM diffusion

2/ Why not heun_maruyama for SI sampling ? in the diffusion path, this is what we use.
Since SI transport is in this PR very generalistic I would consider making heun/euler and maruyama/nothing disconnected
→ Uncouple in the config
sampler: heun/euler
stochastic: true/false

3/ Also, for the SI stochastic sampler, once we add stochastic noise during sampling, noise can move samples away from the bridge. In the SI SDE formulation, this is usually corrected by score term to bring back the pushed away samples in the bridge marginal. Here in the PR, it looks like the SI part only estimates the drift/velocity. Could be nice to have in the doc that the noisy_sampling euler_maruyama is a heuristic noise injection for diversity.

4/ The PR separates diffusion and stochastic_interpolant in two separate options. SI is a sort of general framework where we could in theory meet edm_diffusion mathematically. On the other hand we have diffusion which is actually edm_diffusion so a very specific kind of diffusion which works empirically very well. I would rename diffusion in edm_diffusion to emphasize this and avoid suggesting it is a generic diffusion objective.

Copy link
Copy Markdown
Member

@JoffreyDumontLeBrazidec JoffreyDumontLeBrazidec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work Simon. It’s great to be able to try out the different options

Some comments:

1/ Some knob combinations are allowed but surprising
diffusion + source.kind: zero
diffusion + reference_state
We should either disallow clearly degenerate combinations or emit warnings/document them as experimental.
For example source.kind: reference_state with objective: diffusion is potentially useful for downscaling, where the source might be an upsampled LR state.
But I feel that could be confusing for users. Maybe document that Gaussian is the recommended/default source for EDM diffusion

2/ Why not heun_maruyama for SI sampling ? in the diffusion path, this is what we use.
Since SI transport is in this PR very generalistic I would consider making heun/euler and maruyama/nothing disconnected
--> Uncouple in the config
sampler: heun/euler
stochastic: true/false

3/ Also, for the SI stochastic sampler, once we add stochastic noise during sampling, noise can move samples away from the bridge. In the SI SDE formulation, this is usually corrected by score term to bring back the pushed away samples in the bridge marginal. Here in the PR, it looks like the SI part only estimates the drift/velocity. Could be nice to have in the doc that the noisy_sampling euler_maruyama is a heuristic noise injection for diversity.

4/ The PR separates diffusion and stochastic_interpolant in two separate options. SI is a sort of general framework where we could in theory meet edm_diffusion mathematically. On the other hand we have diffusion which is actually edm_diffusion so a very specific kind of diffusion which works empirically very well. I would rename diffusion in edm_diffusion to emphasize this and avoid suggesting it is a generic diffusion objective.

@github-project-automation github-project-automation Bot moved this from To be triaged to For merging in Anemoi-dev May 14, 2026
@WeiPanMaths
Copy link
Copy Markdown

Hi Simon, I think this is a really nice and elegant framework. Currently, if I were to pick something out, I feel the VectorFieldEulerSampler, VectorFieldHeunSampler, and StochasticInterpolantEulerMaruyamaSampler could be modularised a bit further. Right now each class mixes the numerical integration scheme with the specific SDE/ODE being solved, which means adding a new sampler requires a full new implementation.

An alternative design could be to separate the two concerns — a generic numerical solver class (e.g. a RungeKuttaSolver) that takes the integration scheme as a parameter, and a separate equation/dynamics class that defines the drift and diffusion terms. New samplers would then just be parameter choices rather than new classes. Something like a SDESolver(scheme="euler_maruyama", dynamics=StochasticInterpolantDynamics()).

That said, I appreciate this may be intentional for clarity and simplicity at this stage — just something to consider if the sampler zoo grows. Overall really cool!

@ssmmnn11
Copy link
Copy Markdown
Member Author

Hi Simon, I think this is a really nice and elegant framework. Currently, if I were to pick something out, I feel the VectorFieldEulerSampler, VectorFieldHeunSampler, and StochasticInterpolantEulerMaruyamaSampler could be modularised a bit further. Right now each class mixes the numerical integration scheme with the specific SDE/ODE being solved, which means adding a new sampler requires a full new implementation.

An alternative design could be to separate the two concerns — a generic numerical solver class (e.g. a RungeKuttaSolver) that takes the integration scheme as a parameter, and a separate equation/dynamics class that defines the drift and diffusion terms. New samplers would then just be parameter choices rather than new classes. Something like a SDESolver(scheme="euler_maruyama", dynamics=StochasticInterpolantDynamics()).

That said, I appreciate this may be intentional for clarity and simplicity at this stage — just something to consider if the sampler zoo grows. Overall really cool!

Very good point! I will put it on our todo list and leave this as a follow up for now.

@ssmmnn11
Copy link
Copy Markdown
Member Author

Very nice work Simon. It’s great to be able to try out the different options

Some comments:

1/ Some knob combinations are allowed but surprising diffusion + source.kind: zero diffusion + reference_state We should either disallow clearly degenerate combinations or emit warnings/document them as experimental. For example source.kind: reference_state with objective: diffusion is potentially useful for downscaling, where the source might be an upsampled LR state. But I feel that could be confusing for users. Maybe document that Gaussian is the recommended/default source for EDM diffusion

2/ Why not heun_maruyama for SI sampling ? in the diffusion path, this is what we use. Since SI transport is in this PR very generalistic I would consider making heun/euler and maruyama/nothing disconnected --> Uncouple in the config sampler: heun/euler stochastic: true/false

3/ Also, for the SI stochastic sampler, once we add stochastic noise during sampling, noise can move samples away from the bridge. In the SI SDE formulation, this is usually corrected by score term to bring back the pushed away samples in the bridge marginal. Here in the PR, it looks like the SI part only estimates the drift/velocity. Could be nice to have in the doc that the noisy_sampling euler_maruyama is a heuristic noise injection for diversity.

4/ The PR separates diffusion and stochastic_interpolant in two separate options. SI is a sort of general framework where we could in theory meet edm_diffusion mathematically. On the other hand we have diffusion which is actually edm_diffusion so a very specific kind of diffusion which works empirically very well. I would rename diffusion in edm_diffusion to emphasize this and avoid suggesting it is a generic diffusion objective.

Very true. I revised according to suggestions - we will support det. sampling for now.

ssmmnn11 added 2 commits May 18, 2026 13:26
…fusion

# Conflicts:
#	training/src/anemoi/training/config/training/diffusion.yaml
``prediction_mode: tendency`` for tendency-space targets. The model must
use :class:`AnemoiTransportModelEncProcDec` or
:class:`AnemoiTransportTendModelEncProcDec`; the plain GNN model is not
supported.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to move this to a note or a warning to highlight it?

.. warning::
     The plain GNN model is not supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: For merging

Development

Successfully merging this pull request may close these issues.

5 participants