Skip to content

Commit c0ce246

Browse files
Copilotanxiangsir
andcommitted
Implement new ViT interface: patchify video, select patches, reshape to video
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent bf6d733 commit c0ce246

1 file changed

Lines changed: 70 additions & 71 deletions

File tree

training/train.py

Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -578,14 +578,23 @@ def wrap_ddp(model):
578578
list_embedding.append(head_embedding)
579579

580580
elif dataset_config.dali_type in ["decord_residual"]:
581-
head_input = list_data_batch[head_id]["videos"]
581+
head_input = list_data_batch[head_id]["videos"] # [bs, C, 64, H, W]
582582
list_batch_sizes.append(head_input.size(0))
583-
visible_indices = list_data_batch[head_id]["video_visible_indices"] # [bs, ?],需要至少 args.total_indices 合法列
583+
visible_indices = list_data_batch[head_id]["video_visible_indices"] # [bs, ?]
584584
visible_indices = visible_indices.long()
585585

586586
bs = visible_indices.shape[0]
587587
dev = visible_indices.device
588588

589+
# Get patch_size from backbone config
590+
backbone_module = unwrap_module(backbone)
591+
if hasattr(backbone_module, 'config'):
592+
patch_size = backbone_module.config.patch_size
593+
elif hasattr(backbone_module, 'embeddings') and hasattr(backbone_module.embeddings, 'patch_size'):
594+
patch_size = backbone_module.embeddings.patch_size
595+
else:
596+
patch_size = 16 # default fallback
597+
589598
out = visible_indices[:, :args.target_num].clone()
590599
n1 = int(bs * 0.5)
591600
n2 = int(bs * 0.875)
@@ -595,61 +604,17 @@ def wrap_ddp(model):
595604
mask_frame_sampling = (idx_range >= n1) & (idx_range < n2) # idx in [n1, n2)
596605
mask_collage = idx_range >= n2 # idx in [n2, bs)
597606

598-
# Storage for selected frames per sample
599-
SEQ = 8
600-
FRAMES = 64
601-
selected_frames_all = torch.zeros(bs, SEQ, device=dev, dtype=torch.long)
602-
607+
# For mask_residual: directly select first args.target_num patches
603608
if mask_residual.any():
604-
vis_a = visible_indices[mask_residual, :args.total_indices]
605-
must = vis_a[:, :args.must_num]
606-
candidates = vis_a[:, args.must_num:args.total_indices]
607-
k = max(0, args.target_num - args.must_num)
608-
k = min(k, candidates.size(1))
609-
if k > 0:
610-
scores = torch.rand(vis_a.size(0), candidates.size(1), device=dev)
611-
idx = scores.topk(k, dim=1).indices
612-
sampled = torch.gather(candidates, 1, idx)
613-
sel_a = torch.cat([must, sampled], dim=1)
614-
else:
615-
sel_a = must
609+
sel_a = visible_indices[mask_residual, :args.target_num]
616610
if sel_a.size(1) < args.target_num:
617611
pad = sel_a[:, -1:].repeat(1, args.target_num - sel_a.size(1))
618612
sel_a = torch.cat([sel_a, pad], dim=1)
619613
out[mask_residual] = sel_a
620614

621-
# Extract frame indices from sel_a for residual branch
622-
# sel_a contains patch indices, convert to frame indices
623-
nA = sel_a.size(0)
624-
# Get frame index for each patch
625-
frame_indices_per_patch = sel_a // args.num_tokens_per_frame # [nA, target_num]
626-
627-
# For each sample, get unique sorted frame indices
628-
# We need exactly SEQ frames
629-
# Vectorized approach: sort and take unique by checking consecutive differences
630-
sorted_frames, _ = torch.sort(frame_indices_per_patch, dim=1) # [nA, target_num]
631-
632-
# Find where consecutive values differ (mark the first of each unique value)
633-
diff_mask = torch.cat([
634-
torch.ones(nA, 1, device=dev, dtype=torch.bool),
635-
sorted_frames[:, 1:] != sorted_frames[:, :-1]
636-
], dim=1) # [nA, target_num]
637-
638-
# For each sample, collect the first SEQ unique frames
639-
residual_frames = torch.zeros(nA, SEQ, device=dev, dtype=torch.long)
640-
for i in range(nA):
641-
unique_positions = diff_mask[i].nonzero(as_tuple=True)[0]
642-
unique_values = sorted_frames[i, unique_positions]
643-
num_unique = unique_values.numel()
644-
if num_unique >= SEQ:
645-
residual_frames[i] = unique_values[:SEQ]
646-
elif num_unique > 0:
647-
residual_frames[i, :num_unique] = unique_values
648-
residual_frames[i, num_unique:] = unique_values[-1]
649-
# If num_unique == 0, residual_frames[i] stays at zeros (frame 0)
650-
selected_frames_all[mask_residual] = residual_frames
651-
652-
615+
# For mask_frame_sampling: compute patch indices based on frame sampling
616+
SEQ = 8
617+
FRAMES = 64
653618
if mask_frame_sampling.any():
654619
nB = visible_indices[mask_frame_sampling].size(0)
655620
avg = FRAMES // SEQ
@@ -669,35 +634,69 @@ def wrap_ddp(model):
669634
pad = sel_b[:, -1:].repeat(1, args.target_num - sel_b.size(1))
670635
out[mask_frame_sampling] = torch.cat([sel_b, pad], dim=1)
671636

