Skip to content

Commit 646ab6e

Browse files
authored
Merge branch 'main' into autoencoderkl-tests-refactor
2 parents 63afbae + e16719a commit 646ab6e

221 files changed

Lines changed: 13867 additions & 882 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ai/AGENTS.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
Strive to write code as simple and explicit as possible.
66

7-
- Minimize small helper/utility functions — inline the logic instead. A reader should be able to follow the full flow without jumping between functions.
8-
- No defensive code or unused code paths — do not add fallback paths, safety checks, or configuration options "just in case". When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating.
7+
- Prefer inlining small helper/utility functions over factoring them out — a reader should be able to follow the full flow without jumping between functions. If a private helper has only one caller, inlining it at the call site is usually the cleaner choice.
8+
- No defensive code, unused code paths, or legacy stubs — do not add fallback paths, safety checks, or configuration options "just in case"; do not carry unused method parameters "for API consistency", backwards-compatibility aliases for names that never shipped, or deprecation shims for code that was never released. When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating.
99
- Do not guess user intent and silently correct behavior. Make the expected inputs clear in the docstring, and raise a concise error for unsupported cases rather than adding complex fallback logic.
1010

11+
Before opening the PR, self-review against [review-rules.md](review-rules.md), which collects the most common mistakes we catch in review.
12+
1113
---
1214

1315
## Code formatting
@@ -27,13 +29,11 @@ Strive to write code as simple and explicit as possible.
2729

2830
### Pipelines & Schedulers
2931

30-
- Pipelines inherit from `DiffusionPipeline`
31-
- Schedulers use `SchedulerMixin` with `ConfigMixin`
32-
- Use `@torch.no_grad()` on pipeline `__call__`
33-
- Support `output_type="latent"` for skipping VAE decode
34-
- Support `generator` parameter for reproducibility
35-
- Use `self.progress_bar(timesteps)` for progress tracking
36-
- Don't subclass an existing pipeline for a variant — DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`)
32+
- See [pipelines.md](pipelines.md) for pipeline conventions, patterns, and gotchas.
33+
34+
### Modular Pipelines
35+
36+
- See [modular.md](modular.md) for modular pipeline conventions, patterns, and gotchas.
3737

3838
## Skills
3939

.ai/models.md

Lines changed: 109 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ Linked from `AGENTS.md`, `skills/model-integration/SKILL.md`, and `review-rules.
1111

1212
## Common model conventions
1313

14-
- Models use `ModelMixin` with `register_to_config` for config serialization
14+
* Models use `ModelMixin` with `register_to_config` for config serialization.
15+
* When adding a new transformer (or reviewing one), skim `src/diffusers/models/transformers/transformer_flux.py`, `src/diffusers/models/transformers/transformer_flux2.py`, `src/diffusers/models/transformers/transformer_qwenimage.py`, and `src/diffusers/models/transformers/transformer_wan.py` first to establish the pattern. Most conventions (mixin set, file structure, naming, gradient-checkpointing implementation, `_no_split_modules` settings, etc.) are easiest to internalize by comparison rather than from a fixed list.
1516

1617
## Attention pattern
1718

@@ -55,22 +56,119 @@ class MyModelAttention(nn.Module, AttentionModuleMixin):
5556
return self.processor(self, hidden_states, attention_mask, **kwargs)
5657
```
5758

58-
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
59+
### Attention masks
5960

60-
## Gotchas
61+
What you pass as `attn_mask=` to `dispatch_attention_fn` determines which backends work:
6162

