Skip to content

Commit c8eba43

Browse files
yiyixuxuclaude
andauthored
[agents docs] update models.md with class attributes and attention mask (#13665)
* [agents docs] update models.md with class attributes and attention mask guidance - Add "Model class attributes" section documenting _no_split_modules, _repeated_blocks, _skip_layerwise_casting_patterns, _keep_in_fp32_modules, _cp_plan, and _supports_gradient_checkpointing with their corresponding user-facing APIs and how they work - Improve attention mask guidance: recommend passing None when no real padding exists, document backend compatibility - Move _no_split_modules from gotchas to its own section with first-principles explanation of why it's needed (accelerate device hooks) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * update review-rules, ask to help identify unused code path --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ffd5da5 commit c8eba43

2 files changed

Lines changed: 102 additions & 7 deletions

File tree

.ai/models.md

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,109 @@ class MyModelAttention(nn.Module, AttentionModuleMixin):
6161
What you pass as `attn_mask=` to `dispatch_attention_fn` determines which backends work:
6262

6363
- **No mask needed → pass `None`, not an all-zero tensor.** A dense 4D additive float mask of all `0.0` does no math but still hard-raises on `flash` / `_flash_3` / `_sage` (see `attention_dispatch.py:2328, 2544, 3266`). Only materialize a mask when it carries information. This is the Flux / Flux2 / Wan pattern: no mask, works on every backend, relies on the model having been trained tolerating consistent padding.
64-
- **Padding mask → bool `(B, L)` or `(B, 1, 1, L)`.** Stays compatible with the `*_varlen` kernels via `_normalize_attn_mask` (`attention_dispatch.py:639`), which reduces bool masks to `cu_seqlens`. Dense additive-float masks *cannot* be reduced this way and so lose the varlen path. This is the Qwen pattern (`transformer_qwenimage.py:951`).
65-
- **Structural mask (causal, sliding-window, band-diagonal) → dense `(1, 1, L, L)` is unavoidable.** Row-varying patterns can't be expressed as `(B, L)`. Expect SDPA/Flex-only for these layers; consider Flex's `sliding_window_mask_mod` or FA3's native `window_size=` kwarg if backend flexibility matters. Consult `src/models/transformers/transformer_kandinsky.py` as a reference.
64+
- **Padding mask → bool `(B, L)` or `(B, 1, 1, L)`.** Only pass when the batch actually contains different-length sequences (i.e. there is real padding). If all sequences are the same length, set the mask to `None` — many backends (flash, sage, aiter) raise `ValueError` on any non-None mask, and even SDPA-based backends pay unnecessary overhead processing a no-op mask. See `pipeline_qwenimage.py` `encode_prompt` for the pattern: `if mask.all(): mask = None`. When a mask is needed, use bool format — it stays compatible with the `*_varlen` kernels via `_normalize_attn_mask` (`attention_dispatch.py:639`), which reduces bool masks to `cu_seqlens`. Dense additive-float masks *cannot* be reduced this way and so lose the varlen path.
65+
- **Other mask types (structural, BlockMask, etc.)** — if the model requires a different mask pattern, figure out how to support as many backends as possible (e.g. use `window_size` kwarg for sliding window on flash, `BlockMask` for Flex) and document which backends are supported for that model.
6666
- **Don't declare `attention_mask` (or `encoder_hidden_states_mask`) in the forward signature if you ignore it.** "For API stability with other transformers" is not a reason; readers assume a declared param is honored, and downstream pipelines will pass padding masks that silently get dropped. Some existing models in the repo carry unused mask params for historical reasons — e.g. `QwenDoubleStreamAttnProcessor2_0.__call__` declares `encoder_hidden_states_mask` but never reads it (the joint mask is routed through `attention_mask` instead), and the block-level forward in `transformer_qwenimage.py` declares it but always receives `None`. This is a legacy behavior and should not be replicated in new models.
6767

68+
## Model class attributes
69+
70+
Each `ModelMixin` subclass can declare class-level attributes that configure optimization features. Each attribute corresponds to a user-facing API — the attribute controls how that feature behaves for the model. When adding a new transformer, set all that apply — skim `transformer_flux.py`, `transformer_wan.py`, `transformer_qwenimage.py` for examples.
71+
72+
### `_no_split_modules`
73+
74+
**API:** `Model.from_pretrained(..., device_map="auto")` — called in `model_loading_utils.py:87` via `model._get_no_split_modules()`, which feeds the list to `accelerate`'s `infer_auto_device_map(no_split_module_classes=...)`.
75+
76+
Lists which `nn.Module` subclasses must stay on a single device (i.e. never have their children placed on different devices).
77+
78+
- **`None` (default)**`from_pretrained(..., device_map="auto")` raises `ValueError` (`modeling_utils.py:1863`).
79+
- **`[]`** — split anywhere you like.
80+
- **`["MyBlock"]`** — keep all `MyBlock` instances intact on one device.
81+
82+
**Why it's needed.** When `accelerate` splits a model across devices, it installs hooks on leaf modules that move inputs to the module's device before `forward` runs. Any inline operation (`+`, `*`, `torch.cat`) that combines tensors from different submodules has no hook — if those submodules landed on different devices, it crashes with "tensors on different devices". The fix is either: (a) list the parent module in `_no_split_modules` so all its children stay co-located, or (b) pack the operation into its own `nn.Module`. Inline ops on outputs from the **same** submodule call are fine since they're already on the same device.
83+
When deciding which modules to list, inspect `forward` methods at every level of the module tree — not just the top-level model, but also its submodules recursively. Any module with inline ops combining tensors from different children or stored parameters needs to be listed.
84+
85+
Every transformer in the repo declares it — new transformers should too. It's cheap and prevents a confusing error when users try `device_map="auto"`.
86+
87+
```python
88+
_no_split_modules = ["MyModelTransformerBlock"]
89+
```
90+
91+
### `_repeated_blocks`
92+
93+
**API:** `model.compile_repeated_blocks(*args, **kwargs)` — walks all submodules, compiles each one whose `__class__.__name__` matches an entry in this list (`modeling_utils.py:1552`). Arguments are forwarded to `torch.compile`.
94+
95+
Lists the class names of the repeated sub-modules (e.g. transformer blocks) for regional compilation instead of compiling the entire model. Must match the class `__name__` exactly.
96+
97+
```python
98+
# Flux: two block types
99+
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
100+
# Wan: one block type
101+
_repeated_blocks = ["WanTransformerBlock"]
102+
```
103+
104+
Typically these are the layers that run many times (e.g. the transformer blocks in the denoising loop), since those benefit most from compilation. If empty or not set, `compile_repeated_blocks()` raises `ValueError`.
105+
106+
### `_skip_layerwise_casting_patterns`
107+
108+
**API:** `model.enable_layerwise_casting(storage_dtype=..., compute_dtype=...)` — applies hooks that store weights in a low-precision dtype and cast to compute dtype on each forward. Modules matching these patterns are skipped (`modeling_utils.py:435`).
109+
110+
List of regex/substring patterns matching module names that should **stay in full precision**. Typically precision-sensitive layers: patch embeddings, positional embeddings, normalization layers.
111+
112+
```python
113+
# Common pattern — skip embeddings and norms:
114+
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
115+
# Flux pattern:
116+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
117+
```
118+
119+
If `None`, no modules are skipped (everything gets cast). Modules in `_keep_in_fp32_modules` are also skipped automatically.
120+
121+
### `_keep_in_fp32_modules`
122+
123+
**API:** `Model.from_pretrained(..., torch_dtype=torch.bfloat16)` — during loading, modules matching these patterns are kept in `float32` even when the rest of the model is cast to the requested dtype (`modeling_utils.py:1160`). Also respected by `enable_layerwise_casting()`.
124+
125+
List of module name patterns for modules that are numerically unstable in lower precision — timestep embeddings, scale/shift tables, normalization parameters.
126+
127+
```python
128+
# Wan pattern:
129+
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
130+
```
131+
132+
If `None` (default), all modules follow the requested `torch_dtype`.
133+
134+
### `_cp_plan`
135+
136+
**API:** `model.enable_parallelism(config=parallel_config)` — when the config includes `context_parallel_config`, this plan is used by `apply_context_parallel()` to shard tensors across GPUs for sequence parallelism (`modeling_utils.py:1665`).
137+
138+
Dict describing how to partition the model's tensors for context parallelism. Maps parameter/activation names to their sharding strategy.
139+
140+
```python
141+
# Minimal example (see transformer_flux.py, transformer_wan.py for full plans):
142+
_cp_plan = {
143+
"": { ... }, # default sharding for unnamed tensors
144+
"rope": { ... }, # RoPE-specific sharding
145+
}
146+
```
147+
148+
If `None` (default), `enable_parallelism()` with `context_parallel_config` raises `ValueError` unless a `cp_plan` is passed explicitly as an argument. To derive a plan for a new model, study the mechanism in `hooks/context_parallel.py` and `_modeling_parallel.py`, compare existing plans in `transformer_flux.py` and `transformer_wan.py`, then test and adjust — correct plans depend on the model's data flow and require validation.
149+
150+
### `_supports_gradient_checkpointing`
151+
152+
**API:** `model.enable_gradient_checkpointing()` — walks submodules for a `gradient_checkpointing` attribute, flips it to `True`, and sets `_gradient_checkpointing_func` (`modeling_utils.py:285`).
153+
154+
Boolean gate. If `False` (default), calling that method raises `ValueError`. All transformers in the repo support this. To add support, just: (1) set the class attribute to `True`, (2) add `self.gradient_checkpointing = False` in `__init__`, (3) add `if torch.is_grad_enabled() and self.gradient_checkpointing:` branches in `forward` that call `self._gradient_checkpointing_func`. See gotcha #4.
155+
68156
## Gotchas
69157

70158
1. **Forgetting to register imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports — both the sub-package `__init__.py` and the top-level `src/diffusers/__init__.py` (which has `_import_structure` and `_lazy_modules`). Missing either causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
71159

72160
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
73161

74-
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
75-
76-
4. **Capability flags without matching implementation.** `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward — wasting memory rather than saving it, and masking the bug behind a successful run. `_no_split_modules` similarly needs to name the actual block classes that must stay on one device, or `device_map` placement causes silent correctness bugs / OOM. Copy from a similar model and verify the corresponding logic is in place; for inference-only ports just drop the flag.
77162

78-
5. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
163+
3. **Capability flags without matching implementation.** for example, `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward.
164+
4. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
79165

80-
6. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
166+
5. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
81167
- **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
82168
- **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
83169
```python

.ai/review-rules.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,12 @@ Before reviewing, read and apply the guidelines in:
1515
Common mistakes are covered in the common-mistakes / gotcha sections in [AGENTS.md](AGENTS.md), [models.md](models.md), [pipelines.md](pipelines.md), and [modular.md](modular.md). Additionally, watch for below patterns that aren't covered there:
1616

1717
- **Ephemeral context.** Comments, docstrings, and files that only made sense to the current PR's author or reviewer don't help a future reader/user/developer. Examples: `# per reviewer comment on PR #NNNN`, `# as discussed in review`, `# TODO from offline chat`, debug printouts. Same for files: parity harnesses, comparison scripts, anything in `scripts/` with hardcoded developer paths or imports from the reference repo. State the *reason* so the comment stands alone, or drop it.
18+
19+
## Dead code analysis (new models)
20+
21+
When reviewing a PR that adds a new model, trace how the model is actually called from the pipeline to identify likely dead code. Include the results as a **suggestions / additional info** section in your review (not as blocking comments — the findings are advisory).
22+
23+
1. **Trace the call path.** Read the pipeline's `__call__` and follow every call into the model — which arguments are passed, which branches are taken, which helper methods are invoked.
24+
2. **Check the default model config.** Look at the default config values in the model's `__init__` (or any published config JSON). Identify code paths that are unreachable under those defaults — e.g. an `if self.config.use_foo:` branch where `use_foo` defaults to `False` and no published checkpoint sets it to `True`.
25+
3. **Flag unused parameters and methods.** Parameters declared in `forward` (or helper methods) but never passed by the pipeline, private methods never called, layers initialized but never used in `forward`.
26+
4. **Qualify findings.** The actual model config can differ from the defaults, so any dead code identified this way is *likely* dead — not certain. Frame findings accordingly: "Under the default config and the pipeline's call path, this code appears unreachable." The PR author may know of configs or use cases that exercise the path.

0 commit comments

Comments
 (0)