You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: .ai/models.md
+6-6Lines changed: 6 additions & 6 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -163,14 +163,14 @@ Boolean gate. If `False` (default), calling that method raises `ValueError`. All
163
163
3.**Capability flags without matching implementation.** for example, `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward.
164
164
4.**Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
165
165
166
-
5.**`torch.float64` anywhere in the model.** MPSand several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
166
+
5.**`torch.float64` anywhere in the model.** MPS, NPU, and Neuron backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
167
167
-**Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
168
-
-**Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
168
+
-**Only if float32 visibly degrades output, use the `maybe_adjust_dtype_for_device` helper** from `diffusers.utils.torch_utils`. It centralizes the device-specific dtype downcast (float64→float32, int64→int32) for all restricted backends (mps, npu, neuron):
169
169
```python
170
-
is_mps = hidden_states.device.type =="mps"
171
-
is_npu = hidden_states.device.type =="npu"
172
-
freqs_dtype =torch.float32 if (is_mps or is_npu) elsetorch.float64
170
+
from diffusers.utils.torch_utils import maybe_adjust_dtype_for_device
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py`for reference usages. Never leave an unconditional `torch.float64`in the model.
174
+
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py`, and`pipeline_pixart_alpha.py`for reference usages. Never leave an unconditional `torch.float64`in the model.
175
175
176
176
6. **Using `torch.empty`.**- Do not use `torch.empty` to initialize parameters. Use `torch.zeros`or`torch.ones`, instead.
0 commit comments