62-
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
63+
- **No mask needed → pass `None`, not an all-zero tensor.** A dense 4D additive float mask of all `0.0` does no math but still hard-raises on `flash` / `_flash_3` / `_sage` (see `attention_dispatch.py:2328, 2544, 3266`). Only materialize a mask when it carries information. This is the Flux / Flux2 / Wan pattern: no mask, works on every backend, relies on the model having been trained tolerating consistent padding.
64+
- **Padding mask → bool `(B, L)` or `(B, 1, 1, L)`.** Only pass when the batch actually contains different-length sequences (i.e. there is real padding). If all sequences are the same length, set the mask to `None` — many backends (flash, sage, aiter) raise `ValueError` on any non-None mask, and even SDPA-based backends pay unnecessary overhead processing a no-op mask. See `pipeline_qwenimage.py` `encode_prompt` for the pattern: `if mask.all(): mask = None`. When a mask is needed, use bool format — it stays compatible with the `*_varlen` kernels via `_normalize_attn_mask` (`attention_dispatch.py:639`), which reduces bool masks to `cu_seqlens`. Dense additive-float masks *cannot* be reduced this way and so lose the varlen path.
65+
- **Other mask types (structural, BlockMask, etc.)** — if the model requires a different mask pattern, figure out how to support as many backends as possible (e.g. use `window_size` kwarg for sliding window on flash, `BlockMask` for Flex) and document which backends are supported for that model.
66+
- **Don't declare `attention_mask` (or `encoder_hidden_states_mask`) in the forward signature if you ignore it.** "For API stability with other transformers" is not a reason; readers assume a declared param is honored, and downstream pipelines will pass padding masks that silently get dropped. Some existing models in the repo carry unused mask params for historical reasons — e.g. `QwenDoubleStreamAttnProcessor2_0.__call__` declares `encoder_hidden_states_mask` but never reads it (the joint mask is routed through `attention_mask` instead), and the block-level forward in `transformer_qwenimage.py` declares it but always receives `None`. This is a legacy behavior and should not be replicated in new models.
6367

64-
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
68+
## Model class attributes
69+
70+
Each `ModelMixin` subclass can declare class-level attributes that configure optimization features. Each attribute corresponds to a user-facing API — the attribute controls how that feature behaves for the model. When adding a new transformer, set all that apply — skim `transformer_flux.py`, `transformer_wan.py`, `transformer_qwenimage.py` for examples.
71+
72+
### `_no_split_modules`
73+
74+
**API:** `Model.from_pretrained(..., device_map="auto")` — called in `model_loading_utils.py:87` via `model._get_no_split_modules()`, which feeds the list to `accelerate`'s `infer_auto_device_map(no_split_module_classes=...)`.
75+
76+
Lists which `nn.Module` subclasses must stay on a single device (i.e. never have their children placed on different devices).
77+
78+
- **`None` (default)**`from_pretrained(..., device_map="auto")` raises `ValueError` (`modeling_utils.py:1863`).
79+
- **`[]`** — split anywhere you like.
80+
- **`["MyBlock"]`** — keep all `MyBlock` instances intact on one device.
81+
82+
**Why it's needed.** When `accelerate` splits a model across devices, it installs hooks on leaf modules that move inputs to the module's device before `forward` runs. Any inline operation (`+`, `*`, `torch.cat`) that combines tensors from different submodules has no hook — if those submodules landed on different devices, it crashes with "tensors on different devices". The fix is either: (a) list the parent module in `_no_split_modules` so all its children stay co-located, or (b) pack the operation into its own `nn.Module`. Inline ops on outputs from the **same** submodule call are fine since they're already on the same device.
83+
When deciding which modules to list, inspect `forward` methods at every level of the module tree — not just the top-level model, but also its submodules recursively. Any module with inline ops combining tensors from different children or stored parameters needs to be listed.
84+
85+
Every transformer in the repo declares it — new transformers should too. It's cheap and prevents a confusing error when users try `device_map="auto"`.
86+
87+
```python
88+
_no_split_modules = ["MyModelTransformerBlock"]
89+
```
6590

66-
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
91+
### `_repeated_blocks`
6792

