You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: .ai/models.md
+93-7Lines changed: 93 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -61,23 +61,109 @@ class MyModelAttention(nn.Module, AttentionModuleMixin):
61
61
What you pass as `attn_mask=` to `dispatch_attention_fn` determines which backends work:
62
62
63
63
-**No mask needed → pass `None`, not an all-zero tensor.** A dense 4D additive float mask of all `0.0` does no math but still hard-raises on `flash` / `_flash_3` / `_sage` (see `attention_dispatch.py:2328, 2544, 3266`). Only materialize a mask when it carries information. This is the Flux / Flux2 / Wan pattern: no mask, works on every backend, relies on the model having been trained tolerating consistent padding.
64
-
-**Padding mask → bool `(B, L)` or `(B, 1, 1, L)`.**Stays compatible with the `*_varlen` kernels via `_normalize_attn_mask` (`attention_dispatch.py:639`), which reduces bool masks to `cu_seqlens`. Dense additive-float masks *cannot* be reduced this way and so lose the varlen path. This is the Qwen pattern (`transformer_qwenimage.py:951`).
65
-
-**Structural mask (causal, sliding-window, band-diagonal) → dense `(1, 1, L, L)` is unavoidable.** Row-varying patterns can't be expressed as `(B, L)`. Expect SDPA/Flex-only for these layers; consider Flex's `sliding_window_mask_mod` or FA3's native `window_size=` kwarg if backend flexibility matters. Consult `src/models/transformers/transformer_kandinsky.py` as a reference.
64
+
-**Padding mask → bool `(B, L)` or `(B, 1, 1, L)`.**Only pass when the batch actually contains different-length sequences (i.e. there is real padding). If all sequences are the same length, set the mask to `None` — many backends (flash, sage, aiter) raise `ValueError` on any non-None mask, and even SDPA-based backends pay unnecessary overhead processing a no-op mask. See `pipeline_qwenimage.py``encode_prompt` for the pattern: `if mask.all(): mask = None`. When a mask is needed, use bool format — it stays compatible with the `*_varlen` kernels via `_normalize_attn_mask` (`attention_dispatch.py:639`), which reduces bool masks to `cu_seqlens`. Dense additive-float masks *cannot* be reduced this way and so lose the varlen path.
65
+
-**Other mask types (structural, BlockMask, etc.)** — if the model requires a different mask pattern, figure out how to support as many backends as possible (e.g. use `window_size` kwarg for sliding window on flash, `BlockMask` for Flex) and document which backends are supported for that model.
66
66
-**Don't declare `attention_mask` (or `encoder_hidden_states_mask`) in the forward signature if you ignore it.** "For API stability with other transformers" is not a reason; readers assume a declared param is honored, and downstream pipelines will pass padding masks that silently get dropped. Some existing models in the repo carry unused mask params for historical reasons — e.g. `QwenDoubleStreamAttnProcessor2_0.__call__` declares `encoder_hidden_states_mask` but never reads it (the joint mask is routed through `attention_mask` instead), and the block-level forward in `transformer_qwenimage.py` declares it but always receives `None`. This is a legacy behavior and should not be replicated in new models.
67
67
68
+
## Model class attributes
69
+
70
+
Each `ModelMixin` subclass can declare class-level attributes that configure optimization features. Each attribute corresponds to a user-facing API — the attribute controls how that feature behaves for the model. When adding a new transformer, set all that apply — skim `transformer_flux.py`, `transformer_wan.py`, `transformer_qwenimage.py` for examples.
71
+
72
+
### `_no_split_modules`
73
+
74
+
**API:**`Model.from_pretrained(..., device_map="auto")` — called in `model_loading_utils.py:87` via `model._get_no_split_modules()`, which feeds the list to `accelerate`'s `infer_auto_device_map(no_split_module_classes=...)`.
75
+
76
+
Lists which `nn.Module` subclasses must stay on a single device (i.e. never have their children placed on different devices).
-**`["MyBlock"]`** — keep all `MyBlock` instances intact on one device.
81
+
82
+
**Why it's needed.** When `accelerate` splits a model across devices, it installs hooks on leaf modules that move inputs to the module's device before `forward` runs. Any inline operation (`+`, `*`, `torch.cat`) that combines tensors from different submodules has no hook — if those submodules landed on different devices, it crashes with "tensors on different devices". The fix is either: (a) list the parent module in `_no_split_modules` so all its children stay co-located, or (b) pack the operation into its own `nn.Module`. Inline ops on outputs from the **same** submodule call are fine since they're already on the same device.
83
+
When deciding which modules to list, inspect `forward` methods at every level of the module tree — not just the top-level model, but also its submodules recursively. Any module with inline ops combining tensors from different children or stored parameters needs to be listed.
84
+
85
+
Every transformer in the repo declares it — new transformers should too. It's cheap and prevents a confusing error when users try `device_map="auto"`.
86
+
87
+
```python
88
+
_no_split_modules = ["MyModelTransformerBlock"]
89
+
```
90
+
91
+
### `_repeated_blocks`
92
+
93
+
**API:**`model.compile_repeated_blocks(*args, **kwargs)` — walks all submodules, compiles each one whose `__class__.__name__` matches an entry in this list (`modeling_utils.py:1552`). Arguments are forwarded to `torch.compile`.
94
+
95
+
Lists the class names of the repeated sub-modules (e.g. transformer blocks) for regional compilation instead of compiling the entire model. Must match the class `__name__` exactly.
Typically these are the layers that run many times (e.g. the transformer blocks in the denoising loop), since those benefit most from compilation. If empty or not set, `compile_repeated_blocks()` raises `ValueError`.
105
+
106
+
### `_skip_layerwise_casting_patterns`
107
+
108
+
**API:**`model.enable_layerwise_casting(storage_dtype=..., compute_dtype=...)` — applies hooks that store weights in a low-precision dtype and cast to compute dtype on each forward. Modules matching these patterns are skipped (`modeling_utils.py:435`).
109
+
110
+
List of regex/substring patterns matching module names that should **stay in full precision**. Typically precision-sensitive layers: patch embeddings, positional embeddings, normalization layers.
If `None`, no modules are skipped (everything gets cast). Modules in `_keep_in_fp32_modules` are also skipped automatically.
120
+
121
+
### `_keep_in_fp32_modules`
122
+
123
+
**API:**`Model.from_pretrained(..., torch_dtype=torch.bfloat16)` — during loading, modules matching these patterns are kept in `float32` even when the rest of the model is cast to the requested dtype (`modeling_utils.py:1160`). Also respected by `enable_layerwise_casting()`.
124
+
125
+
List of module name patterns for modules that are numerically unstable in lower precision — timestep embeddings, scale/shift tables, normalization parameters.
If `None` (default), all modules follow the requested `torch_dtype`.
133
+
134
+
### `_cp_plan`
135
+
136
+
**API:**`model.enable_parallelism(config=parallel_config)` — when the config includes `context_parallel_config`, this plan is used by `apply_context_parallel()` to shard tensors across GPUs for sequence parallelism (`modeling_utils.py:1665`).
137
+
138
+
Dict describing how to partition the model's tensors for context parallelism. Maps parameter/activation names to their sharding strategy.
139
+
140
+
```python
141
+
# Minimal example (see transformer_flux.py, transformer_wan.py for full plans):
142
+
_cp_plan = {
143
+
"": { ... }, # default sharding for unnamed tensors
144
+
"rope": { ... }, # RoPE-specific sharding
145
+
}
146
+
```
147
+
148
+
If `None` (default), `enable_parallelism()` with `context_parallel_config` raises `ValueError` unless a `cp_plan` is passed explicitly as an argument. To derive a plan for a new model, study the mechanism in `hooks/context_parallel.py` and `_modeling_parallel.py`, compare existing plans in `transformer_flux.py` and `transformer_wan.py`, then test and adjust — correct plans depend on the model's data flow and require validation.
149
+
150
+
### `_supports_gradient_checkpointing`
151
+
152
+
**API:**`model.enable_gradient_checkpointing()` — walks submodules for a `gradient_checkpointing` attribute, flips it to `True`, and sets `_gradient_checkpointing_func` (`modeling_utils.py:285`).
153
+
154
+
Boolean gate. If `False` (default), calling that method raises `ValueError`. All transformers in the repo support this. To add support, just: (1) set the class attribute to `True`, (2) add `self.gradient_checkpointing = False` in `__init__`, (3) add `if torch.is_grad_enabled() and self.gradient_checkpointing:` branches in `forward` that call `self._gradient_checkpointing_func`. See gotcha #4.
155
+
68
156
## Gotchas
69
157
70
158
1.**Forgetting to register imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports — both the sub-package `__init__.py` and the top-level `src/diffusers/__init__.py` (which has `_import_structure` and `_lazy_modules`). Missing either causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
71
159
72
160
2.**Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
73
161
74
-
3.**Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
75
-
76
-
4.**Capability flags without matching implementation.**`_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward — wasting memory rather than saving it, and masking the bug behind a successful run. `_no_split_modules` similarly needs to name the actual block classes that must stay on one device, or `device_map` placement causes silent correctness bugs / OOM. Copy from a similar model and verify the corresponding logic is in place; for inference-only ports just drop the flag.
77
162
78
-
5.**Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
163
+
3.**Capability flags without matching implementation.** for example, `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward.
164
+
4.**Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
79
165
80
-
6.**`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
166
+
5.**`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
81
167
-**Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
82
168
-**Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
Copy file name to clipboardExpand all lines: .ai/review-rules.md
+9Lines changed: 9 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -15,3 +15,12 @@ Before reviewing, read and apply the guidelines in:
15
15
Common mistakes are covered in the common-mistakes / gotcha sections in [AGENTS.md](AGENTS.md), [models.md](models.md), [pipelines.md](pipelines.md), and [modular.md](modular.md). Additionally, watch for below patterns that aren't covered there:
16
16
17
17
-**Ephemeral context.** Comments, docstrings, and files that only made sense to the current PR's author or reviewer don't help a future reader/user/developer. Examples: `# per reviewer comment on PR #NNNN`, `# as discussed in review`, `# TODO from offline chat`, debug printouts. Same for files: parity harnesses, comparison scripts, anything in `scripts/` with hardcoded developer paths or imports from the reference repo. State the *reason* so the comment stands alone, or drop it.
18
+
19
+
## Dead code analysis (new models)
20
+
21
+
When reviewing a PR that adds a new model, trace how the model is actually called from the pipeline to identify likely dead code. Include the results as a **suggestions / additional info** section in your review (not as blocking comments — the findings are advisory).
22
+
23
+
1.**Trace the call path.** Read the pipeline's `__call__` and follow every call into the model — which arguments are passed, which branches are taken, which helper methods are invoked.
24
+
2.**Check the default model config.** Look at the default config values in the model's `__init__` (or any published config JSON). Identify code paths that are unreachable under those defaults — e.g. an `if self.config.use_foo:` branch where `use_foo` defaults to `False` and no published checkpoint sets it to `True`.
25
+
3.**Flag unused parameters and methods.** Parameters declared in `forward` (or helper methods) but never passed by the pipeline, private methods never called, layers initialized but never used in `forward`.
26
+
4.**Qualify findings.** The actual model config can differ from the defaults, so any dead code identified this way is *likely* dead — not certain. Frame findings accordingly: "Under the default config and the pipeline's call path, this code appears unreachable." The PR author may know of configs or use cases that exercise the path.
0 commit comments