Skip to content

Commit 801cbd9

Browse files
Copilotanxiangsir
andcommitted
Simplify decord_residual code for better readability
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent bef1e90 commit 801cbd9

1 file changed

Lines changed: 44 additions & 94 deletions

File tree

training/train.py

Lines changed: 44 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -589,116 +589,66 @@ def wrap_ddp(model):
589589
elif dataset_config.dali_type in ["decord_residual"]:
590590
head_input = list_data_batch[head_id]["videos"] # [bs, C, 64, H, W]
591591
list_batch_sizes.append(head_input.size(0))
592-
visible_indices = list_data_batch[head_id]["video_visible_indices"] # [bs, ?]
593-
visible_indices = visible_indices.long()
594-
595-
bs = visible_indices.shape[0]
596-
dev = visible_indices.device
592+
visible_indices = list_data_batch[head_id]["video_visible_indices"].long()
597593

594+
bs, dev = visible_indices.shape[0], visible_indices.device
598595
out = visible_indices[:, :args.target_num].clone()
599-
n1 = int(bs * 0.5)
600-
n2 = int(bs * 0.875)
596+
n1, n2 = int(bs * 0.5), int(bs * 0.875)
601597

602598
idx_range = torch.arange(bs, device=dev)
603-
mask_residual = idx_range < n1 # idx in [0, n1)
604-
mask_frame_sampling = (idx_range >= n1) & (idx_range < n2) # idx in [n1, n2)
605-
mask_collage = idx_range >= n2 # idx in [n2, bs)
599+
mask_residual = idx_range < n1
600+
mask_frame_sampling = (idx_range >= n1) & (idx_range < n2)
601+
mask_collage = idx_range >= n2
606602

607-
# For mask_residual: directly select first args.target_num patches
603+
# mask_residual: select first args.target_num patches
608604
if mask_residual.any():
609605
sel_a = visible_indices[mask_residual, :args.target_num]
610606
if sel_a.size(1) < args.target_num:
611-
pad = sel_a[:, -1:].repeat(1, args.target_num - sel_a.size(1))
612-
sel_a = torch.cat([sel_a, pad], dim=1)
607+
sel_a = torch.cat([sel_a, sel_a[:, -1:].expand(-1, args.target_num - sel_a.size(1))], dim=1)
613608
out[mask_residual] = sel_a
614609

