Skip to content

Commit 65aff37

Browse files
authored
fix device mismatch issue for HiDreamTransformerTests (#13766)
* fix device mismatch issue for HiDreamTransformerTests Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * refine code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 907c0c2 commit 65aff37

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -852,10 +852,16 @@ def forward(
852852

853853
# 2. Blocks
854854
block_id = 0
855-
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
855+
initial_encoder_hidden_states = torch.cat(
856+
[
857+
encoder_hidden_states[-1].to(hidden_states.device),
858+
encoder_hidden_states[-2].to(hidden_states.device),
859+
],
860+
dim=1,
861+
)
856862
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
857863
for bid, block in enumerate(self.double_stream_blocks):
858-
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
864+
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].to(hidden_states.device)
859865
cur_encoder_hidden_states = torch.cat(
860866
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
861867
)
@@ -891,7 +897,7 @@ def forward(
891897
hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
892898

893899
for bid, block in enumerate(self.single_stream_blocks):
894-
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
900+
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].to(hidden_states.device)
895901
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
896902
if torch.is_grad_enabled() and self.gradient_checkpointing:
897903
hidden_states = self._gradient_checkpointing_func(

0 commit comments

Comments
 (0)