You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* [AnyFlow] FAR: standalone causal-mask builder + torch.compile follow-up
Follow-up to #13745. Extracts FAR mask construction to a module-level
helper and adds an `attention_mask` forward kwarg so
AnyFlowFARTransformer3DModel can be wrapped in
`torch.compile(fullgraph=True)`. The pipeline pre-builds the mask during
KV-cache prefill so users get end-to-end fullgraph compile.
* Public method `AnyFlowFARTransformer3DModel.build_attention_mask(...)`
(modes: "train", "cache") plus private module-level helper
`_build_anyflow_far_causal_block_mask(...)`.
* `_build_freqs` cache lookup/write bypassed under
`torch.compiler.is_compiling()` to avoid a Dynamo guard recompile on
the second compiled call (applied in bidi source; synced to FAR via
`# Copied from`).
* `TestAnyFlowFARTransformer3DCompile(TorchCompileTesterMixin)` —
recompilation_and_graph_break, repeated_blocks, and group_offloading
pass on H200; AOT is `@pytest.mark.skip`'d (torch.export rejects
BlockMask as a pytree input).
* Base `get_dummy_inputs` omits `attention_mask` so every non-compile
test class exercises the in-forward fallback; the compile class
overrides to inject a pre-built mask.
* Bit-exact: pre-built path vs internal-build fallback max|Δ|=0.0e+00.
* [AnyFlow] docs: full author list, repo demo examples, slimmer pipeline page
* Full author list and NVIDIA → NUS → MIT institution order; TL;DR +
abstract + Available Models bullets.
* Rewritten pipeline-selection tip describing both pipelines symmetrically.
* T2V / I2V / V2V examples now use the canonical 81-frame setup and the
demo prompts / conditioning assets shipped under
`NVlabs/AnyFlow/assets/evaluation/` (linked via raw.githubusercontent.com).
* Drop the inline "Optimizing Memory" and "torch.compile" sections — those
notes will live in the NVlabs/AnyFlow repo's own performance guide rather
than the diffusers pipeline reference.
* Sync zh user guide and the two model-API stubs.
* [AnyFlow] FAR: move chunk_partition default into transformer config
- AnyFlowFARTransformer3DModel.__init__ now accepts chunk_partition via
@register_to_config (default (1, 3, 3, 3, 3, 3, 3, 2) for the released
81-frame checkpoints, matching the field on Hub).
- AnyFlowFARPipeline.__call__ no longer requires chunk_partition; defaults
to self.transformer.config.chunk_partition. Per-call override still
supported for V2V / non-default num_frames.
- Drop the AnyFlowFARPipeline.default_chunk_partition class attribute.
- Update docs (en pipelines/models, zh using-diffusers) and the conversion
script to match.
* [AnyFlow] FAR pipeline: fix `timesteps` shadowing across chunks
Inside the per-chunk rollout loop, the local variable `timesteps` was
reassigned to `self.scheduler.timesteps` after `set_timesteps()`. On the
next chunk iteration the same name was passed back into
`set_timesteps(timesteps=...)`, where a non-None value enters the
*custom-schedule* branch — `apply_shift` re-runs on already-shifted
values, double-shifting the schedule for every chunk after the first.
Concretely, with `shift=5` and `num_inference_steps=4`:
- chunk 0 timesteps: [1000, 937.5, 833.3, 625] (correct)
- chunk 1+ timesteps: [1000, 986.8, 961.3, 892.9] (double-shifted)
The later steps drift toward `t=1000` instead of toward `t=0`, the
flow-map model is conditioned on the wrong source sigma, and the chunk
KV cache accumulates errors that show up as artifacts in later video
frames.
Fix: rebind the cached schedule to a fresh local name
(`scheduler_timesteps`) so the outer-scope `timesteps` kwarg (the
user-provided custom schedule, when any) stays untouched across chunks.
Layer-by-layer verification against the NVlabs reference implementation
on H200 (elephant prompt, seed 0, 4 NFE, 81 frames):
- chunk 0 inference: bit-exact (0.0 mean diff)
- chunk 1 step 0: 0.194 → 0.014 (-93%)
- chunk 7 last step: 0.564 → 0.274 (-51%)
* [AnyFlow] FAR: doc-builder line wrap for chunk_partition docstrings
Pure rewrap to satisfy `doc-builder style --max_len 119`. Two docstrings
introduced in 96077b2 (the `chunk_partition` config arg on the FAR
transformer + the matching pipeline kwarg) wrapped a few characters
short of the line budget. No semantic change.
* [AnyFlow] docs: drop author names from docstrings, link FAR via HF papers, say chunk-wise
- Remove author-name attributions from the transformer / pipeline class
docstrings and file-header comments; the paper-citation header on the
doc page keeps the full author list, the in-code references just point at
the [AnyFlow] / [FAR] papers.
- Link FAR via its Hugging Face papers page
(https://huggingface.co/papers/2503.19325) instead of a raw arxiv.org URL,
matching the AnyFlow reference style and the rest of the diffusers docs.
- Describe AnyFlow FAR generation as "chunk-wise autoregressive": the
pipeline autoregresses over chunks (`chunk_partition`), not single frames.
* [AnyFlow] FAR: address review nits
- pipeline: reuse the standard `timesteps` variable name for the per-chunk
scheduler timesteps; freeze the caller-provided custom schedule in
`custom_timesteps`/`custom_sigmas` before the loop so it isn't re-fed into
`set_timesteps` and double-shifted on later chunks.
- transformer: clarify the no-mask fallback comment to spell out the
`torch.compile(fullgraph=True)` graph-break behavior and the
`build_attention_mask` workaround.
---------
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Copy file name to clipboardExpand all lines: docs/source/en/api/pipelines/anyflow.md
+49-75Lines changed: 49 additions & 75 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -20,68 +20,28 @@ specific language governing permissions and limitations under the License.
20
20
21
21
# AnyFlow
22
22
23
-
[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian Fang and collaborators at [NUS ShowLab](https://sites.google.com/view/showlab) in collaboration with NVIDIA.
23
+
[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) from NVIDIA, National University of Singapore, and Massachusetts Institute of Technology, by Yuchao Gu, Guian Fang, Yuxin Jiang, Weijia Mao, Song Han, Han Cai, Mike Zheng Shou.
24
+
25
+
> **TL;DR:** AnyFlow is the first any-step video diffusion framework built on flow maps, which enables a single model (bidirectional or causal) to adapt to arbitrary inference budgets.
24
26
25
27
*Few-step video generation has been significantly advanced by consistency models. However, their performance often degrades in any-step video diffusion models due to the fixed-point formulation. To address this limitation, we present AnyFlow, the first any-step video diffusion distillation framework built on flow maps. Instead of learning only the mapping z_t → z_0, AnyFlow learns transitions z_t → z_r over arbitrary time intervals, enabling a single model to adapt to different inference budgets. We design an improved forward flow map training recipe that fine-tunes pretrained video diffusion models into flow map models, and introduce Flow Map Backward Simulation to enable on-policy distillation for flow map models. Extensive experiments across both bidirectional and causal architectures, at scales ranging from 1.3B to 14B, on text-to-video and image-to-video tasks demonstrate that AnyFlow outperforms consistency-based baselines while preserving high fidelity and flexible sampling under varying step budgets.*
26
28
27
-
The original training code is at [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow). The project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow).
29
+
The AnyFlow pipelines were contributed by the AnyFlow Team. The original code is available on [GitHub](https://github.com/NVlabs/AnyFlow), the project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow), and pretrained models can be found in the [nvidia/anyflow](https://huggingface.co/collections/nvidia/anyflow) collection on Hugging Face.
All four are grouped under the [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection.
39
-
40
40
> [!TIP]
41
-
> Choose `AnyFlowPipeline` for traditional bidirectional text-to-video generation. Choose `AnyFlowFARPipeline` for streaming I2V, video continuation (V2V), or any setup that benefits from frame-by-frame autoregressive sampling.
42
-
43
-
> [!TIP]
44
-
> AnyFlow supports any-step sampling: a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without retraining. Quality scales monotonically with steps in our benchmarks.
45
-
46
-
### Optimizing Memory and Inference Speed
47
-
48
-
<hfoptionsid="optimization">
49
-
<hfoptionid="memory">
50
-
51
-
```py
52
-
import torch
53
-
from diffusers import AnyFlowPipeline
54
-
from diffusers.hooks import apply_group_offloading
> `AnyFlowPipeline` is designed for bidirectional diffusion models in text-to-video (T2V) generation. `AnyFlowFARPipeline` is a chunk-wise causal diffusion model that supports text-to-video (T2V) generation, image-to-video (I2V) generation, and video continuation (V2V).
0 commit comments