Skip to content

Commit 1d61993

Browse files
JingyaHuangclaudeyiyixuxu
authored
[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device (#13289)
* draft:add neuron as a legit backend * feat: neuron-specific changes in the pipeline * tests: eager tests * fix: style * fix:apr_02 beta * cleanup: remove tp part, for another pr * fix: restore ring_anything to ContextParallelConfig after over-aggressive TP cleanup The previous cleanup commit removed TensorParallelConfig (correct) but also accidentally removed ring_anything from ContextParallelConfig (incorrect). ring_anything is a context-parallel feature referenced in context_parallel.py and must remain in the config. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * removal: style fix * tests:sdxl + flux2 * tests: simplify * tests: simplify * fix style * fix style * fix style for doc-builder * review: address comments * review:apply suggestion for the fix of index_for_timtestep * review: stronger guard on image slice * Apply suggestions from code review Co-authored-by: YiYi Xu <yixu310@gmail.com> * fix: when set_begin_index not implemented for scheduler * review:add maybe_adjust_dtype_for_device and apply to all models with downcasting needs * fix: dependency * Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * review: apply maybe_adjust_dtype_for_device in pixart pipe * review: update .ai/models.md --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 7c12518 commit 1d61993

32 files changed

Lines changed: 309 additions & 165 deletions

.ai/models.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,14 @@ Boolean gate. If `False` (default), calling that method raises `ValueError`. All
163163
3. **Capability flags without matching implementation.** for example, `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward.
164164
4. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
165165

166-
5. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
166+
5. **`torch.float64` anywhere in the model.** MPS, NPU, and Neuron backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
167167
- **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
168-
- **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
168+
- **Only if float32 visibly degrades output, use the `maybe_adjust_dtype_for_device` helper** from `diffusers.utils.torch_utils`. It centralizes the device-specific dtype downcast (float64→float32, int64→int32) for all restricted backends (mps, npu, neuron):
169169
```python
170-
is_mps = hidden_states.device.type == "mps"
171-
is_npu = hidden_states.device.type == "npu"
172-
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
170+
from diffusers.utils.torch_utils import maybe_adjust_dtype_for_device
171+
172+
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device)
173173
```
174-
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py` for reference usages. Never leave an unconditional `torch.float64` in the model.
174+
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py`, and `pipeline_pixart_alpha.py` for reference usages. Never leave an unconditional `torch.float64` in the model.
175175

176176
6. **Using `torch.empty`.** - Do not use `torch.empty` to initialize parameters. Use `torch.zeros` or `torch.ones`, instead.

src/diffusers/models/controlnets/controlnet.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ...loaders import PeftAdapterMixin
2323
from ...loaders.single_file_model import FromOriginalModelMixin
2424
from ...utils import BaseOutput, apply_lora_scale, logging
25+
from ...utils.torch_utils import maybe_adjust_dtype_for_device
2526
from ..attention import AttentionMixin
2627
from ..attention_processor import (
2728
ADDED_KV_ATTENTION_PROCESSORS,
@@ -675,12 +676,9 @@ def forward(
675676
if not torch.is_tensor(timesteps):
676677
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
677678
# This would be a good case for the `match` statement (Python 3.10+)
678-
is_mps = sample.device.type == "mps"
679-
is_npu = sample.device.type == "npu"
680-
if isinstance(timestep, float):
681-
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
682-
else:
683-
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
679+
dtype = maybe_adjust_dtype_for_device(
680+
torch.float64 if isinstance(timestep, float) else torch.int64, sample.device
681+
)
684682
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
685683
elif len(timesteps.shape) == 0:
686684
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_sparsectrl.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin
2424
from ...utils import BaseOutput, logging
25+
from ...utils.torch_utils import maybe_adjust_dtype_for_device
2526
from ..attention import AttentionMixin
2627
from ..attention_processor import (
2728
ADDED_KV_ATTENTION_PROCESSORS,
@@ -604,12 +605,9 @@ def forward(
604605
if not torch.is_tensor(timesteps):
605606
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
606607
# This would be a good case for the `match` statement (Python 3.10+)
607-
is_mps = sample.device.type == "mps"
608-
is_npu = sample.device.type == "npu"
609-
if isinstance(timestep, float):
610-
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
611-
else:
612-
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
608+
dtype = maybe_adjust_dtype_for_device(
609+
torch.float64 if isinstance(timestep, float) else torch.int64, sample.device
610+
)
613611
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
614612
elif len(timesteps.shape) == 0:
615613
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ...configuration_utils import ConfigMixin, register_to_config
2020
from ...loaders.single_file_model import FromOriginalModelMixin
2121
from ...utils import logging
22+
from ...utils.torch_utils import maybe_adjust_dtype_for_device
2223
from ..attention import AttentionMixin
2324
from ..attention_processor import (
2425
ADDED_KV_ATTENTION_PROCESSORS,
@@ -620,12 +621,9 @@ def forward(
620621
if not torch.is_tensor(timesteps):
621622
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
622623
# This would be a good case for the `match` statement (Python 3.10+)
623-
is_mps = sample.device.type == "mps"
624-
is_npu = sample.device.type == "npu"
625-
if isinstance(timestep, float):
626-
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
627-
else:
628-
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
624+
dtype = maybe_adjust_dtype_for_device(
625+
torch.float64 if isinstance(timestep, float) else torch.int64, sample.device
626+
)
629627
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
630628
elif len(timesteps.shape) == 0:
631629
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_xs.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...utils import BaseOutput, logging
23-
from ...utils.torch_utils import apply_freeu
23+
from ...utils.torch_utils import apply_freeu, maybe_adjust_dtype_for_device
2424
from ..attention import AttentionMixin
2525
from ..attention_processor import (
2626
ADDED_KV_ATTENTION_PROCESSORS,
@@ -1014,12 +1014,9 @@ def forward(
10141014
if not torch.is_tensor(timesteps):
10151015
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
10161016
# This would be a good case for the `match` statement (Python 3.10+)
1017-
is_mps = sample.device.type == "mps"
1018-
is_npu = sample.device.type == "npu"
1019-
if isinstance(timestep, float):
1020-
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
1021-
else:
1022-
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
1017+
dtype = maybe_adjust_dtype_for_device(
1018+
torch.float64 if isinstance(timestep, float) else torch.int64, sample.device
1019+
)
10231020
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
10241021
elif len(timesteps.shape) == 0:
10251022
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch import nn
2020

2121
from ..utils import deprecate
22+
from ..utils.torch_utils import maybe_adjust_dtype_for_device
2223
from .activations import FP32SiLU, get_activation
2324
from .attention_processor import Attention
2425

@@ -346,7 +347,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
346347

347348
# Auto-detect appropriate dtype if not specified
348349
if dtype is None:
349-
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
350+
dtype = maybe_adjust_dtype_for_device(torch.float64, pos.device)
350351

351352
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
352353
omega /= embed_dim / 2.0

src/diffusers/models/transformers/transformer_anyflow.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ...configuration_utils import ConfigMixin, register_to_config
2929
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
3030
from ...utils import apply_lora_scale, logging
31+
from ...utils.torch_utils import maybe_adjust_dtype_for_device
3132
from ..attention import AttentionModuleMixin, FeedForward
3233
from ..attention_dispatch import dispatch_attention_fn
3334
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
@@ -41,9 +42,7 @@
4142

4243
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
4344
# MPS / NPU backends do not support complex128 / float64; fall back to float32 on those devices.
44-
is_mps = hidden_states.device.type == "mps"
45-
is_npu = hidden_states.device.type == "npu"
46-
rotary_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
45+
rotary_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device)
4746
x_rotated = torch.view_as_complex(hidden_states.to(rotary_dtype).unflatten(3, (-1, 2)))
4847
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
4948
return x_out.type_as(hidden_states)
@@ -341,9 +340,7 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor:
341340
if not is_compiling and self._freqs_cache is not None and self._freqs_cache[0] == cache_key:
342341
return self._freqs_cache[1]
343342

344-
is_mps = device.type == "mps"
345-
is_npu = device.type == "npu"
346-
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
343+
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, device)
347344

348345
h_dim = w_dim = 2 * (self.attention_head_dim // 6)
349346
t_dim = self.attention_head_dim - h_dim - w_dim

src/diffusers/models/transformers/transformer_anyflow_far.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ...configuration_utils import ConfigMixin, register_to_config
3131
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
3232
from ...utils import BaseOutput, apply_lora_scale, logging
33+
from ...utils.torch_utils import maybe_adjust_dtype_for_device
3334
from ..attention import AttentionModuleMixin, FeedForward
3435
from ..attention_dispatch import dispatch_attention_fn
3536
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
@@ -44,9 +45,7 @@
4445
# Copied from diffusers.models.transformers.transformer_anyflow.apply_rotary_emb
4546
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
4647
# MPS / NPU backends do not support complex128 / float64; fall back to float32 on those devices.
47-
is_mps = hidden_states.device.type == "mps"
48-
is_npu = hidden_states.device.type == "npu"
49-
rotary_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
48+
rotary_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device)
5049
x_rotated = torch.view_as_complex(hidden_states.to(rotary_dtype).unflatten(3, (-1, 2)))
5150
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
5251
return x_out.type_as(hidden_states)
@@ -650,9 +649,7 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor:
650649
if not is_compiling and self._freqs_cache is not None and self._freqs_cache[0] == cache_key:
651650
return self._freqs_cache[1]
652651

653-
is_mps = device.type == "mps"
654-
is_npu = device.type == "npu"
655-
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
652+
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, device)
656653

657654
h_dim = w_dim = 2 * (self.attention_head_dim // 6)
658655
t_dim = self.attention_head_dim - h_dim - w_dim

src/diffusers/models/transformers/transformer_bria.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ...configuration_utils import ConfigMixin, register_to_config
1010
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
1111
from ...utils import apply_lora_scale, logging
12-
from ...utils.torch_utils import maybe_allow_in_graph
12+
from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph
1313
from ..attention import AttentionModuleMixin, FeedForward
1414
from ..attention_dispatch import dispatch_attention_fn
1515
from ..cache_utils import CacheMixin
@@ -276,8 +276,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
276276
cos_out = []
277277
sin_out = []
278278
pos = ids.float()
279-
is_mps = ids.device.type == "mps"
280-
freqs_dtype = torch.float32 if is_mps else torch.float64
279+
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device)
281280
for i in range(n_axes):
282281
cos, sin = get_1d_rotary_pos_embed(
283282
self.axes_dim[i],
@@ -344,8 +343,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
344343
cos_out = []
345344
sin_out = []
346345
pos = ids.float()
347-
is_mps = ids.device.type == "mps"
348-
freqs_dtype = torch.float32 if is_mps else torch.float64
346+
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device)
349347
for i in range(n_axes):
350348
cos, sin = get_1d_rotary_pos_embed(
351349
self.axes_dim[i],

src/diffusers/models/transformers/transformer_bria_fibo.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
apply_lora_scale,
2626
logging,
2727
)
28-
from ...utils.torch_utils import maybe_allow_in_graph
28+
from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph
2929
from ..attention import AttentionModuleMixin, FeedForward
3030
from ..attention_dispatch import dispatch_attention_fn
3131
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -222,8 +222,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
222222
cos_out = []
223223
sin_out = []
224224
pos = ids.float()
225-
is_mps = ids.device.type == "mps"
226-
freqs_dtype = torch.float32 if is_mps else torch.float64
225+
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device)
227226
for i in range(n_axes):
228227
cos, sin = get_1d_rotary_pos_embed(
229228
self.axes_dim[i],

0 commit comments

Comments
 (0)