feat(training, models)!: add transport diffusion and stochastic interpolant#1096
feat(training, models)!: add transport diffusion and stochastic interpolant#1096ssmmnn11 wants to merge 21 commits into
Conversation
…fusion # Conflicts: # models/src/anemoi/models/models/transport_encoder_processor_decoder.py # models/src/anemoi/models/samplers/transport_samplers.py # training/src/anemoi/training/train/methods/diffusion.py
…o feat/transport-si-diffusion
…fusion # Conflicts: # training/src/anemoi/training/diagnostics/callbacks/plot.py # training/src/anemoi/training/diagnostics/plots.py # training/src/anemoi/training/train/methods/diffusion.py # training/tests/unit/diagnostics/test_plotting_callbacks.py
|
@icedoom888 for awareness. |
|
for downscaling, a clean solution is to add "residuals" as a new prediction objective (could also be part of the "tendency" but probably will result in something too hacky). |
JoffreyDumontLeBrazidec
left a comment
There was a problem hiding this comment.
Very nice work Simon. It’s great to be able to try out the different options
Some comments:
1/ Some knob combinations are allowed but surprising
diffusion + source.kind: zero
diffusion + reference_state
We should either disallow clearly degenerate combinations or emit warnings/document them as experimental.
For example source.kind: reference_state with objective: diffusion is potentially useful for downscaling, where the source might be an upsampled LR state.
But I feel that could be confusing for users. Maybe document that Gaussian is the recommended/default source for EDM diffusion
2/ Why not heun_maruyama for SI sampling ? in the diffusion path, this is what we use.
Since SI transport is in this PR very generalistic I would consider making heun/euler and maruyama/nothing disconnected
→ Uncouple in the config
sampler: heun/euler
stochastic: true/false
3/ Also, for the SI stochastic sampler, once we add stochastic noise during sampling, noise can move samples away from the bridge. In the SI SDE formulation, this is usually corrected by score term to bring back the pushed away samples in the bridge marginal. Here in the PR, it looks like the SI part only estimates the drift/velocity. Could be nice to have in the doc that the noisy_sampling euler_maruyama is a heuristic noise injection for diversity.
4/ The PR separates diffusion and stochastic_interpolant in two separate options. SI is a sort of general framework where we could in theory meet edm_diffusion mathematically. On the other hand we have diffusion which is actually edm_diffusion so a very specific kind of diffusion which works empirically very well. I would rename diffusion in edm_diffusion to emphasize this and avoid suggesting it is a generic diffusion objective.
JoffreyDumontLeBrazidec
left a comment
There was a problem hiding this comment.
Very nice work Simon. It’s great to be able to try out the different options
Some comments:
1/ Some knob combinations are allowed but surprising
diffusion + source.kind: zero
diffusion + reference_state
We should either disallow clearly degenerate combinations or emit warnings/document them as experimental.
For example source.kind: reference_state with objective: diffusion is potentially useful for downscaling, where the source might be an upsampled LR state.
But I feel that could be confusing for users. Maybe document that Gaussian is the recommended/default source for EDM diffusion
2/ Why not heun_maruyama for SI sampling ? in the diffusion path, this is what we use.
Since SI transport is in this PR very generalistic I would consider making heun/euler and maruyama/nothing disconnected
--> Uncouple in the config
sampler: heun/euler
stochastic: true/false
3/ Also, for the SI stochastic sampler, once we add stochastic noise during sampling, noise can move samples away from the bridge. In the SI SDE formulation, this is usually corrected by score term to bring back the pushed away samples in the bridge marginal. Here in the PR, it looks like the SI part only estimates the drift/velocity. Could be nice to have in the doc that the noisy_sampling euler_maruyama is a heuristic noise injection for diversity.
4/ The PR separates diffusion and stochastic_interpolant in two separate options. SI is a sort of general framework where we could in theory meet edm_diffusion mathematically. On the other hand we have diffusion which is actually edm_diffusion so a very specific kind of diffusion which works empirically very well. I would rename diffusion in edm_diffusion to emphasize this and avoid suggesting it is a generic diffusion objective.
|
Hi Simon, I think this is a really nice and elegant framework. Currently, if I were to pick something out, I feel the VectorFieldEulerSampler, VectorFieldHeunSampler, and StochasticInterpolantEulerMaruyamaSampler could be modularised a bit further. Right now each class mixes the numerical integration scheme with the specific SDE/ODE being solved, which means adding a new sampler requires a full new implementation. An alternative design could be to separate the two concerns — a generic numerical solver class (e.g. a RungeKuttaSolver) that takes the integration scheme as a parameter, and a separate equation/dynamics class that defines the drift and diffusion terms. New samplers would then just be parameter choices rather than new classes. Something like a SDESolver(scheme="euler_maruyama", dynamics=StochasticInterpolantDynamics()). That said, I appreciate this may be intentional for clarity and simplicity at this stage — just something to consider if the sampler zoo grows. Overall really cool! |
Very good point! I will put it on our todo list and leave this as a follow up for now. |
Very true. I revised according to suggestions - we will support det. sampling for now. |
…fusion # Conflicts: # training/src/anemoi/training/config/training/diffusion.yaml
…fusion # Conflicts: # training/src/anemoi/training/config/training/diffusion.yaml
| ``prediction_mode: tendency`` for tendency-space targets. The model must | ||
| use :class:`AnemoiTransportModelEncProcDec` or | ||
| :class:`AnemoiTransportTendModelEncProcDec`; the plain GNN model is not | ||
| supported. |
There was a problem hiding this comment.
Would it make sense to move this to a note or a warning to highlight it?
.. warning::
The plain GNN model is not supported.
Introduce transport model for stochastic interpolants and diffusion EDM.
Stochastic interpolants provide a general framework for learning continuous paths between probability distributions, with Flow Matching being a special case. This makes them a flexible framework for forecasting, time-interpolation and downscaling.
PredictionMode decides what to predict:
TransportObjective decides how to train:
TransportSourceBuilder decides the source / anchor field:
SI bridge-noise schedules are now:
Flow-matching-like training is SI with:
During inference, the model-side TransportModelObjective dispatches to the sampler:
📚 Documentation preview 📚: https://anemoi-training--1096.org.readthedocs.build/en/1096/
📚 Documentation preview 📚: https://anemoi-graphs--1096.org.readthedocs.build/en/1096/
📚 Documentation preview 📚: https://anemoi-models--1096.org.readthedocs.build/en/1096/