68-
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
93+
**API:** `model.compile_repeated_blocks(*args, **kwargs)` — walks all submodules, compiles each one whose `__class__.__name__` matches an entry in this list (`modeling_utils.py:1552`). Arguments are forwarded to `torch.compile`.
94+
95+
Lists the class names of the repeated sub-modules (e.g. transformer blocks) for regional compilation instead of compiling the entire model. Must match the class `__name__` exactly.
96+
97+
```python
98+
# Flux: two block types
99+
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
100+
# Wan: one block type
101+
_repeated_blocks = ["WanTransformerBlock"]
102+
```
103+
104+
Typically these are the layers that run many times (e.g. the transformer blocks in the denoising loop), since those benefit most from compilation. If empty or not set, `compile_repeated_blocks()` raises `ValueError`.
105+
106+
### `_skip_layerwise_casting_patterns`
107+
108+
**API:** `model.enable_layerwise_casting(storage_dtype=..., compute_dtype=...)` — applies hooks that store weights in a low-precision dtype and cast to compute dtype on each forward. Modules matching these patterns are skipped (`modeling_utils.py:435`).
109+
110+
List of regex/substring patterns matching module names that should **stay in full precision**. Typically precision-sensitive layers: patch embeddings, positional embeddings, normalization layers.
111+
112+
```python
113+
# Common pattern — skip embeddings and norms:
114+
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
115+
# Flux pattern:
116+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
117+
```
69118

70-
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
119+
If `None`, no modules are skipped (everything gets cast). Modules in `_keep_in_fp32_modules` are also skipped automatically.
120+
121+
### `_keep_in_fp32_modules`
122+
123+
**API:** `Model.from_pretrained(..., torch_dtype=torch.bfloat16)` — during loading, modules matching these patterns are kept in `float32` even when the rest of the model is cast to the requested dtype (`modeling_utils.py:1160`). Also respected by `enable_layerwise_casting()`.
124+
125+
List of module name patterns for modules that are numerically unstable in lower precision — timestep embeddings, scale/shift tables, normalization parameters.
126+
127+
```python
128+
# Wan pattern:
129+
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
130+
```
131+
132+
If `None` (default), all modules follow the requested `torch_dtype`.
133+
134+
### `_cp_plan`
135+
136+
**API:** `model.enable_parallelism(config=parallel_config)` — when the config includes `context_parallel_config`, this plan is used by `apply_context_parallel()` to shard tensors across GPUs for sequence parallelism (`modeling_utils.py:1665`).
137+
138+
Dict describing how to partition the model's tensors for context parallelism. Maps parameter/activation names to their sharding strategy.
139+
140+
```python
141+
# Minimal example (see transformer_flux.py, transformer_wan.py for full plans):
142+
_cp_plan = {
143+
"": { ... }, # default sharding for unnamed tensors
144+
"rope": { ... }, # RoPE-specific sharding
145+
}
146+
```
147+
148+
If `None` (default), `enable_parallelism()` with `context_parallel_config` raises `ValueError` unless a `cp_plan` is passed explicitly as an argument. To derive a plan for a new model, study the mechanism in `hooks/context_parallel.py` and `_modeling_parallel.py`, compare existing plans in `transformer_flux.py` and `transformer_wan.py`, then test and adjust — correct plans depend on the model's data flow and require validation.
149+
150+
### `_supports_gradient_checkpointing`
151+
152+
**API:** `model.enable_gradient_checkpointing()` — walks submodules for a `gradient_checkpointing` attribute, flips it to `True`, and sets `_gradient_checkpointing_func` (`modeling_utils.py:285`).
153+
154+
Boolean gate. If `False` (default), calling that method raises `ValueError`. All transformers in the repo support this. To add support, just: (1) set the class attribute to `True`, (2) add `self.gradient_checkpointing = False` in `__init__`, (3) add `if torch.is_grad_enabled() and self.gradient_checkpointing:` branches in `forward` that call `self._gradient_checkpointing_func`. See gotcha #4.
155+
156+
## Gotchas
157+
158+
1. **Forgetting to register imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports — both the sub-package `__init__.py` and the top-level `src/diffusers/__init__.py` (which has `_import_structure` and `_lazy_modules`). Missing either causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
159+
160+
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
71161

72-
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
73162

74-
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
163+
3. **Capability flags without matching implementation.** for example, `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward.
164+
4. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
75165

76-
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
166+
5. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
167+
- **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
168+
- **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
169+
```python
170+
is_mps = hidden_states.device.type == "mps"
171+
is_npu = hidden_states.device.type == "npu"
172+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
173+
```
174+
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py` for reference usages. Never leave an unconditional `torch.float64` in the model.

0 commit comments

Comments
 (0)