Skip to content

Commit 26b2760

Browse files
Merge branch 'main' into add-jit-diffusion
2 parents dcbd753 + 65aff37 commit 26b2760

532 files changed

Lines changed: 44386 additions & 5558 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ai/AGENTS.md

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,40 @@
44

55
Strive to write code as simple and explicit as possible.
66

7-
- Minimize small helper/utility functions — inline the logic instead. A reader should be able to follow the full flow without jumping between functions.
8-
- No defensive code or unused code paths — do not add fallback paths, safety checks, or configuration options "just in case". When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating.
7+
- Prefer inlining small helper/utility functions over factoring them out — a reader should be able to follow the full flow without jumping between functions. If a private helper has only one caller, inlining it at the call site is usually the cleaner choice.
8+
- No defensive code, unused code paths, or legacy stubs — do not add fallback paths, safety checks, or configuration options "just in case"; do not carry unused method parameters "for API consistency", backwards-compatibility aliases for names that never shipped, or deprecation shims for code that was never released. When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating.
99
- Do not guess user intent and silently correct behavior. Make the expected inputs clear in the docstring, and raise a concise error for unsupported cases rather than adding complex fallback logic.
1010

11-
---
11+
Before opening the PR, self-review against [review-rules.md](review-rules.md), which collects the most common mistakes we catch in review.
1212

13-
### Dependencies
14-
- No new mandatory dependency without discussion (e.g. `einops`)
15-
- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`
13+
---
1614

1715
## Code formatting
16+
1817
- `make style` and `make fix-copies` should be run as the final step before opening a PR
1918

2019
### Copied Code
20+
2121
- Many classes are kept in sync with a source via a `# Copied from ...` header comment
2222
- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source
2323
- Remove the header to intentionally break the link
2424

2525
### Models
26-
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
27-
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
28-
- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details.
26+
27+
- See [models.md](models.md) for model conventions, attention pattern, implementation rules, dependencies, and gotchas.
28+
- See the [model-integration](./skills/model-integration/SKILL.md) skill for the full integration workflow, file structure, test setup, and other details.
29+
30+
### Pipelines & Schedulers
31+
32+
- See [pipelines.md](pipelines.md) for pipeline conventions, patterns, and gotchas.
33+
34+
### Modular Pipelines
35+
36+
- See [modular.md](modular.md) for modular pipeline conventions, patterns, and gotchas.
2937

3038
## Skills
3139

32-
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.
33-
Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).
40+
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents. Available skills include:
41+
42+
- [model-integration](./skills/model-integration/SKILL.md) (adding/converting pipelines)
43+
- [parity-testing](./skills/parity-testing/SKILL.md) (debugging numerical parity).

