Skip to content

Commit 3bb9c7c

Browse files
committed
fix:apr_02 beta
1 parent 28a5086 commit 3bb9c7c

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
961961
pos = ids.float()
962962
is_mps = ids.device.type == "mps"
963963
is_npu = ids.device.type == "npu"
964-
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
964+
is_neuron = ids.device.type == "neuron"
965+
freqs_dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64
965966
# Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
966967
for i in range(len(self.axes_dim)):
967968
cos, sin = get_1d_rotary_pos_embed(

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
deprecate,
3030
is_bs4_available,
3131
is_ftfy_available,
32+
is_torch_neuronx_available,
3233
is_torch_xla_available,
3334
logging,
3435
replace_example_docstring,
@@ -862,7 +863,7 @@ def __call__(
862863
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
863864

864865
# 4. Prepare timesteps
865-
if XLA_AVAILABLE:
866+
if XLA_AVAILABLE or is_torch_neuronx_available():
866867
timestep_device = "cpu"
867868
else:
868869
timestep_device = device
@@ -914,10 +915,11 @@ def __call__(
914915
# This would be a good case for the `match` statement (Python 3.10+)
915916
is_mps = latent_model_input.device.type == "mps"
916917
is_npu = latent_model_input.device.type == "npu"
918+
is_neuron = latent_model_input.device.type == "neuron"
917919
if isinstance(current_timestep, float):
918-
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
920+
dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64
919921
else:
920-
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
922+
dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64
921923
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
922924
elif len(current_timestep.shape) == 0:
923925
current_timestep = current_timestep[None].to(latent_model_input.device)

0 commit comments

Comments
 (0)