Skip to content

Commit d28ef91

Browse files
Copilotanxiangsir
andcommitted
Add shape comments with example values (bs=16, target_num=2048)
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent 6cf3b8e commit d28ef91

1 file changed

Lines changed: 65 additions & 59 deletions

File tree

training/train.py

Lines changed: 65 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -580,113 +580,119 @@ def wrap_ddp(model):
580580
list_embedding.append(head_embedding)
581581

582582
elif dataset_config.dali_type in ["decord_residual"]:
583-
head_input = list_data_batch[head_id]["videos"] # [bs, C, 64, H, W]
583+
# Example: bs=16, target_num=2048 (8*256), num_tokens_per_frame=256, H=W=224, patch_size=16
584+
# Hp=Wp=14, patches_per_frame=196, T=64, total_patches=12544
585+
586+
head_input = list_data_batch[head_id]["videos"] # [16, 3, 64, 224, 224]
584587
list_batch_sizes.append(head_input.size(0))
585-
visible_indices = list_data_batch[head_id]["video_visible_indices"].long()
588+
visible_indices = list_data_batch[head_id]["video_visible_indices"].long() # [16, >=2048]
586589

587-
bs = visible_indices.shape[0]
588-
out = visible_indices[:, :args.target_num].clone()
589-
n1, n2 = int(bs * 0.5), int(bs * 0.875)
590+
bs = visible_indices.shape[0] # 16
591+
out = visible_indices[:, :args.target_num].clone() # [16, 2048]
592+
n1, n2 = int(bs * 0.5), int(bs * 0.875) # n1=8, n2=14
590593

591-
idx_range = torch.arange(bs).cuda()
592-
mask_residual = idx_range < n1
593-
mask_frame_sampling = (idx_range >= n1) & (idx_range < n2)
594-
mask_collage = idx_range >= n2
594+
idx_range = torch.arange(bs).cuda() # [16]
595+
mask_residual = idx_range < n1 # [16], first 8 samples are True
596+
mask_frame_sampling = (idx_range >= n1) & (idx_range < n2) # [16], samples 8-13 are True
597+
mask_collage = idx_range >= n2 # [16], samples 14-15 are True
595598

596599
# mask_residual: select first args.target_num patches
597600
if mask_residual.any():
598-
sel_a = visible_indices[mask_residual, :args.target_num]
601+
sel_a = visible_indices[mask_residual, :args.target_num] # [8, 2048]
599602
if sel_a.size(1) < args.target_num:
600-
sel_a = torch.cat([sel_a, sel_a[:, -1:].expand(-1, args.target_num - sel_a.size(1))], dim=1)
601-
out[mask_residual] = sel_a
603+
sel_a = torch.cat([sel_a, sel_a[:, -1:].expand(-1, args.target_num - sel_a.size(1))], dim=1) # [8, 2048]
604+
out[mask_residual] = sel_a # out[0:8] = sel_a
602605

