Skip to content

Commit 6cf3b8e

Browse files
Copilotanxiangsir
andcommitted
Replace device=dev with .cuda() for single GPU visibility
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent ba255bd commit 6cf3b8e

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

training/train.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -584,11 +584,11 @@ def wrap_ddp(model):
584584
list_batch_sizes.append(head_input.size(0))
585585
visible_indices = list_data_batch[head_id]["video_visible_indices"].long()
586586

587-
bs, dev = visible_indices.shape[0], visible_indices.device
587+
bs = visible_indices.shape[0]
588588
out = visible_indices[:, :args.target_num].clone()
589589
n1, n2 = int(bs * 0.5), int(bs * 0.875)
590590

591-
idx_range = torch.arange(bs, device=dev)
591+
idx_range = torch.arange(bs).cuda()
592592
mask_residual = idx_range < n1
593593
mask_frame_sampling = (idx_range >= n1) & (idx_range < n2)
594594
mask_collage = idx_range >= n2
@@ -604,8 +604,8 @@ def wrap_ddp(model):
604604
FRAMES = 64
605605
if mask_frame_sampling.any():
606606
nB = mask_frame_sampling.sum().item()
607-
frames = torch.arange(args.actual_num_frames, device=dev) * (FRAMES // args.actual_num_frames) + torch.randint(FRAMES // args.actual_num_frames, (nB, args.actual_num_frames), device=dev)
608-
sel_b = (frames.unsqueeze(-1) * args.num_tokens_per_frame + torch.arange(args.num_tokens_per_frame, device=dev)).reshape(nB, -1)
607+
frames = torch.arange(args.actual_num_frames).cuda() * (FRAMES // args.actual_num_frames) + torch.randint(FRAMES // args.actual_num_frames, (nB, args.actual_num_frames)).cuda()
608+
sel_b = (frames.unsqueeze(-1) * args.num_tokens_per_frame + torch.arange(args.num_tokens_per_frame).cuda()).reshape(nB, -1)
609609
if sel_b.size(1) > args.target_num:
610610
sel_b = sel_b[:, :args.target_num]
611611
elif sel_b.size(1) < args.target_num:
@@ -662,10 +662,10 @@ def wrap_ddp(model):
662662
Hf = head_subset.size(3)
663663
Wf = head_subset.size(4)
664664
avg = FRAMES // args.actual_num_frames # 8
665-
base = torch.arange(args.actual_num_frames, device=dev) * avg
666-
offs = torch.randint(avg, (nC, args.actual_num_frames), device=dev)
665+
base = torch.arange(args.actual_num_frames).cuda() * avg
666+
offs = torch.randint(avg, (nC, args.actual_num_frames)).cuda()
667667
frames_idx = (base.unsqueeze(0) + offs).long().clamp(max=FRAMES - 1) # [nC, actual_num_frames], 范围在 [0, 63]
668-
idx_expand = frames_idx.view(nC, 1, args.actual_num_frames, 1, 1).expand(-1, Cf, -1, Hf, Wf).to(head_subset.device)
668+
idx_expand = frames_idx.view(nC, 1, args.actual_num_frames, 1, 1).expand(-1, Cf, -1, Hf, Wf)
669669
sel_frames = torch.gather(head_subset, 2, idx_expand) # [nC, Cf, actual_num_frames, Hf, Wf]
670670
sel_frames = sel_frames.permute(0, 2, 1, 3, 4) # [nC, actual_num_frames, Cf, Hf, Wf]
671671
grid_rows = [sel_frames[:, i, :, :, :] for i in range(args.actual_num_frames)]
@@ -680,7 +680,7 @@ def wrap_ddp(model):
680680

681681
D = combined_head_output.size(1)
682682

683-
head_embedding_full = torch.zeros(bs, D, device=dev, dtype=torch.float32)
683+
head_embedding_full = torch.zeros(bs, D, dtype=torch.float32).cuda()
684684
if combined_mask.any():
685685
head_embedding_full[combined_idx] = combined_head_output
686686
if mask_collage.any():

0 commit comments

Comments
 (0)