.ai/models.md

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Model conventions and rules
2+
3+
Shared reference for model-related conventions, patterns, and gotchas.
4+
Linked from `AGENTS.md`, `skills/model-integration/SKILL.md`, and `review-rules.md`.
5+
6+
## Coding style
7+
8+
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
9+
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
10+
- No new mandatory dependency without discussion (e.g. `einops`). Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`.
11+
12+
## Common model conventions
13+
14+
* Models use `ModelMixin` with `register_to_config` for config serialization.
15+
* When adding a new transformer (or reviewing one), skim `src/diffusers/models/transformers/transformer_flux.py`, `src/diffusers/models/transformers/transformer_flux2.py`, `src/diffusers/models/transformers/transformer_qwenimage.py`, and `src/diffusers/models/transformers/transformer_wan.py` first to establish the pattern. Most conventions (mixin set, file structure, naming, gradient-checkpointing implementation, `_no_split_modules` settings, etc.) are easiest to internalize by comparison rather than from a fixed list.
16+
17+
## Attention pattern
18+
19+
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
20+
21+
```python
22+
# transformer_mymodel.py
23+
24+
class MyModelAttnProcessor:
25+
_attention_backend = None
26+
_parallel_config = None
27+
28+
def __call__(self, attn, hidden_states, attention_mask=None, ...):
29+
query = attn.to_q(hidden_states)
30+
key = attn.to_k(hidden_states)
31+
value = attn.to_v(hidden_states)
32+
# reshape, apply rope, etc.
33+
hidden_states = dispatch_attention_fn(
34+
query, key, value,
35+
attn_mask=attention_mask,
36+
backend=self._attention_backend,
37+
parallel_config=self._parallel_config,
38+
)
39+
hidden_states = hidden_states.flatten(2, 3)
40+
return attn.to_out[0](hidden_states)
41+
42+
43+
class MyModelAttention(nn.Module, AttentionModuleMixin):
44+
_default_processor_cls = MyModelAttnProcessor
45+
_available_processors = [MyModelAttnProcessor]
46+
47+
def __init__(self, query_dim, heads=8, dim_head=64, ...):
48+
super().__init__()
49+
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
50+
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
51+
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
52+
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
53+
self.set_processor(MyModelAttnProcessor())
54+
55+
def forward(self, hidden_states, attention_mask=None, **kwargs):
56+
return self.processor(self, hidden_states, attention_mask, **kwargs)
57+
```
58+
59+
### Attention masks
60+
61+
What you pass as `attn_mask=` to `dispatch_attention_fn` determines which backends work:
62+
63+
- **No mask needed → pass `None`, not an all-zero tensor.** A dense 4D additive float mask of all `0.0` does no math but still hard-raises on `flash` / `_flash_3` / `_sage` (see `attention_dispatch.py:2328, 2544, 3266`). Only materialize a mask when it carries information. This is the Flux / Flux2 / Wan pattern: no mask, works on every backend, relies on the model having been trained tolerating consistent padding.
64+
- **Padding mask → bool `(B, L)` or `(B, 1, 1, L)`.** Only pass when the batch actually contains different-length sequences (i.e. there is real padding). If all sequences are the same length, set the mask to `None` — many backends (flash, sage, aiter) raise `ValueError` on any non-None mask, and even SDPA-based backends pay unnecessary overhead processing a no-op mask. See `pipeline_qwenimage.py` `encode_prompt` for the pattern: `if mask.all(): mask = None`. When a mask is needed, use bool format — it stays compatible with the `*_varlen` kernels via `_normalize_attn_mask` (`attention_dispatch.py:639`), which reduces bool masks to `cu_seqlens`. Dense additive-float masks *cannot* be reduced this way and so lose the varlen path.
65+
- **Other mask types (structural, BlockMask, etc.)** — if the model requires a different mask pattern, figure out how to support as many backends as possible (e.g. use `window_size` kwarg for sliding window on flash, `BlockMask` for Flex) and document which backends are supported for that model.
66+
- **Don't declare `attention_mask` (or `encoder_hidden_states_mask`) in the forward signature if you ignore it.** "For API stability with other transformers" is not a reason; readers assume a declared param is honored, and downstream pipelines will pass padding masks that silently get dropped. Some existing models in the repo carry unused mask params for historical reasons — e.g. `QwenDoubleStreamAttnProcessor2_0.__call__` declares `encoder_hidden_states_mask` but never reads it (the joint mask is routed through `attention_mask` instead), and the block-level forward in `transformer_qwenimage.py` declares it but always receives `None`. This is a legacy behavior and should not be replicated in new models.
67+
68+
## Model class attributes
69+
70+
Each `ModelMixin` subclass can declare class-level attributes that configure optimization features. Each attribute corresponds to a user-facing API — the attribute controls how that feature behaves for the model. When adding a new transformer, set all that apply — skim `transformer_flux.py`, `transformer_wan.py`, `transformer_qwenimage.py` for examples.
71+
72+
### `_no_split_modules`
73+
74+
**API:** `Model.from_pretrained(..., device_map="auto")` — called in `model_loading_utils.py:87` via `model._get_no_split_modules()`, which feeds the list to `accelerate`'s `infer_auto_device_map(no_split_module_classes=...)`.
75+
76+
Lists which `nn.Module` subclasses must stay on a single device (i.e. never have their children placed on different devices).
77+
78+
- **`None` (default)**`from_pretrained(..., device_map="auto")` raises `ValueError` (`modeling_utils.py:1863`).
79+
- **`[]`** — split anywhere you like.
80+
- **`["MyBlock"]`** — keep all `MyBlock` instances intact on one device.
81+
82+
**Why it's needed.** When `accelerate` splits a model across devices, it installs hooks on leaf modules that move inputs to the module's device before `forward` runs. Any inline operation (`+`, `*`, `torch.cat`) that combines tensors from different submodules has no hook — if those submodules landed on different devices, it crashes with "tensors on different devices". The fix is either: (a) list the parent module in `_no_split_modules` so all its children stay co-located, or (b) pack the operation into its own `nn.Module`. Inline ops on outputs from the **same** submodule call are fine since they're already on the same device.
83+
When deciding which modules to list, inspect `forward` methods at every level of the module tree — not just the top-level model, but also its submodules recursively. Any module with inline ops combining tensors from different children or stored parameters needs to be listed.
84+
85+
Every transformer in the repo declares it — new transformers should too. It's cheap and prevents a confusing error when users try `device_map="auto"`.
86+
87+
```python
88+
_no_split_modules = ["MyModelTransformerBlock"]
89+
```
90+
91+
### `_repeated_blocks`
92+
93+
**API:** `model.compile_repeated_blocks(*args, **kwargs)` — walks all submodules, compiles each one whose `__class__.__name__` matches an entry in this list (`modeling_utils.py:1552`). Arguments are forwarded to `torch.compile`.
94+
95+
Lists the class names of the repeated sub-modules (e.g. transformer blocks) for regional compilation instead of compiling the entire model. Must match the class `__name__` exactly.
96+
97+
```python
98+
# Flux: two block types
99+
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
100+
# Wan: one block type
101+
_repeated_blocks = ["WanTransformerBlock"]
102+
```
103+
104+
Typically these are the layers that run many times (e.g. the transformer blocks in the denoising loop), since those benefit most from compilation. If empty or not set, `compile_repeated_blocks()` raises `ValueError`.
105+
106+
### `_skip_layerwise_casting_patterns`
107+
108+
**API:** `model.enable_layerwise_casting(storage_dtype=..., compute_dtype=...)` — applies hooks that store weights in a low-precision dtype and cast to compute dtype on each forward. Modules matching these patterns are skipped (`modeling_utils.py:435`).
109+
110+
List of regex/substring patterns matching module names that should **stay in full precision**. Typically precision-sensitive layers: patch embeddings, positional embeddings, normalization layers.
111+
112+
```python
113+
# Common pattern — skip embeddings and norms:
114+
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
115+
# Flux pattern:
116+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
117+
```
118+
119+
If `None`, no modules are skipped (everything gets cast). Modules in `_keep_in_fp32_modules` are also skipped automatically.
120+
121+
### `_keep_in_fp32_modules`
122+
123+
**API:** `Model.from_pretrained(..., torch_dtype=torch.bfloat16)` — during loading, modules matching these patterns are kept in `float32` even when the rest of the model is cast to the requested dtype (`modeling_utils.py:1160`). Also respected by `enable_layerwise_casting()`.
124+
125+
List of module name patterns for modules that are numerically unstable in lower precision — timestep embeddings, scale/shift tables, normalization parameters.
126+
127+
```python
128+
# Wan pattern:
129+
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
130+
```
131+
132+
If `None` (default), all modules follow the requested `torch_dtype`.
133+
134+
### `_cp_plan`
135+
136+
**API:** `model.enable_parallelism(config=parallel_config)` — when the config includes `context_parallel_config`, this plan is used by `apply_context_parallel()` to shard tensors across GPUs for sequence parallelism (`modeling_utils.py:1665`).
137+
138+
Dict describing how to partition the model's tensors for context parallelism. Maps parameter/activation names to their sharding strategy.
139+
140+
```python
141+
# Minimal example (see transformer_flux.py, transformer_wan.py for full plans):
142+
_cp_plan = {
143+
"": { ... }, # default sharding for unnamed tensors
144+
"rope": { ... }, # RoPE-specific sharding
145+
}
146+
```
147+
148+
If `None` (default), `enable_parallelism()` with `context_parallel_config` raises `ValueError` unless a `cp_plan` is passed explicitly as an argument. To derive a plan for a new model, study the mechanism in `hooks/context_parallel.py` and `_modeling_parallel.py`, compare existing plans in `transformer_flux.py` and `transformer_wan.py`, then test and adjust — correct plans depend on the model's data flow and require validation.
149+
150+
### `_supports_gradient_checkpointing`
151+
152+
**API:** `model.enable_gradient_checkpointing()` — walks submodules for a `gradient_checkpointing` attribute, flips it to `True`, and sets `_gradient_checkpointing_func` (`modeling_utils.py:285`).
153+
154+
Boolean gate. If `False` (default), calling that method raises `ValueError`. All transformers in the repo support this. To add support, just: (1) set the class attribute to `True`, (2) add `self.gradient_checkpointing = False` in `__init__`, (3) add `if torch.is_grad_enabled() and self.gradient_checkpointing:` branches in `forward` that call `self._gradient_checkpointing_func`. See gotcha #4.
155+
156+
## Gotchas
157+
158+
1. **Forgetting to register imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports — both the sub-package `__init__.py` and the top-level `src/diffusers/__init__.py` (which has `_import_structure` and `_lazy_modules`). Missing either causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
159+
160+
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
161+
162+
163+
3. **Capability flags without matching implementation.** for example, `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward.
164+
4. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
165+
166+
5. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
167+
- **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
168+
- **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
169+
```python
170+
is_mps = hidden_states.device.type == "mps"
171+
is_npu = hidden_states.device.type == "npu"
172+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
173+
```
174+
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py` for reference usages. Never leave an unconditional `torch.float64` in the model.

0 commit comments

Comments
 (0)