603606
# mask_frame_sampling: sample 8 frames from 64, get all patches per frame
604607
FRAMES = 64
605608
if mask_frame_sampling.any():
606-
nB = mask_frame_sampling.sum().item()
607-
frames = torch.arange(args.actual_num_frames).cuda() * (FRAMES // args.actual_num_frames) + torch.randint(FRAMES // args.actual_num_frames, (nB, args.actual_num_frames)).cuda()
608-
sel_b = (frames.unsqueeze(-1) * args.num_tokens_per_frame + torch.arange(args.num_tokens_per_frame).cuda()).reshape(nB, -1)
609+
nB = mask_frame_sampling.sum().item() # 6
610+
# frames: sample 1 frame from each of 8 bins (each bin has 8 frames)
611+
frames = torch.arange(args.actual_num_frames).cuda() * (FRAMES // args.actual_num_frames) + torch.randint(FRAMES // args.actual_num_frames, (nB, args.actual_num_frames)).cuda() # [6, 8], values in [0,7], [8,15], ..., [56,63]
612+
# sel_b: for each frame, get all 256 patches
613+
sel_b = (frames.unsqueeze(-1) * args.num_tokens_per_frame + torch.arange(args.num_tokens_per_frame).cuda()).reshape(nB, -1) # [6, 8*256] = [6, 2048]
609614
if sel_b.size(1) > args.target_num:
610-
sel_b = sel_b[:, :args.target_num]
615+
sel_b = sel_b[:, :args.target_num] # [6, 2048]
611616
elif sel_b.size(1) < args.target_num:
612-
sel_b = torch.cat([sel_b, sel_b[:, -1:].expand(-1, args.target_num - sel_b.size(1))], dim=1)
613-
out[mask_frame_sampling] = sel_b
617+
sel_b = torch.cat([sel_b, sel_b[:, -1:].expand(-1, args.target_num - sel_b.size(1))], dim=1) # [6, 2048]
618+
out[mask_frame_sampling] = sel_b # out[8:14] = sel_b
614619

615-
combined_mask = mask_residual | mask_frame_sampling
620+
combined_mask = mask_residual | mask_frame_sampling # [16], first 14 samples are True
616621
if combined_mask.any():
617-
combined_idx = combined_mask.nonzero(as_tuple=False).squeeze(1)
618-
video = head_input[combined_idx] # [n, C, T, H, W]
619-
vis_idx = out[combined_idx] # [n, target_num]
622+
combined_idx = combined_mask.nonzero(as_tuple=False).squeeze(1) # [14]
623+
video = head_input[combined_idx] # [14, 3, 64, 224, 224]
624+
vis_idx = out[combined_idx] # [14, 2048]
620625

621-
n, C, T, H, W = video.shape
622-
Hp, Wp = H // patch_size, W // patch_size
626+
n, C, T, H, W = video.shape # n=14, C=3, T=64, H=224, W=224
627+
Hp, Wp = H // patch_size, W // patch_size # Hp=14, Wp=14
623628

624629
# Patchify: [n, C, T, H, W] -> [n, C, T*Hp*Wp, p, p]
625-
patches = video.view(n, C, T, Hp, patch_size, Wp, patch_size).permute(0, 1, 2, 3, 5, 4, 6).reshape(n, C, T * Hp * Wp, patch_size, patch_size)
630+
# [14, 3, 64, 224, 224] -> [14, 3, 64, 14, 16, 14, 16] -> [14, 3, 64, 14, 14, 16, 16] -> [14, 3, 12544, 16, 16]
631+
patches = video.view(n, C, T, Hp, patch_size, Wp, patch_size).permute(0, 1, 2, 3, 5, 4, 6).reshape(n, C, T * Hp * Wp, patch_size, patch_size) # [14, 3, 12544, 16, 16]
626632

627633
# Select patches by vis_idx
628-
idx = vis_idx[:, None, :, None, None].expand(-1, C, -1, patch_size, patch_size)
629-
selected = torch.gather(patches, 2, idx) # [n, C, target_num, p, p]
634+
idx = vis_idx[:, None, :, None, None].expand(-1, C, -1, patch_size, patch_size) # [14, 3, 2048, 16, 16]
635+
selected = torch.gather(patches, 2, idx) # [14, 3, 2048, 16, 16]
630636

631637
# Unpatchify: [n, C, target_num, p, p] -> [n, C, T', H, W]
632-
T_new = args.target_num // (Hp * Wp)
638+
T_new = args.target_num // (Hp * Wp) # 2048 // 196 = 10
633639
if T_new == 0:
634640
T_new = 1
635-
num_patches = T_new * Hp * Wp
641+
num_patches = T_new * Hp * Wp # 10 * 196 = 1960
636642
if selected.size(2) > num_patches:
637-
selected = selected[:, :, :num_patches]
643+
selected = selected[:, :, :num_patches] # [14, 3, 1960, 16, 16]
638644
elif selected.size(2) < num_patches:
639-
selected = torch.cat([selected, selected[:, :, -1:].expand(-1, -1, num_patches - selected.size(2), -1, -1)], dim=2)
640-
combined_head_input = selected.view(n, C, T_new, Hp, Wp, patch_size, patch_size).permute(0, 1, 2, 3, 5, 4, 6).reshape(n, C, T_new, H, W)
645+
selected = torch.cat([selected, selected[:, :, -1:].expand(-1, -1, num_patches - selected.size(2), -1, -1)], dim=2) # [14, 3, 1960, 16, 16]
646+
# [14, 3, 1960, 16, 16] -> [14, 3, 10, 14, 14, 16, 16] -> [14, 3, 10, 14, 16, 14, 16] -> [14, 3, 10, 224, 224]
647+
combined_head_input = selected.view(n, C, T_new, Hp, Wp, patch_size, patch_size).permute(0, 1, 2, 3, 5, 4, 6).reshape(n, C, T_new, H, W) # [14, 3, 10, 224, 224]
641648

642649
with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
643-
combined_head_output = backbone_ddp_compiled(combined_head_input, vis_idx)
644-
combined_head_output = (combined_head_output.pooler_output if hasattr(combined_head_output, "pooler_output") else combined_head_output["head_output"]).float()
650+
combined_head_output = backbone_ddp_compiled(combined_head_input, vis_idx) # input: [14, 3, 10, 224, 224], vis_idx: [14, 2048]
651+
combined_head_output = (combined_head_output.pooler_output if hasattr(combined_head_output, "pooler_output") else combined_head_output["head_output"]).float() # [14, D]
645652

646653

647654
if mask_collage.any():
648-
coll_idx = torch.nonzero(mask_collage, as_tuple=False).squeeze(1)
649-
nC = coll_idx.numel()
650-
FRAMES = 64 # assume fixed 64 frames for head_subset
655+
coll_idx = torch.nonzero(mask_collage, as_tuple=False).squeeze(1) # [2]
656+
nC = coll_idx.numel() # 2
657+
FRAMES = 64
651658

652-
head_subset = head_input[coll_idx] # [nC, C, 64, H, W] (must hold)
659+
head_subset = head_input[coll_idx] # [2, 3, 64, 224, 224]
653660

654-
# 检查形状
655661
if head_subset.dim() != 5 or head_subset.size(2) != FRAMES:
656662
raise RuntimeError(
657663
f"collage branch expects head_subset shape [nC, C, {FRAMES}, H, W], got {tuple(head_subset.shape)}"
658664
)
659665

660-
nC = head_subset.size(0)
661-
Cf = head_subset.size(1)
662-
Hf = head_subset.size(3)
663-
Wf = head_subset.size(4)
664-
avg = FRAMES // args.actual_num_frames # 8
665-
base = torch.arange(args.actual_num_frames).cuda() * avg
666-
offs = torch.randint(avg, (nC, args.actual_num_frames)).cuda()
667-
frames_idx = (base.unsqueeze(0) + offs).long().clamp(max=FRAMES - 1) # [nC, actual_num_frames], 范围在 [0, 63]
668-
idx_expand = frames_idx.view(nC, 1, args.actual_num_frames, 1, 1).expand(-1, Cf, -1, Hf, Wf)
669-
sel_frames = torch.gather(head_subset, 2, idx_expand) # [nC, Cf, actual_num_frames, Hf, Wf]
670-
sel_frames = sel_frames.permute(0, 2, 1, 3, 4) # [nC, actual_num_frames, Cf, Hf, Wf]
671-
grid_rows = [sel_frames[:, i, :, :, :] for i in range(args.actual_num_frames)]
672-
grid = torch.cat(grid_rows, dim=-2) # [nC, Cf, Hf*args.actual_num_frames, Wf]
666+
nC = head_subset.size(0) # 2
667+
Cf = head_subset.size(1) # 3
668+
Hf = head_subset.size(3) # 224
669+
Wf = head_subset.size(4) # 224
670+
avg = FRAMES // args.actual_num_frames # 64 // 8 = 8
671+
base = torch.arange(args.actual_num_frames).cuda() * avg # [0, 8, 16, 24, 32, 40, 48, 56]
672+
offs = torch.randint(avg, (nC, args.actual_num_frames)).cuda() # [2, 8], values in [0, 7]
673+
frames_idx = (base.unsqueeze(0) + offs).long().clamp(max=FRAMES - 1) # [2, 8], values in [0, 63]
674+
idx_expand = frames_idx.view(nC, 1, args.actual_num_frames, 1, 1).expand(-1, Cf, -1, Hf, Wf) # [2, 3, 8, 224, 224]
675+
sel_frames = torch.gather(head_subset, 2, idx_expand) # [2, 3, 8, 224, 224]
676+
sel_frames = sel_frames.permute(0, 2, 1, 3, 4) # [2, 8, 3, 224, 224]
677+
grid_rows = [sel_frames[:, i, :, :, :] for i in range(args.actual_num_frames)] # 8 x [2, 3, 224, 224]
678+
grid = torch.cat(grid_rows, dim=-2) # [2, 3, 1792, 224] (1792 = 224 * 8)
673679
with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
674-
collage_head_output = backbone_ddp_compiled(grid)
680+
collage_head_output = backbone_ddp_compiled(grid) # input: [2, 3, 1792, 224]
675681
if hasattr(collage_head_output, "pooler_output"):
676682
collage_head_output = collage_head_output.pooler_output
677683
else:
678684
collage_head_output = collage_head_output["head_output"]
679-
collage_head_output = collage_head_output.float()
685+
collage_head_output = collage_head_output.float() # [2, D]
680686

681-
D = combined_head_output.size(1)
687+
D = combined_head_output.size(1) # embedding dimension
682688

683-
head_embedding_full = torch.zeros(bs, D, dtype=torch.float32).cuda()
689+
head_embedding_full = torch.zeros(bs, D, dtype=torch.float32).cuda() # [16, D]
684690
if combined_mask.any():
685-
head_embedding_full[combined_idx] = combined_head_output
691+
head_embedding_full[combined_idx] = combined_head_output # head_embedding_full[0:14] = [14, D]
686692
if mask_collage.any():
687-
head_embedding_full[coll_idx] = collage_head_output
693+
head_embedding_full[coll_idx] = collage_head_output # head_embedding_full[14:16] = [2, D]
688694

689-
list_embedding.append(head_embedding_full)
695+
list_embedding.append(head_embedding_full) # [16, D]
690696

691697
elif dataset_config.dali_type in ["origin", "ocr"]:
692698
head_input = list_data_batch[head_id]["pixel_values"]

0 commit comments

Comments
 (0)