672-
# Store frame indices for frame_sampling branch
673-
selected_frames_all[mask_frame_sampling] = frames
674-
675-
676637
combined_mask = mask_residual | mask_frame_sampling
677638
if combined_mask.any():
678639
combined_idx = torch.nonzero(combined_mask, as_tuple=False).squeeze(1)
679-
combined_head_input_full = head_input[combined_idx] # [n, C, 64, H, W]
680-
combined_out = out[combined_idx]
681-
combined_frames = selected_frames_all[combined_idx] # [n, SEQ]
682-
683-
# Select frames from the video based on selected_frames
684-
# combined_head_input_full: [n, C, 64, H, W]
685-
# combined_frames: [n, SEQ]
686-
nComb = combined_head_input_full.size(0)
687-
Cf = combined_head_input_full.size(1)
688-
Hf = combined_head_input_full.size(3)
689-
Wf = combined_head_input_full.size(4)
690-
691-
# Expand frame indices for gather: [n, 1, SEQ, 1, 1] -> [n, C, SEQ, H, W]
692-
frame_idx_expand = combined_frames.view(nComb, 1, SEQ, 1, 1).expand(-1, Cf, -1, Hf, Wf)
693-
combined_head_input = torch.gather(combined_head_input_full, 2, frame_idx_expand) # [n, C, SEQ, H, W]
640+
combined_video = head_input[combined_idx] # [n, C, 64, H, W]
641+
combined_out = out[combined_idx] # [n, target_num]
642+
643+
n_comb, C_vid, T_vid, H_vid, W_vid = combined_video.shape
644+
Hp = H_vid // patch_size # patches per row
645+
Wp = W_vid // patch_size # patches per col
646+
patches_per_frame = Hp * Wp
647+
total_patches = T_vid * patches_per_frame
648+
649+
# Convert video to patches: [n, C, T*Hp*Wp, patch_size, patch_size]
650+
# First reshape to [n, C, T, Hp, patch_size, Wp, patch_size]
651+
video_reshaped = combined_video.view(n_comb, C_vid, T_vid, Hp, patch_size, Wp, patch_size)
652+
# Permute to [n, C, T, Hp, Wp, patch_size, patch_size]
653+
video_reshaped = video_reshaped.permute(0, 1, 2, 3, 5, 4, 6)
654+
# Reshape to [n, C, T*Hp*Wp, patch_size, patch_size]
655+
video_patches = video_reshaped.reshape(n_comb, C_vid, total_patches, patch_size, patch_size)
656+
657+
# Select patches using combined_out (visible_indices): [n, target_num]
658+
# Expand combined_out for gathering: [n, target_num, 1, 1] -> [n, C, target_num, patch_size, patch_size]
659+
idx_expand = combined_out.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # [n, 1, target_num, 1, 1]
660+
idx_expand = idx_expand.expand(-1, C_vid, -1, patch_size, patch_size) # [n, C, target_num, patch_size, patch_size]
661+
selected_patches = torch.gather(video_patches, 2, idx_expand) # [n, C, target_num, patch_size, patch_size]
662+
663+
# Reshape selected patches back to video format [n, C, T', H', W']
664+
# We have target_num patches, need to figure out T', Hp', Wp'
665+
# For simplicity: T' = target_num // patches_per_frame, and use original Hp, Wp
666+
T_new = args.target_num // patches_per_frame
667+
expected_patches = T_new * patches_per_frame
668+
669+
# Handle case when target_num is not divisible by patches_per_frame
670+
if expected_patches != args.target_num:
671+
T_new = max(1, T_new)
672+
expected_patches = T_new * patches_per_frame
673+
# Truncate or pad selected_patches to match expected_patches
674+
if args.target_num > expected_patches:
675+
selected_patches = selected_patches[:, :, :expected_patches, :, :]
676+
else:
677+
# Pad with the last patch repeated
678+
pad_size = expected_patches - args.target_num
679+
pad_patches = selected_patches[:, :, -1:, :, :].repeat(1, 1, pad_size, 1, 1)
680+
selected_patches = torch.cat([selected_patches, pad_patches], dim=2)
681+
682+
# Reshape: [n, C, expected_patches, patch_size, patch_size] -> [n, C, T_new, Hp, Wp, patch_size, patch_size]
683+
# Then -> [n, C, T_new, Hp*patch_size, Wp*patch_size] = [n, C, T_new, H, W]
684+
H_new = Hp * patch_size
685+
W_new = Wp * patch_size
686+
687+
# First reshape to [n, C, T_new, Hp, Wp, patch_size, patch_size]
688+
selected_reshaped = selected_patches.view(n_comb, C_vid, T_new, Hp, Wp, patch_size, patch_size)
689+
# Permute to [n, C, T_new, Hp, patch_size, Wp, patch_size]
690+
selected_reshaped = selected_reshaped.permute(0, 1, 2, 3, 5, 4, 6)
691+
# Reshape to [n, C, T_new, H_new, W_new]
692+
combined_head_input = selected_reshaped.reshape(n_comb, C_vid, T_new, H_new, W_new)
694693

695694
with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
696695
combined_head_output = backbone_ddp_compiled(combined_head_input, combined_out)
697696
if hasattr(combined_head_output, "pooler_output"):
698697
combined_head_output = combined_head_output.pooler_output
699698
else:
700-
combined_head_output = combined_head_output["head_output"]
699+
combined_head_output = combined_head_output["head_output"]
701700

702701
combined_head_output = combined_head_output.float()
703702

0 commit comments

Comments
 (0)