Skip to content

Commit fff6c8c

Browse files
authored
Merge branch 'main' into fix-bucket-batch-sampler-cache-alignment
2 parents 3cabf56 + 0f1abc4 commit fff6c8c

325 files changed

Lines changed: 18327 additions & 3836 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ai/AGENTS.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
Strive to write code as simple and explicit as possible.
66

7-
- Minimize small helper/utility functions — inline the logic instead. A reader should be able to follow the full flow without jumping between functions.
8-
- No defensive code or unused code paths — do not add fallback paths, safety checks, or configuration options "just in case". When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating.
7+
- Prefer inlining small helper/utility functions over factoring them out — a reader should be able to follow the full flow without jumping between functions. If a private helper has only one caller, inlining it at the call site is usually the cleaner choice.
8+
- No defensive code, unused code paths, or legacy stubs — do not add fallback paths, safety checks, or configuration options "just in case"; do not carry unused method parameters "for API consistency", backwards-compatibility aliases for names that never shipped, or deprecation shims for code that was never released. When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating.
99
- Do not guess user intent and silently correct behavior. Make the expected inputs clear in the docstring, and raise a concise error for unsupported cases rather than adding complex fallback logic.
1010

11+
Before opening the PR, self-review against [review-rules.md](review-rules.md), which collects the most common mistakes we catch in review.
12+
1113
---
1214

1315
## Code formatting
@@ -27,13 +29,11 @@ Strive to write code as simple and explicit as possible.
2729

2830
### Pipelines & Schedulers
2931

30-
- Pipelines inherit from `DiffusionPipeline`
31-
- Schedulers use `SchedulerMixin` with `ConfigMixin`
32-
- Use `@torch.no_grad()` on pipeline `__call__`
33-
- Support `output_type="latent"` for skipping VAE decode
34-
- Support `generator` parameter for reproducibility
35-
- Use `self.progress_bar(timesteps)` for progress tracking
36-
- Don't subclass an existing pipeline for a variant — DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`)
32+
- See [pipelines.md](pipelines.md) for pipeline conventions, patterns, and gotchas.
33+
34+
### Modular Pipelines
35+
36+
- See [modular.md](modular.md) for modular pipeline conventions, patterns, and gotchas.
3737

3838
## Skills
3939

.ai/models.md

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ Linked from `AGENTS.md`, `skills/model-integration/SKILL.md`, and `review-rules.
1111

1212
## Common model conventions
1313

14-
- Models use `ModelMixin` with `register_to_config` for config serialization
14+
* Models use `ModelMixin` with `register_to_config` for config serialization.
15+
* When adding a new transformer (or reviewing one), skim `src/diffusers/models/transformers/transformer_flux.py`, `src/diffusers/models/transformers/transformer_flux2.py`, `src/diffusers/models/transformers/transformer_qwenimage.py`, and `src/diffusers/models/transformers/transformer_wan.py` first to establish the pattern. Most conventions (mixin set, file structure, naming, gradient-checkpointing implementation, `_no_split_modules` settings, etc.) are easiest to internalize by comparison rather than from a fixed list.
1516

1617
## Attention pattern
1718

@@ -55,22 +56,33 @@ class MyModelAttention(nn.Module, AttentionModuleMixin):
5556
return self.processor(self, hidden_states, attention_mask, **kwargs)
5657
```
5758

58-
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
59+
### Attention masks
60+
61+
What you pass as `attn_mask=` to `dispatch_attention_fn` determines which backends work:
62+
63+
- **No mask needed → pass `None`, not an all-zero tensor.** A dense 4D additive float mask of all `0.0` does no math but still hard-raises on `flash` / `_flash_3` / `_sage` (see `attention_dispatch.py:2328, 2544, 3266`). Only materialize a mask when it carries information. This is the Flux / Flux2 / Wan pattern: no mask, works on every backend, relies on the model having been trained tolerating consistent padding.
64+
- **Padding mask → bool `(B, L)` or `(B, 1, 1, L)`.** Stays compatible with the `*_varlen` kernels via `_normalize_attn_mask` (`attention_dispatch.py:639`), which reduces bool masks to `cu_seqlens`. Dense additive-float masks *cannot* be reduced this way and so lose the varlen path. This is the Qwen pattern (`transformer_qwenimage.py:951`).
65+
- **Structural mask (causal, sliding-window, band-diagonal) → dense `(1, 1, L, L)` is unavoidable.** Row-varying patterns can't be expressed as `(B, L)`. Expect SDPA/Flex-only for these layers; consider Flex's `sliding_window_mask_mod` or FA3's native `window_size=` kwarg if backend flexibility matters. Consult `src/models/transformers/transformer_kandinsky.py` as a reference.
66+
- **Don't declare `attention_mask` (or `encoder_hidden_states_mask`) in the forward signature if you ignore it.** "For API stability with other transformers" is not a reason; readers assume a declared param is honored, and downstream pipelines will pass padding masks that silently get dropped. Some existing models in the repo carry unused mask params for historical reasons — e.g. `QwenDoubleStreamAttnProcessor2_0.__call__` declares `encoder_hidden_states_mask` but never reads it (the joint mask is routed through `attention_mask` instead), and the block-level forward in `transformer_qwenimage.py` declares it but always receives `None`. This is a legacy behavior and should not be replicated in new models.
5967

6068
## Gotchas
6169

62-
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
70+
1. **Forgetting to register imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports — both the sub-package `__init__.py` and the top-level `src/diffusers/__init__.py` (which has `_import_structure` and `_lazy_modules`). Missing either causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
6371

6472
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
6573

6674
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
6775

68-
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
69-
70-
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
71-
72-
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
76+
4. **Capability flags without matching implementation.** `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward — wasting memory rather than saving it, and masking the bug behind a successful run. `_no_split_modules` similarly needs to name the actual block classes that must stay on one device, or `device_map` placement causes silent correctness bugs / OOM. Copy from a similar model and verify the corresponding logic is in place; for inference-only ports just drop the flag.
7377

74-
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
78+
5. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
7579

76-
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
80+
6. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
81+
- **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
82+
- **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
83+
```python
84+
is_mps = hidden_states.device.type == "mps"
85+
is_npu = hidden_states.device.type == "npu"
86+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
87+
```
88+
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py` for reference usages. Never leave an unconditional `torch.float64` in the model.

0 commit comments

Comments
 (0)