Skip to content

Commit ba255bd

Browse files
committed
updated
1 parent 801cbd9 commit ba255bd

2 files changed

Lines changed: 24 additions & 23 deletions

File tree

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,15 @@ bash shells/ov_encoder_base_stage1_si.sh
230230

231231
### Single Node Stage-2 Video Contine Pretraining
232232

233+
Download the Stage-1 checkpoint from HuggingFace:
234+
235+
```bash
236+
git clone https://huggingface.co/lmms-lab-encoder/onevision-encoder-large-si
237+
```
238+
239+
Download the pretraining data and prepare the data directory as per the instructions in `data/README.md`.
240+
241+
233242
More documentation will be added soon.
234243

235244
```bash

training/train.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -304,14 +304,6 @@ def _expand(name, v):
304304

305305
if args.finetune_backbone:
306306
backbone.requires_grad_(True)
307-
else:
308-
backbone.requires_grad_(False)
309-
backbone_module = unwrap_module(backbone)
310-
if hasattr(backbone_module, "head"):
311-
for p in backbone_module.head.parameters():
312-
p.requires_grad = True
313-
else:
314-
raise RuntimeError()
315307

316308
backbone_parameters = filter(lambda p: p.requires_grad, backbone.parameters())
317309

@@ -347,7 +339,7 @@ def _expand(name, v):
347339
)
348340

349341
partial_fc.train().cuda()
350-
# list_module_pfc.append(torch.compile(partial_fc))
342+
351343
list_module_pfc.append(partial_fc)
352344
dict_pfc_modules[head_name] = partial_fc
353345

@@ -481,6 +473,7 @@ def wrap_ddp(model):
481473
shard_id=dataset_config.shard_id,
482474
num_shards=dataset_config.num_shards
483475
)
476+
484477
elif dataset_config.dali_type == "ocr":
485478
if args.debug:
486479
from dataloader.data_v2_ocr import SyntheticDataIter
@@ -495,8 +488,8 @@ def wrap_ddp(model):
495488
image_size=args.image_size,
496489
workers=args.workers,
497490
shard_id=dataset_config.shard_id,
498-
num_shards=dataset_config.num_shards
499-
)
491+
num_shards=dataset_config.num_shards)
492+
500493
else:
501494
raise ValueError(
502495
f"dataset_config.dali_type {dataset_config.dali_type} not support!"
@@ -608,10 +601,10 @@ def wrap_ddp(model):
608601
out[mask_residual] = sel_a
609602

610603
# mask_frame_sampling: sample 8 frames from 64, get all patches per frame
611-
SEQ, FRAMES = 8, 64
604+
FRAMES = 64
612605
if mask_frame_sampling.any():
613606
nB = mask_frame_sampling.sum().item()
614-
frames = torch.arange(SEQ, device=dev) * (FRAMES // SEQ) + torch.randint(FRAMES // SEQ, (nB, SEQ), device=dev)
607+
frames = torch.arange(args.actual_num_frames, device=dev) * (FRAMES // args.actual_num_frames) + torch.randint(FRAMES // args.actual_num_frames, (nB, args.actual_num_frames), device=dev)
615608
sel_b = (frames.unsqueeze(-1) * args.num_tokens_per_frame + torch.arange(args.num_tokens_per_frame, device=dev)).reshape(nB, -1)
616609
if sel_b.size(1) > args.target_num:
617610
sel_b = sel_b[:, :args.target_num]
@@ -654,7 +647,6 @@ def wrap_ddp(model):
654647
if mask_collage.any():
655648
coll_idx = torch.nonzero(mask_collage, as_tuple=False).squeeze(1)
656649
nC = coll_idx.numel()
657-
SEQ = 8
658650
FRAMES = 64 # assume fixed 64 frames for head_subset
659651

660652
head_subset = head_input[coll_idx] # [nC, C, 64, H, W] (must hold)
@@ -669,15 +661,15 @@ def wrap_ddp(model):
669661
Cf = head_subset.size(1)
670662
Hf = head_subset.size(3)
671663
Wf = head_subset.size(4)
672-
avg = FRAMES // SEQ # 8
673-
base = torch.arange(SEQ, device=dev) * avg
674-
offs = torch.randint(avg, (nC, SEQ), device=dev)
675-
frames_idx = (base.unsqueeze(0) + offs).long().clamp(max=FRAMES - 1) # [nC, SEQ], 范围在 [0, 63]
676-
idx_expand = frames_idx.view(nC, 1, SEQ, 1, 1).expand(-1, Cf, -1, Hf, Wf).to(head_subset.device)
677-
sel_frames = torch.gather(head_subset, 2, idx_expand) # [nC, Cf, SEQ, Hf, Wf]
678-
sel_frames = sel_frames.permute(0, 2, 1, 3, 4) # [nC, SEQ, Cf, Hf, Wf]
679-
grid_rows = [sel_frames[:, i, :, :, :] for i in range(SEQ)]
680-
grid = torch.cat(grid_rows, dim=-2) # [nC, Cf, Hf*SEQ, Wf]
664+
avg = FRAMES // args.actual_num_frames # 8
665+
base = torch.arange(args.actual_num_frames, device=dev) * avg
666+
offs = torch.randint(avg, (nC, args.actual_num_frames), device=dev)
667+
frames_idx = (base.unsqueeze(0) + offs).long().clamp(max=FRAMES - 1) # [nC, actual_num_frames], 范围在 [0, 63]
668+
idx_expand = frames_idx.view(nC, 1, args.actual_num_frames, 1, 1).expand(-1, Cf, -1, Hf, Wf).to(head_subset.device)
669+
sel_frames = torch.gather(head_subset, 2, idx_expand) # [nC, Cf, actual_num_frames, Hf, Wf]
670+
sel_frames = sel_frames.permute(0, 2, 1, 3, 4) # [nC, actual_num_frames, Cf, Hf, Wf]
671+
grid_rows = [sel_frames[:, i, :, :, :] for i in range(args.actual_num_frames)]
672+
grid = torch.cat(grid_rows, dim=-2) # [nC, Cf, Hf*args.actual_num_frames, Wf]
681673
with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
682674
collage_head_output = backbone_ddp_compiled(grid)
683675
if hasattr(collage_head_output, "pooler_output"):

0 commit comments

Comments
 (0)