Skip to content

Commit d94ad81

Browse files
Enderfgadg845
andauthored
[AnyFlow] FAR: standalone causal-mask builder + torch.compile follow-up (#13792)
* [AnyFlow] FAR: standalone causal-mask builder + torch.compile follow-up Follow-up to #13745. Extracts FAR mask construction to a module-level helper and adds an `attention_mask` forward kwarg so AnyFlowFARTransformer3DModel can be wrapped in `torch.compile(fullgraph=True)`. The pipeline pre-builds the mask during KV-cache prefill so users get end-to-end fullgraph compile. * Public method `AnyFlowFARTransformer3DModel.build_attention_mask(...)` (modes: "train", "cache") plus private module-level helper `_build_anyflow_far_causal_block_mask(...)`. * `_build_freqs` cache lookup/write bypassed under `torch.compiler.is_compiling()` to avoid a Dynamo guard recompile on the second compiled call (applied in bidi source; synced to FAR via `# Copied from`). * `TestAnyFlowFARTransformer3DCompile(TorchCompileTesterMixin)` — recompilation_and_graph_break, repeated_blocks, and group_offloading pass on H200; AOT is `@pytest.mark.skip`'d (torch.export rejects BlockMask as a pytree input). * Base `get_dummy_inputs` omits `attention_mask` so every non-compile test class exercises the in-forward fallback; the compile class overrides to inject a pre-built mask. * Bit-exact: pre-built path vs internal-build fallback max|Δ|=0.0e+00. * [AnyFlow] docs: full author list, repo demo examples, slimmer pipeline page * Full author list and NVIDIA → NUS → MIT institution order; TL;DR + abstract + Available Models bullets. * Rewritten pipeline-selection tip describing both pipelines symmetrically. * T2V / I2V / V2V examples now use the canonical 81-frame setup and the demo prompts / conditioning assets shipped under `NVlabs/AnyFlow/assets/evaluation/` (linked via raw.githubusercontent.com). * Drop the inline "Optimizing Memory" and "torch.compile" sections — those notes will live in the NVlabs/AnyFlow repo's own performance guide rather than the diffusers pipeline reference. * Sync zh user guide and the two model-API stubs. * [AnyFlow] FAR: move chunk_partition default into transformer config - AnyFlowFARTransformer3DModel.__init__ now accepts chunk_partition via @register_to_config (default (1, 3, 3, 3, 3, 3, 3, 2) for the released 81-frame checkpoints, matching the field on Hub). - AnyFlowFARPipeline.__call__ no longer requires chunk_partition; defaults to self.transformer.config.chunk_partition. Per-call override still supported for V2V / non-default num_frames. - Drop the AnyFlowFARPipeline.default_chunk_partition class attribute. - Update docs (en pipelines/models, zh using-diffusers) and the conversion script to match. * [AnyFlow] FAR pipeline: fix `timesteps` shadowing across chunks Inside the per-chunk rollout loop, the local variable `timesteps` was reassigned to `self.scheduler.timesteps` after `set_timesteps()`. On the next chunk iteration the same name was passed back into `set_timesteps(timesteps=...)`, where a non-None value enters the *custom-schedule* branch — `apply_shift` re-runs on already-shifted values, double-shifting the schedule for every chunk after the first. Concretely, with `shift=5` and `num_inference_steps=4`: - chunk 0 timesteps: [1000, 937.5, 833.3, 625] (correct) - chunk 1+ timesteps: [1000, 986.8, 961.3, 892.9] (double-shifted) The later steps drift toward `t=1000` instead of toward `t=0`, the flow-map model is conditioned on the wrong source sigma, and the chunk KV cache accumulates errors that show up as artifacts in later video frames. Fix: rebind the cached schedule to a fresh local name (`scheduler_timesteps`) so the outer-scope `timesteps` kwarg (the user-provided custom schedule, when any) stays untouched across chunks. Layer-by-layer verification against the NVlabs reference implementation on H200 (elephant prompt, seed 0, 4 NFE, 81 frames): - chunk 0 inference: bit-exact (0.0 mean diff) - chunk 1 step 0: 0.194 → 0.014 (-93%) - chunk 7 last step: 0.564 → 0.274 (-51%) * [AnyFlow] FAR: doc-builder line wrap for chunk_partition docstrings Pure rewrap to satisfy `doc-builder style --max_len 119`. Two docstrings introduced in 96077b2 (the `chunk_partition` config arg on the FAR transformer + the matching pipeline kwarg) wrapped a few characters short of the line budget. No semantic change. * [AnyFlow] docs: drop author names from docstrings, link FAR via HF papers, say chunk-wise - Remove author-name attributions from the transformer / pipeline class docstrings and file-header comments; the paper-citation header on the doc page keeps the full author list, the in-code references just point at the [AnyFlow] / [FAR] papers. - Link FAR via its Hugging Face papers page (https://huggingface.co/papers/2503.19325) instead of a raw arxiv.org URL, matching the AnyFlow reference style and the rest of the diffusers docs. - Describe AnyFlow FAR generation as "chunk-wise autoregressive": the pipeline autoregresses over chunks (`chunk_partition`), not single frames. * [AnyFlow] FAR: address review nits - pipeline: reuse the standard `timesteps` variable name for the per-chunk scheduler timesteps; freeze the caller-provided custom schedule in `custom_timesteps`/`custom_sigmas` before the loop so it isn't re-fed into `set_timesteps` and double-shifted on later chunks. - transformer: clarify the no-mask fallback comment to spell out the `torch.compile(fullgraph=True)` graph-break behavior and the `build_attention_mask` workaround. --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent 5ec96c3 commit d94ad81

11 files changed

Lines changed: 418 additions & 310 deletions

File tree

docs/source/en/api/models/anyflow_far_transformer3d.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,22 @@ specific language governing permissions and limitations under the License.
1313
# AnyFlowFARTransformer3DModel
1414

1515
The causal (FAR) 3D Transformer used by [`AnyFlowFARPipeline`](../pipelines/anyflow#anyflowfarpipeline)
16-
the FAR variant of [AnyFlow](https://huggingface.co/papers/2605.13724) (Yuchao Gu, Guian Fang et al., NUS
17-
ShowLab × NVIDIA). It extends the v0.35.1 Wan2.1 backbone with three additions:
16+
the FAR variant of [AnyFlow](https://huggingface.co/papers/2605.13724). See the
17+
[`AnyFlowFARPipeline`](../pipelines/anyflow) page for paper, authors, and released checkpoints. It extends
18+
the v0.35.1 Wan2.1 backbone with three additions:
1819

19-
1. **FAR causal block-mask** via `torch.nn.attention.flex_attention`, supporting frame-level autoregressive
20-
generation as introduced in [FAR (Gu et al., 2025)](https://arxiv.org/abs/2503.19325).
20+
1. **FAR causal block-mask** via `torch.nn.attention.flex_attention`, supporting chunk-wise autoregressive
21+
generation as introduced in [FAR](https://huggingface.co/papers/2503.19325).
2122
2. **Compressed-frame patch embedding** (`far_patch_embedding`) for context (already-generated) frames,
2223
warm-started from the full-resolution `patch_embedding` at construction time via trilinear interpolation.
2324
3. **Dual-timestep flow-map embedding** (same as
2425
[`AnyFlowTransformer3DModel`](anyflow_transformer3d)) — every forward call conditions on both the source
2526
timestep ``t`` and the target timestep ``r``.
2627

27-
The chunk schedule (`chunk_partition`) is **not** baked into the model config. It is a per-call argument to
28-
`forward`, so the same checkpoint handles different `num_frames` configurations without retraining.
28+
The default chunk schedule (`chunk_partition`) is stored in the model config; the released NVIDIA AnyFlow-FAR
29+
checkpoints use `[1, 3, 3, 3, 3, 3, 3, 2]` for the canonical 81-frame setting. `forward` accepts a per-call
30+
`chunk_partition` override, so the same checkpoint also handles other `num_frames` configurations without
31+
retraining.
2932

3033
```python
3134
from diffusers import AnyFlowFARTransformer3DModel

docs/source/en/api/models/anyflow_transformer3d.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ The bidirectional 3D Transformer used by [`AnyFlowPipeline`](../pipelines/anyflo
1616
v0.35.1 Wan2.1 backbone with one structural change: the timestep embedder is replaced by
1717
``AnyFlowDualTimestepTextImageEmbedding``, so every forward call conditions on both the source timestep
1818
``t`` and the target timestep ``r``. This is the embedding required to learn the flow map
19-
:math:`\Phi_{r\leftarrow t}` introduced in
20-
[AnyFlow](https://huggingface.co/papers/2605.13724) (Yuchao Gu, Guian Fang et al., NUS ShowLab × NVIDIA).
19+
$\Phi_{r\leftarrow t}$ introduced in
20+
[AnyFlow](https://huggingface.co/papers/2605.13724). See the [`AnyFlowPipeline`](../pipelines/anyflow) page
21+
for paper, authors, and released checkpoints.
2122

22-
For frame-level autoregressive (FAR causal) generation, use
23+
For chunk-wise autoregressive (FAR causal) generation, use
2324
[`AnyFlowFARTransformer3DModel`](anyflow_far_transformer3d) instead.
2425

2526
```python

docs/source/en/api/pipelines/anyflow.md

Lines changed: 49 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -20,68 +20,28 @@ specific language governing permissions and limitations under the License.
2020

2121
# AnyFlow
2222

23-
[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian Fang and collaborators at [NUS ShowLab](https://sites.google.com/view/showlab) in collaboration with NVIDIA.
23+
[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) from NVIDIA, National University of Singapore, and Massachusetts Institute of Technology, by Yuchao Gu, Guian Fang, Yuxin Jiang, Weijia Mao, Song Han, Han Cai, Mike Zheng Shou.
24+
25+
> **TL;DR:** AnyFlow is the first any-step video diffusion framework built on flow maps, which enables a single model (bidirectional or causal) to adapt to arbitrary inference budgets.
2426
2527
*Few-step video generation has been significantly advanced by consistency models. However, their performance often degrades in any-step video diffusion models due to the fixed-point formulation. To address this limitation, we present AnyFlow, the first any-step video diffusion distillation framework built on flow maps. Instead of learning only the mapping z_t → z_0, AnyFlow learns transitions z_t → z_r over arbitrary time intervals, enabling a single model to adapt to different inference budgets. We design an improved forward flow map training recipe that fine-tunes pretrained video diffusion models into flow map models, and introduce Flow Map Backward Simulation to enable on-policy distillation for flow map models. Extensive experiments across both bidirectional and causal architectures, at scales ranging from 1.3B to 14B, on text-to-video and image-to-video tasks demonstrate that AnyFlow outperforms consistency-based baselines while preserving high fidelity and flexible sampling under varying step budgets.*
2628

27-
The original training code is at [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow). The project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow).
29+
The AnyFlow pipelines were contributed by the AnyFlow Team. The original code is available on [GitHub](https://github.com/NVlabs/AnyFlow), the project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow), and pretrained models can be found in the [nvidia/anyflow](https://huggingface.co/collections/nvidia/anyflow) collection on Hugging Face.
2830

29-
The following AnyFlow checkpoints are supported:
31+
Available Models:
3032

3133
| Checkpoint | Backbone | Description |
32-
|------------|----------|-------------|
33-
| [`nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers) | Wan2.1 1.3B | Bidirectional T2V, lightweight |
34-
| [`nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers) | Wan2.1 14B | Bidirectional T2V, full quality |
34+
|---|---|---|
35+
| [`nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers) | Wan2.1 1.3B | Bidirectional T2V |
36+
| [`nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers) | Wan2.1 14B | Bidirectional T2V |
3537
| [`nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers) | FAR + Wan2.1 1.3B | Causal T2V / I2V / V2V |
3638
| [`nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers) | FAR + Wan2.1 14B | Causal T2V / I2V / V2V |
3739

38-
All four are grouped under the [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection.
39-
4040
> [!TIP]
41-
> Choose `AnyFlowPipeline` for traditional bidirectional text-to-video generation. Choose `AnyFlowFARPipeline` for streaming I2V, video continuation (V2V), or any setup that benefits from frame-by-frame autoregressive sampling.
42-
43-
> [!TIP]
44-
> AnyFlow supports any-step sampling: a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without retraining. Quality scales monotonically with steps in our benchmarks.
45-
46-
### Optimizing Memory and Inference Speed
47-
48-
<hfoptions id="optimization">
49-
<hfoption id="memory">
50-
51-
```py
52-
import torch
53-
from diffusers import AnyFlowPipeline
54-
from diffusers.hooks import apply_group_offloading
55-
56-
pipe = AnyFlowPipeline.from_pretrained(
57-
"nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
58-
)
59-
apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level")
60-
pipe.vae.enable_slicing()
61-
pipe.vae.enable_tiling()
62-
```
63-
64-
</hfoption>
65-
<hfoption id="inference speed">
66-
67-
```py
68-
import torch
69-
from diffusers import AnyFlowPipeline
70-
71-
pipe = AnyFlowPipeline.from_pretrained(
72-
"nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
73-
).to("cuda")
74-
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
75-
```
76-
77-
</hfoption>
78-
</hfoptions>
41+
> `AnyFlowPipeline` is designed for bidirectional diffusion models in text-to-video (T2V) generation. `AnyFlowFARPipeline` is a chunk-wise causal diffusion model that supports text-to-video (T2V) generation, image-to-video (I2V) generation, and video continuation (V2V).
7942
8043
### Generation with AnyFlow (Bidirectional T2V)
8144

82-
<hfoptions id="anyflow-bidi">
83-
<hfoption id="usage">
84-
8545
```py
8646
import torch
8747
from diffusers import AnyFlowPipeline
@@ -91,14 +51,16 @@ pipe = AnyFlowPipeline.from_pretrained(
9151
"nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
9252
).to("cuda")
9353

94-
prompt = "A red panda eating bamboo in a forest, cinematic lighting"
95-
video = pipe(prompt, num_inference_steps=4, num_frames=33).frames[0]
96-
export_to_video(video, "out.mp4", fps=16)
54+
prompt = (
55+
"An astronaut runs smoothly and appears almost weightless on the lunar surface, "
56+
"as seen from a low-angle shot that highlights the vast, desolate background of the moon. "
57+
"The moon's craters and rocky terrain are clearly visible, creating a stark contrast against "
58+
"the running astronaut who moves with graceful, fluid motions."
59+
)
60+
video = pipe(prompt, num_inference_steps=4, num_frames=81).frames[0]
61+
export_to_video(video, "anyflow_t2v.mp4", fps=16)
9762
```
9863

99-
</hfoption>
100-
</hfoptions>
101-
10264
### Generation with AnyFlow (FAR Causal)
10365

10466
The causal pipeline selects between T2V / I2V / V2V via the ``video`` (or ``video_latents``) argument:
@@ -108,10 +70,10 @@ clip for V2V continuation. If you already have pre-encoded latents in the model
10870
``video_latents=<tensor>`` to skip VAE encoding. ``video`` and ``video_latents`` are mutually exclusive.
10971

11072
> [!IMPORTANT]
111-
> `AnyFlowFARPipeline.default_chunk_partition = [1, 3, 3, 3, 3, 3, 3, 2]` (sum 21) is matched to the
112-
> released checkpoints' canonical 81 raw frames (21 latent frames at the VAE temporal stride of 4). When
113-
> you change `num_frames`, you must also pass a matching `chunk_partition` summing to
114-
> `(num_frames - 1) // 4 + 1`, otherwise the pipeline raises an `AssertionError`.
73+
> The released checkpoints bake `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]` (sum 21) into the transformer
74+
> config, matched to the canonical 81 raw frames (21 latent frames at the VAE temporal stride of 4). When
75+
> you change `num_frames`, pass a matching `chunk_partition` summing to `(num_frames - 1) // 4 + 1`,
76+
> otherwise the pipeline raises a `ValueError`.
11577
11678
<hfoptions id="anyflow-far">
11779
<hfoption id="t2v">
@@ -125,12 +87,12 @@ pipe = AnyFlowFARPipeline.from_pretrained(
12587
"nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
12688
).to("cuda")
12789

128-
video = pipe(
129-
prompt="A cat surfing a wave, sunset",
130-
num_inference_steps=4,
131-
num_frames=81,
132-
).frames[0]
133-
export_to_video(video, "out.mp4", fps=16)
90+
prompt = (
91+
"An astronaut runs smoothly and appears almost weightless on the lunar surface, "
92+
"as seen from a low-angle shot that highlights the vast, desolate background of the moon."
93+
)
94+
video = pipe(prompt, num_inference_steps=4, num_frames=81).frames[0]
95+
export_to_video(video, "anyflow_far_t2v.mp4", fps=16)
13496
```
13597

13698
</hfoption>
@@ -146,18 +108,25 @@ pipe = AnyFlowFARPipeline.from_pretrained(
146108
"nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
147109
).to("cuda")
148110

149-
# Wrap the conditioning image as a one-frame video tensor: (1, 1, 3, H, W) in [0, 1].
150-
first_frame = load_image("path/to/first_frame.png").resize((832, 480))
111+
# Example conditioning image from the AnyFlow repo.
112+
first_frame = load_image(
113+
"https://raw.githubusercontent.com/NVlabs/AnyFlow/main/assets/evaluation/example/images/1.jpg"
114+
).resize((832, 480))
151115
arr = np.asarray(first_frame).astype("float32") / 255.0 # (480, 832, 3)
152-
context_tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(1).to("cuda")
116+
context_tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(1).to("cuda") # (1, 1, 3, 480, 832)
153117

118+
prompt = (
119+
"A towering, battle-scarred humanoid robot, reminiscent of a Transformer with powerful, segmented armor "
120+
"and glowing red optics, walking through the skeletal remains of a city ruin. Twisted metal and shattered "
121+
"concrete crunch under its heavy steps, as the robot scans the desolate, dust-choked skyline under an dark sky."
122+
)
154123
video = pipe(
155-
prompt="a cat walks across a sunlit lawn",
124+
prompt=prompt,
156125
video=context_tensor,
157126
num_inference_steps=4,
158127
num_frames=81,
159128
).frames[0]
160-
export_to_video(video, "out.mp4", fps=16)
129+
export_to_video(video, "anyflow_far_i2v.mp4", fps=16)
161130
```
162131

163132
</hfoption>
@@ -173,21 +142,26 @@ pipe = AnyFlowFARPipeline.from_pretrained(
173142
"nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
174143
).to("cuda")
175144

176-
# Context clip — 9 raw frames map to 3 latent frames (9 = 4·2 + 1, 3 = 2 + 1).
177-
context_frames = load_video("path/to/context.mp4")[:9]
145+
# Example conditioning clip from the AnyFlow repo — take the first 9 frames (3 latent frames at VAE temporal stride 4).
146+
context_frames = load_video(
147+
"https://raw.githubusercontent.com/NVlabs/AnyFlow/main/assets/evaluation/example/videos/2.mp4"
148+
)[:9]
178149
arr = np.stack([np.asarray(f.resize((832, 480))) for f in context_frames]).astype("float32") / 255.0
179-
# np.stack gives (T, H, W, C) = (9, 480, 832, 3) → permute to (T, C, H, W) then add batch.
180150
context_tensor = torch.from_numpy(arr).permute(0, 3, 1, 2).unsqueeze(0).to("cuda") # (1, 9, 3, 480, 832)
181151

152+
prompt = (
153+
"A focused trail runner's powerful strides through a dense, sun-dappled forest. "
154+
"The camera tracks alongside, highlighting muscular exertion, sweat, and determined facial expression."
155+
)
182156
video = pipe(
183-
prompt="continue the story",
157+
prompt=prompt,
184158
video=context_tensor,
185159
num_inference_steps=4,
186160
num_frames=81,
187161
# Override chunk_partition so the first chunk covers exactly the 3 latent context frames.
188162
chunk_partition=[3, 3, 3, 3, 3, 3, 3],
189163
).frames[0]
190-
export_to_video(video, "out.mp4", fps=16)
164+
export_to_video(video, "anyflow_far_v2v.mp4", fps=16)
191165
```
192166

193167
</hfoption>

docs/source/zh/using-diffusers/anyflow.md

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ NFE 增加反而经常掉点。
2222
采样步之间的 re-noising;on-policy 蒸馏阶段额外用 **DMD 反向散度监督** + **Flow-Map backward simulation**
2323
(3 段 shortcut)补上 consistency 蒸馏遗留的 exposure-bias 缺口。
2424

25-
AnyFlow 由 Yuchao Gu、Guian Fang 等人在 [NUS ShowLab](https://sites.google.com/view/showlab) 与 NVIDIA 合作完成。原始训练代码在 [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow),项目主页是 [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow)4 个发布 checkpoint 归在 [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection 里。
25+
AnyFlow 由 NVIDIA、新加坡国立大学(NUS)和 MIT 合作完成,作者为 Yuchao Gu、Guian Fang、Yuxin Jiang、Weijia Mao、Song Han、Han Cai、Mike Zheng Shou。原始训练代码在 [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow),项目主页是 [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow)4 个发布 checkpoint 归在 [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection 里。
2626

2727
本文档梳理实战要点:怎么选 pipeline、怎么用 any-step 采样、怎么把 AnyFlow 嵌进 T2V / I2V / V2V 工作流。
2828

@@ -100,7 +100,7 @@ prompt = "森林里一只小熊猫在啃竹子,电影感光照"
100100
for nfe in [1, 2, 4, 8, 16, 32]:
101101
# 每轮重建 generator —— 这样跨步数对比时唯一变量是 NFE。
102102
generator = torch.Generator("cuda").manual_seed(0)
103-
video = pipe(prompt, num_inference_steps=nfe, num_frames=33, generator=generator).frames[0]
103+
video = pipe(prompt, num_inference_steps=nfe, num_frames=81, generator=generator).frames[0]
104104
export_to_video(video, f"out_nfe{nfe}.mp4", fps=16)
105105
```
106106

@@ -125,11 +125,11 @@ Causal pipeline 用同一个蒸馏模型支持三种任务模式,**通过 `vid
125125
Context tensor 的帧数必须满足 `T = 4n + 1`,跟 VAE 时间步长对齐。
126126

127127
> [!IMPORTANT]
128-
> FAR pipeline 是分块 (chunk) rollout,`num_frames` 必须配合 chunk 调度。默认
129-
> `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]`(求和 21)对应发布 checkpoint 的标准 `num_frames=81`
130-
> (21 = (81 − 1) // 4 + 1)。改 `num_frames`**必须**显式传匹配的 `chunk_partition`使其求和等于
131-
> `(num_frames - 1) // 4 + 1`,否则 pipeline 会抛 `AssertionError`。比如 `num_frames=33` 对应 9 个 latent
132-
> 帧,可用 `chunk_partition=[1, 4, 4]`
128+
> FAR pipeline 是分块 (chunk) rollout,`num_frames` 必须配合 chunk 调度。发布的 checkpoint 在
129+
> transformer config 里写入 `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]`(求和 21),对应标准
130+
> `num_frames=81`(21 = (81 − 1) // 4 + 1)。改 `num_frames`**必须**显式传匹配的 `chunk_partition`
131+
> 使其求和等于 `(num_frames - 1) // 4 + 1`,否则 pipeline 会抛 `ValueError`。比如 `num_frames=33` 对应
132+
> 9 个 latent 帧,可用 `chunk_partition=[1, 4, 4]`
133133
134134
```py
135135
import numpy as np
@@ -183,33 +183,6 @@ export_to_video(video, "v2v.mp4", fps=16)
183183
如果你已经有 VAE 编码过的 latent,可以直接传 `video_latents=<tensor>` 跳过 `vae_encode` 步骤
184184
(和 `video` 互斥)。
185185

186-
## 显存与推理速度
187-
188-
14B 的 AnyFlow 模型用 group offload + VAE slicing 单卡 40 GB 能跑:
189-
190-
```py
191-
import torch
192-
from diffusers import AnyFlowPipeline
193-
from diffusers.hooks import apply_group_offloading
194-
195-
pipe = AnyFlowPipeline.from_pretrained(
196-
"nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
197-
)
198-
apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level")
199-
pipe.vae.enable_slicing()
200-
pipe.vae.enable_tiling()
201-
```
202-
203-
延迟方面,`torch.compile` 对 transformer(最重的模块)效果很好:
204-
205-
```py
206-
pipe = pipe.to("cuda")
207-
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
208-
```
209-
210-
编译开销跑几步就摊销掉;配合 AnyFlow 的低 NFE(4-8 步),`torch.compile` 在 14B 上相比 eager
211-
模式有明显加速。
212-
213186
## LoRA 微调
214187

215188
两个 pipeline 都复用 [`WanLoraLoaderMixin`](../api/loaders/lora),因此为对应 Wan2.1 backbone 训练的

scripts/convert_anyflow_to_diffusers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,21 @@
5757
"AnyFlow-FAR-Wan2.1-1.3B-Diffusers": {
5858
"base_model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
5959
"transformer_cls": AnyFlowFARTransformer3DModel,
60-
"transformer_kwargs": {"full_chunk_limit": 3, "compressed_patch_size": [1, 4, 4]},
60+
"transformer_kwargs": {
61+
"full_chunk_limit": 3,
62+
"compressed_patch_size": [1, 4, 4],
63+
"chunk_partition": [1, 3, 3, 3, 3, 3, 3, 2],
64+
},
6165
"pipeline_cls": AnyFlowFARPipeline,
6266
},
6367
"AnyFlow-FAR-Wan2.1-14B-Diffusers": {
6468
"base_model": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
6569
"transformer_cls": AnyFlowFARTransformer3DModel,
66-
"transformer_kwargs": {"full_chunk_limit": 3, "compressed_patch_size": [1, 4, 4]},
70+
"transformer_kwargs": {
71+
"full_chunk_limit": 3,
72+
"compressed_patch_size": [1, 4, 4],
73+
"chunk_partition": [1, 3, 3, 3, 3, 3, 3, 2],
74+
},
6775
"pipeline_cls": AnyFlowFARPipeline,
6876
},
6977
"AnyFlow-Wan2.1-T2V-1.3B-Diffusers": {

0 commit comments

Comments
 (0)