File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -591,7 +591,14 @@ def forward(
591591
592592 # 3. Modulation and residual connection
593593 hidden_states = torch .cat ([attn_output , mlp_hidden_states ], dim = 2 )
594- hidden_states = gate * self .proj_out (hidden_states )
594+ print (f"DEBUG: Before proj_out - hidden_states device: { hidden_states .device } " )
595+ print (f"DEBUG: Before proj_out - gate device: { gate .device } " )
596+ # Check proj_out layer's device (assuming it's a nn.Module with parameters)
597+ proj_out_device = next (self .proj_out .parameters ()).device if list (self .proj_out .parameters ()) else "No parameters"
598+ print (f"DEBUG: Before proj_out - self.proj_out device: { proj_out_device } " )
599+ proj_out_result = self .proj_out (hidden_states )
600+ print (f"DEBUG: After proj_out - proj_out_result device: { proj_out_result .device } " )
601+ hidden_states = gate * proj_out_result # Error likely occurs here
595602 hidden_states = hidden_states + residual
596603
597604 hidden_states , encoder_hidden_states = (
You can’t perform that action at this time.
0 commit comments