@@ -584,11 +584,11 @@ def wrap_ddp(model):
584584 list_batch_sizes .append (head_input .size (0 ))
585585 visible_indices = list_data_batch [head_id ]["video_visible_indices" ].long ()
586586
587- bs , dev = visible_indices .shape [0 ], visible_indices . device
587+ bs = visible_indices .shape [0 ]
588588 out = visible_indices [:, :args .target_num ].clone ()
589589 n1 , n2 = int (bs * 0.5 ), int (bs * 0.875 )
590590
591- idx_range = torch .arange (bs , device = dev )
591+ idx_range = torch .arange (bs ). cuda ( )
592592 mask_residual = idx_range < n1
593593 mask_frame_sampling = (idx_range >= n1 ) & (idx_range < n2 )
594594 mask_collage = idx_range >= n2
@@ -604,8 +604,8 @@ def wrap_ddp(model):
604604 FRAMES = 64
605605 if mask_frame_sampling .any ():
606606 nB = mask_frame_sampling .sum ().item ()
607- frames = torch .arange (args .actual_num_frames , device = dev ) * (FRAMES // args .actual_num_frames ) + torch .randint (FRAMES // args .actual_num_frames , (nB , args .actual_num_frames ), device = dev )
608- sel_b = (frames .unsqueeze (- 1 ) * args .num_tokens_per_frame + torch .arange (args .num_tokens_per_frame , device = dev )).reshape (nB , - 1 )
607+ frames = torch .arange (args .actual_num_frames ). cuda () * (FRAMES // args .actual_num_frames ) + torch .randint (FRAMES // args .actual_num_frames , (nB , args .actual_num_frames )). cuda ( )
608+ sel_b = (frames .unsqueeze (- 1 ) * args .num_tokens_per_frame + torch .arange (args .num_tokens_per_frame ). cuda ( )).reshape (nB , - 1 )
609609 if sel_b .size (1 ) > args .target_num :
610610 sel_b = sel_b [:, :args .target_num ]
611611 elif sel_b .size (1 ) < args .target_num :
@@ -662,10 +662,10 @@ def wrap_ddp(model):
662662 Hf = head_subset .size (3 )
663663 Wf = head_subset .size (4 )
664664 avg = FRAMES // args .actual_num_frames # 8
665- base = torch .arange (args .actual_num_frames , device = dev ) * avg
666- offs = torch .randint (avg , (nC , args .actual_num_frames ), device = dev )
665+ base = torch .arange (args .actual_num_frames ). cuda ( ) * avg
666+ offs = torch .randint (avg , (nC , args .actual_num_frames )). cuda ( )
667667 frames_idx = (base .unsqueeze (0 ) + offs ).long ().clamp (max = FRAMES - 1 ) # [nC, actual_num_frames], 范围在 [0, 63]
668- idx_expand = frames_idx .view (nC , 1 , args .actual_num_frames , 1 , 1 ).expand (- 1 , Cf , - 1 , Hf , Wf ). to ( head_subset . device )
668+ idx_expand = frames_idx .view (nC , 1 , args .actual_num_frames , 1 , 1 ).expand (- 1 , Cf , - 1 , Hf , Wf )
669669 sel_frames = torch .gather (head_subset , 2 , idx_expand ) # [nC, Cf, actual_num_frames, Hf, Wf]
670670 sel_frames = sel_frames .permute (0 , 2 , 1 , 3 , 4 ) # [nC, actual_num_frames, Cf, Hf, Wf]
671671 grid_rows = [sel_frames [:, i , :, :, :] for i in range (args .actual_num_frames )]
@@ -680,7 +680,7 @@ def wrap_ddp(model):
680680
681681 D = combined_head_output .size (1 )
682682
683- head_embedding_full = torch .zeros (bs , D , device = dev , dtype = torch .float32 )
683+ head_embedding_full = torch .zeros (bs , D , dtype = torch .float32 ). cuda ( )
684684 if combined_mask .any ():
685685 head_embedding_full [combined_idx ] = combined_head_output
686686 if mask_collage .any ():
0 commit comments