615-
# For mask_frame_sampling: compute patch indices based on frame sampling
616-
SEQ = 8
617-
FRAMES = 64
610+
# mask_frame_sampling: sample 8 frames from 64, get all patches per frame
611+
SEQ, FRAMES = 8, 64
618612
if mask_frame_sampling.any():
619-
nB = visible_indices[mask_frame_sampling].size(0)
620-
avg = FRAMES // SEQ
621-
base = torch.arange(SEQ, device=dev) * avg
622-
offs = torch.randint(avg, (nB, SEQ), device=dev)
623-
frames = base + offs # [nB, 8]
624-
625-
per = torch.arange(args.num_tokens_per_frame, device=dev)
626-
pos = (frames.unsqueeze(-1) * args.num_tokens_per_frame + per).reshape(nB, -1) # [nB, 8*num_tokens_per_frame]
627-
sel_b = pos.to(visible_indices.dtype)
628-
629-
if sel_b.size(1) == args.target_num:
630-
out[mask_frame_sampling] = sel_b
631-
elif sel_b.size(1) > args.target_num:
632-
out[mask_frame_sampling] = sel_b[:, :args.target_num]
633-
else:
634-
pad = sel_b[:, -1:].repeat(1, args.target_num - sel_b.size(1))
635-
out[mask_frame_sampling] = torch.cat([sel_b, pad], dim=1)
613+
nB = mask_frame_sampling.sum().item()
614+
frames = torch.arange(SEQ, device=dev) * (FRAMES // SEQ) + torch.randint(FRAMES // SEQ, (nB, SEQ), device=dev)
615+
sel_b = (frames.unsqueeze(-1) * args.num_tokens_per_frame + torch.arange(args.num_tokens_per_frame, device=dev)).reshape(nB, -1)
616+
if sel_b.size(1) > args.target_num:
617+
sel_b = sel_b[:, :args.target_num]
618+
elif sel_b.size(1) < args.target_num:
619+
sel_b = torch.cat([sel_b, sel_b[:, -1:].expand(-1, args.target_num - sel_b.size(1))], dim=1)
620+
out[mask_frame_sampling] = sel_b
636621

637622
combined_mask = mask_residual | mask_frame_sampling
638623
if combined_mask.any():
639-
combined_idx = torch.nonzero(combined_mask, as_tuple=False).squeeze(1)
640-
combined_video = head_input[combined_idx] # [n, C, 64, H, W]
641-
combined_out = out[combined_idx] # [n, target_num]
642-
643-
n_comb, C_vid, T_vid, H_vid, W_vid = combined_video.shape
644-
Hp = H_vid // patch_size # patches per row
645-
Wp = W_vid // patch_size # patches per col
646-
patches_per_frame = Hp * Wp
647-
total_patches = T_vid * patches_per_frame
648-
649-
# Convert video to patches: [n, C, T*Hp*Wp, patch_size, patch_size]
650-
# First reshape to [n, C, T, Hp, patch_size, Wp, patch_size]
651-
video_reshaped = combined_video.view(n_comb, C_vid, T_vid, Hp, patch_size, Wp, patch_size)
652-
# Permute to [n, C, T, Hp, Wp, patch_size, patch_size]
653-
video_reshaped = video_reshaped.permute(0, 1, 2, 3, 5, 4, 6)
654-
# Reshape to [n, C, T*Hp*Wp, patch_size, patch_size]
655-
video_patches = video_reshaped.reshape(n_comb, C_vid, total_patches, patch_size, patch_size)
656-
657-
# Select patches using combined_out (visible_indices): [n, target_num]
658-
# Expand combined_out for gathering: [n, target_num, 1, 1] -> [n, C, target_num, patch_size, patch_size]
659-
idx_expand = combined_out.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # [n, 1, target_num, 1, 1]
660-
idx_expand = idx_expand.expand(-1, C_vid, -1, patch_size, patch_size) # [n, C, target_num, patch_size, patch_size]
661-
selected_patches = torch.gather(video_patches, 2, idx_expand) # [n, C, target_num, patch_size, patch_size]
662-
663-
# Reshape selected patches back to video format [n, C, T', H', W']
664-
# We have target_num patches, need to figure out T', Hp', Wp'
665-
# For simplicity: T' = target_num // patches_per_frame, and use original Hp, Wp
666-
T_new = args.target_num // patches_per_frame
667-
expected_patches = T_new * patches_per_frame
668-
669-
# Handle case when target_num is not divisible by patches_per_frame
670-
if expected_patches != args.target_num:
671-
T_new = max(1, T_new)
672-
expected_patches = T_new * patches_per_frame
673-
# Truncate or pad selected_patches to match expected_patches
674-
if args.target_num > expected_patches:
675-
selected_patches = selected_patches[:, :, :expected_patches, :, :]
676-
else:
677-
# Pad with the last patch repeated
678-
pad_size = expected_patches - args.target_num
679-
pad_patches = selected_patches[:, :, -1:, :, :].repeat(1, 1, pad_size, 1, 1)
680-
selected_patches = torch.cat([selected_patches, pad_patches], dim=2)
681-
682-
# Reshape: [n, C, expected_patches, patch_size, patch_size] -> [n, C, T_new, Hp, Wp, patch_size, patch_size]
683-
# Then -> [n, C, T_new, Hp*patch_size, Wp*patch_size] = [n, C, T_new, H, W]
684-
H_new = Hp * patch_size
685-
W_new = Wp * patch_size
686-
687-
# First reshape to [n, C, T_new, Hp, Wp, patch_size, patch_size]
688-
selected_reshaped = selected_patches.view(n_comb, C_vid, T_new, Hp, Wp, patch_size, patch_size)
689-
# Permute to [n, C, T_new, Hp, patch_size, Wp, patch_size]
690-
selected_reshaped = selected_reshaped.permute(0, 1, 2, 3, 5, 4, 6)
691-
# Reshape to [n, C, T_new, H_new, W_new]
692-
combined_head_input = selected_reshaped.reshape(n_comb, C_vid, T_new, H_new, W_new)
624+
combined_idx = combined_mask.nonzero(as_tuple=False).squeeze(1)
625+
video = head_input[combined_idx] # [n, C, T, H, W]
626+
vis_idx = out[combined_idx] # [n, target_num]
627+
628+
n, C, T, H, W = video.shape
629+
Hp, Wp = H // patch_size, W // patch_size
630+
631+
# Patchify: [n, C, T, H, W] -> [n, C, T*Hp*Wp, p, p]
632+
patches = video.view(n, C, T, Hp, patch_size, Wp, patch_size).permute(0, 1, 2, 3, 5, 4, 6).reshape(n, C, T * Hp * Wp, patch_size, patch_size)
633+
634+
# Select patches by vis_idx
635+
idx = vis_idx[:, None, :, None, None].expand(-1, C, -1, patch_size, patch_size)
636+
selected = torch.gather(patches, 2, idx) # [n, C, target_num, p, p]
637+
638+
# Unpatchify: [n, C, target_num, p, p] -> [n, C, T', H, W]
639+
T_new = args.target_num // (Hp * Wp)
640+
if T_new == 0:
641+
T_new = 1
642+
num_patches = T_new * Hp * Wp
643+
if selected.size(2) > num_patches:
644+
selected = selected[:, :, :num_patches]
645+
elif selected.size(2) < num_patches:
646+
selected = torch.cat([selected, selected[:, :, -1:].expand(-1, -1, num_patches - selected.size(2), -1, -1)], dim=2)
647+
combined_head_input = selected.view(n, C, T_new, Hp, Wp, patch_size, patch_size).permute(0, 1, 2, 3, 5, 4, 6).reshape(n, C, T_new, H, W)
693648

694649
with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
695-
combined_head_output = backbone_ddp_compiled(combined_head_input, combined_out)
696-
if hasattr(combined_head_output, "pooler_output"):
697-
combined_head_output = combined_head_output.pooler_output
698-
else:
699-
combined_head_output = combined_head_output["head_output"]
700-
701-
combined_head_output = combined_head_output.float()
650+
combined_head_output = backbone_ddp_compiled(combined_head_input, vis_idx)
651+
combined_head_output = (combined_head_output.pooler_output if hasattr(combined_head_output, "pooler_output") else combined_head_output["head_output"]).float()
702652

703653

704654
if mask_collage.any():

0 commit comments

Comments
 (0)