Skip to content

Commit 5c794b1

Browse files
committed
updated
1 parent 5996db6 commit 5c794b1

14 files changed

Lines changed: 85 additions & 260 deletions

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,3 +498,10 @@ cache/
498498
list_*
499499
tmp
500500
*.jsonmodel_factory/vit_aim_v2_packing_hf_old.py
501+
502+
503+
ckpts/
504+
ckpts/**
505+
ckpts
506+
.gitginore
507+

README.md

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,15 @@ Coupled with global contrastive learning over a 2M-scale concept memory bank, On
6464

6565
### Video Processing Pipeline
6666

67-
The visualization below illustrates four different video processing pipelines.
68-
(1) **Original Video**: a continuous 64-frame sequence that preserves the complete temporal context.
69-
(2) **Uniform Frame Sampling**: a conventional strategy that selects 4–8 evenly spaced frames; while simple and efficient, it is inherently lossy and fails to capture fine-grained inter-frame motion.
70-
(3) **Temporal Saliency Detection**: a global analysis of all 64 frames to identify regions rich in temporal information, including motion patterns, appearance variations, and semantic events.
71-
(4) **Codec-Style Patch Extraction**: selective extraction of the temporally salient patches in a zigzag order, achieving 75–98% compression while retaining critical temporal dynamics.
67+
The visualization below illustrates four different video processing pipelines.
68+
69+
**1. Original Video**: a continuous 64-frame sequence that preserves the complete temporal context.
70+
71+
**2. Uniform Frame Sampling**: a conventional strategy that selects 4–8 evenly spaced frames; while simple and efficient, it is inherently lossy and fails to capture fine-grained inter-frame motion.
72+
73+
**3. Temporal Saliency Detection**: a global analysis of all 64 frames to identify regions rich in temporal information, including motion patterns, appearance variations, and semantic events.
74+
75+
**4. Codec-Style Patch Extraction**: selective extraction of the temporally salient patches in a zigzag order, achieving 75–98% compression while retaining critical temporal dynamics.
7276

7377
<div align="center">
7478
<table style="width: 100%; max-width: 1200px; table-layout: fixed;">
@@ -272,25 +276,14 @@ Training configurations and hyperparameters will be documented soon. For now, pl
272276
To evaluate the encoder with uniform frame sampling, first navigate to the evaluation directory:
273277

274278
```bash
279+
pip install -e .
275280
cd eval_encoder
276281
```
277282

278283
Then run the following command:
279284

280285
```bash
281-
torchrun --nproc_per_node=8 --master_port=29507 attentive_probe.py \
282-
--eval_freq 1 \
283-
--default_lr_list 0.0001 \
284-
--batch_size 32 \
285-
--default_weight_decay 0 \
286-
--dali_py_num_workers 8 \
287-
--model_family llava_vit_sampling \
288-
--dataset diving48 \
289-
--num_frames 8 \
290-
--model_weight lmms-lab-encoder/onevision-encoder-large \
291-
--model_name hf_llava_vit_large_ln \
292-
--embedding_size 1024 \
293-
--frames_token_num 256
286+
bash eval_encoder/shells_eval_ap/eval_ov_encoder_large_16frames.sh
294287
```
295288

296289
**Sampling-Specific Parameters:**
@@ -313,22 +306,21 @@ torchrun --nproc_per_node=8 --master_port=29512 attentive_probe_codec.py \
313306
--batch_size 4 \
314307
--default_weight_decay 0 \
315308
--dali_py_num_workers 8 \
316-
--model_family llava_vit_codec \
309+
--model_family ov_encoder_codec \
317310
--dataset diving48 \
318-
--num_frames 64 \
319-
--model_weight lmms-lab/onevision-encoder-large \
320-
--model_name hf_llava_vit_large_ln \
311+
--model_weight lmms-lab-encoder/onevision-encoder-large \
312+
--model_name ov_encoder_large \
321313
--embedding_size 1024 \
322314
--default_epoch 30 \
323-
--data_root /path/to/your/data_attentive_probe/ \
324-
--cache_dir /path/to/your/cache_residuals/ \
325315
--K_keep 2048 \
316+
--num_frames 64 \
326317
--mv_compensate median
318+
327319
```
328320

329321
**Codec-Specific Parameters:**
322+
- `K_keep`: Number of patches to keep.
330323
- `cache_dir`: Directory for cached codec patches. This is where the codec-selected patches will be stored/loaded.
331-
- `K_keep`: Number of patches to keep. For example, 256 patches per frame × 8 frames = 2048 total patches. Adjust based on your frame count and desired compression ratio.
332324
- `mv_compensate`: Motion vector compensation method (e.g., `median`).
333325

334326
#### Shared Parameters

dockerfile

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1-
FROM pytorch/pytorch:2.7.0-cuda11.8-cudnn9-runtime
1+
FROM nvcr.io/nvidia/pytorch:25.04-py3
22

33
# Set up environment variables
44
ENV DEBIAN_FRONTEND=noninteractive \
55
PYTHONUNBUFFERED=1 \
66
PIP_NO_CACHE_DIR=1
77

8+
9+
RUN apt-get update && apt-get install -y --no-install-recommends \
10+
libgl1 \
11+
libglib2.0-0 \
12+
&& rm -rf /var/lib/apt/lists/*
13+
14+
815
# Install system dependencies and ffmpeg in one layer
916
RUN set -eux; \
1017
apt-get update && apt-get install -y --no-install-recommends \

eval_encoder/attentive_probe.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
import torch.nn.functional as F
1010
import torchmetrics
11-
from dataloader.ap_dataloader_dali import get_dali_dataloader
1211
from timm.loss import LabelSmoothingCrossEntropy
1312
from timm.models import create_model
1413
from timm.models.layers import trunc_normal_
@@ -19,7 +18,9 @@
1918

2019
# Ensure custom models and layers are registered
2120
import model_factory
22-
from model_factory.layers import Siglip2MultiheadAttentionPoolingHead, Siglip2TransformerAttentionPoolingHead
21+
from dataloader.ap_dataloader_dali import get_dali_dataloader
22+
from model_factory.layers import (Siglip2MultiheadAttentionPoolingHead,
23+
Siglip2TransformerAttentionPoolingHead)
2324

2425
warnings.filterwarnings("ignore")
2526

@@ -33,7 +34,7 @@ def parse_args() -> argparse.Namespace:
3334
parser.add_argument("--dataset", default="ssv2")
3435

3536
# Model
36-
parser.add_argument("--model_family", default="llava_vit_sampling")
37+
parser.add_argument("--model_family", default="chunk_wise_sampling")
3738
parser.add_argument("--model_name", default="ov_encoder_large")
3839
parser.add_argument("--model_weight", default="NULL")
3940
parser.add_argument("--num_frames", type=int, default=8)
@@ -42,7 +43,6 @@ def parse_args() -> argparse.Namespace:
4243
parser.add_argument("--tubelet_size", type=int, default=1)
4344
parser.add_argument("--embedding_size", type=int, default=768)
4445
parser.add_argument("--num_classes", type=int, default=0)
45-
# ===> New: target frame number parameter <===
4646
parser.add_argument("--target_frames", type=int, default=64,
4747
help="Target number of frames to interpolate to (default: 64)")
4848

@@ -155,7 +155,6 @@ def video_to_images(videos: torch.Tensor) -> torch.Tensor:
155155
"dinov2",
156156
"dinov3",
157157
"metaclip",
158-
"llava_vit_si",
159158
"aimv2"
160159
]
161160
if args.model_family in list_vit_single_image:
@@ -183,7 +182,7 @@ def video_to_images(videos: torch.Tensor) -> torch.Tensor:
183182
else:
184183
raise ValueError("SigLIP2 only supports image input with 4 dimensions [B, C, H, W].")
185184

186-
elif args.model_family == "llava_vit_sampling":
185+
elif args.model_family == "chunk_wise_sampling":
187186
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
188187
with torch.no_grad():
189188
bs, C, T, H, W = videos.shape
@@ -410,7 +409,7 @@ def get_model(args: argparse.Namespace) -> nn.Module:
410409
return model
411410

412411
model = create_model(args.model_name, pretrained=False)
413-
if args.model_family in ["llava_vit_sampling"]:
412+
if args.model_family in ["chunk_wise_sampling"]:
414413
state_dict = torch.load(args.model_weight, map_location="cpu")
415414
state_dict = {k.replace("_orig_mod.", "").replace("module.", ""): v for k, v in state_dict.items()}
416415
model.load_state_dict(state_dict, strict=True)

0 commit comments

Comments
 (0)