@@ -852,10 +852,16 @@ def forward(
852852
853853 # 2. Blocks
854854 block_id = 0
855- initial_encoder_hidden_states = torch .cat ([encoder_hidden_states [- 1 ], encoder_hidden_states [- 2 ]], dim = 1 )
855+ initial_encoder_hidden_states = torch .cat (
856+ [
857+ encoder_hidden_states [- 1 ].to (hidden_states .device ),
858+ encoder_hidden_states [- 2 ].to (hidden_states .device ),
859+ ],
860+ dim = 1 ,
861+ )
856862 initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states .shape [1 ]
857863 for bid , block in enumerate (self .double_stream_blocks ):
858- cur_llama31_encoder_hidden_states = encoder_hidden_states [block_id ]
864+ cur_llama31_encoder_hidden_states = encoder_hidden_states [block_id ]. to ( hidden_states . device )
859865 cur_encoder_hidden_states = torch .cat (
860866 [initial_encoder_hidden_states , cur_llama31_encoder_hidden_states ], dim = 1
861867 )
@@ -891,7 +897,7 @@ def forward(
891897 hidden_states_masks = torch .cat ([hidden_states_masks , encoder_attention_mask_ones ], dim = 1 )
892898
893899 for bid , block in enumerate (self .single_stream_blocks ):
894- cur_llama31_encoder_hidden_states = encoder_hidden_states [block_id ]
900+ cur_llama31_encoder_hidden_states = encoder_hidden_states [block_id ]. to ( hidden_states . device )
895901 hidden_states = torch .cat ([hidden_states , cur_llama31_encoder_hidden_states ], dim = 1 )
896902 if torch .is_grad_enabled () and self .gradient_checkpointing :
897903 hidden_states = self ._gradient_checkpointing_func (
0 commit comments