Skip to content

Commit e153701

Browse files
committed
2 parents e604aae + ec5c0ae commit e153701

6 files changed

Lines changed: 38 additions & 43 deletions

File tree

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
</p>
88

99
<p align="center">
10-
<strong>Fully HEVC-Style Vision Transformer </strong>
10+
<strong>OneVision Encoder</strong>
1111
</p>
1212

1313
## 📖 Table of Contents
@@ -29,11 +29,11 @@ OneVision Encoder is a vision encoder designed for multimodal large language mod
2929
### Input Method Comparison
3030

3131
<table>
32-
<caption style="caption-side: top; text-align: center; font-weight: bold; margin-bottom: 10px;">Comparison of Frame Sampling Input vs Codec Input</caption>
32+
<caption style="caption-side: top; text-align: center; font-weight: bold; margin-bottom: 10px;">Frame Sampling Input vs Codec Input</caption>
3333
<tr>
3434
<td align="center">
3535
<img src="pages/images/example.gif" alt="Animated demonstration of traditional uniform frame sampling method for video processing" width="400"><br>
36-
<b>抽帧输入 (Frame Sampling Input)</b><br>
36+
<b>Frame Sampling Input</b><br>
3737
Traditional uniform frame sampling approach
3838
</td>
3939
<td align="center">

dataloader/ap_dataloader_dali_ip_mv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def _mv_energy_norm(
5555
pct: float = 95.0,
5656
):
5757
"""Return (norm_HxW_float32_in_[0,1], scale_max_px). No gamma/colormap."""
58+
if not _HAS_CV2: # fix: check cv2 availability before use
59+
raise ImportError("cv2 is required for _mv_energy_norm but not available")
5860
vx = mvx.astype(np.float32) / float(mv_unit_div)
5961
vy = mvy.astype(np.float32) / float(mv_unit_div)
6062
mag = np.sqrt(vx * vx + vy * vy) # pixels
@@ -315,7 +317,7 @@ def __call__(self, sample_info):
315317
video_path, video_label = example_info
316318
try:
317319
combined_data, duration, frame_id_list = self.get_frame_id_list(video_path, self.sequence_length)
318-
except:
320+
except Exception: # fix: avoid bare except to allow KeyboardInterrupt/SystemExit to propagate
319321
video_path, video_label = self.replace_example_info
320322
combined_data, duration, frame_id_list = self.get_frame_id_list(video_path, self.sequence_length)
321323

onevision_encoder/configuration_onevision_encoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
attention_dropout=0.0,
7878
initializer_range=0.02,
7979
rope_theta=10000.0,
80+
rope_temporal_size=None,
8081
use_head=True,
8182
**kwargs,
8283
):
@@ -94,4 +95,5 @@ def __init__(
9495
self.attention_dropout = attention_dropout
9596
self.initializer_range = initializer_range
9697
self.rope_theta = rope_theta
98+
self.rope_temporal_size = rope_temporal_size # None=use actual frames, int=fixed size (legacy: 64)
9799
self.use_head = use_head

onevision_encoder/modeling_onevision_encoder.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,10 @@ def forward(
547547
# Determine video dimensions for RoPE
548548
# Note: pixel_values passed to embeddings can be 4D or 5D
549549
if pixel_values.dim() == 5:
550-
t_frames = 64
550+
# fix: use config.rope_temporal_size if set, otherwise use actual frames
551+
# legacy behavior was hardcoded t_frames=64 (for padded 64-frame videos)
552+
actual_frames = pixel_values.shape[2]
553+
t_frames = self.config.rope_temporal_size if self.config.rope_temporal_size else actual_frames
551554
height = pixel_values.shape[3]
552555
width = pixel_values.shape[4]
553556
else:
@@ -578,6 +581,14 @@ def forward(
578581
# 4. Pre-Norm & Encoder
579582
hidden_states = self.layernorm_pre(hidden_states)
580583

584+
# fix: gather hidden_states to match freqs_visible when using sparse visible_indices
585+
num_visible = visible_indices.shape[1]
586+
if num_visible != total_patches:
587+
# sparse mode: select only visible patches
588+
hidden_states = hidden_states.gather(
589+
1, visible_indices.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
590+
)
591+
581592
encoder_outputs = self.encoder(
582593
hidden_states,
583594
attention_mask=None,

tools/tools_for_hevc/hevc_feature_decoder_mv.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -64,36 +64,6 @@ def ffprobe(filename):
6464

6565

6666

67-
# ---------------- YUV plane parsers ----------------
68-
def _split_yuv420_planes(buf: bytes, H: int, W: int, layout: str):
69-
"""Return Y (H,W), U (H/2,W/2), V (H/2,W/2) for layout in {i420,yv12,nv12,nv21}."""
70-
nY = H*W
71-
nUV = (H//2)*(W//2)
72-
arr = np.frombuffer(buf, dtype=np.uint8)
73-
if layout in ("i420","yv12"):
74-
Y = arr[:nY].reshape(H, W)
75-
UV = arr[nY:]
76-
# planar U and V (each nUV)
77-
U_planar, V_planar = (UV[:nUV], UV[nUV:]) if layout=="i420" else (UV[nUV:], UV[:nUV])
78-
U = U_planar.reshape(H//2, W//2)
79-
V = V_planar.reshape(H//2, W//2)
80-
return Y, U, V
81-
elif layout in ("nv12","nv21"):
82-
Y = arr[:nY].reshape(H, W)
83-
UVint = arr[nY:].reshape(H//2, W) # interleaved per row: UVUV or VUVU
84-
U = np.empty((H//2, W//2), dtype=np.uint8)
85-
V = np.empty((H//2, W//2), dtype=np.uint8)
86-
if layout == "nv12": # UVUV...
87-
U[:] = UVint[:, 0::2]
88-
V[:] = UVint[:, 1::2]
89-
else: # nv21: VUVU...
90-
V[:] = UVint[:, 0::2]
91-
U[:] = UVint[:, 1::2]
92-
return Y, U, V
93-
else:
94-
raise ValueError(layout)
95-
96-
9767
# ---------------- YUV plane parsers ----------------
9868
def _split_yuv420_planes(buf: bytes, H: int, W: int, layout: str):
9969
"""Return Y (H,W), U (H/2,W/2), V (H/2,W/2) for layout in {i420,yv12,nv12,nv21}."""
@@ -564,9 +534,16 @@ def close(self):
564534
if self._proc is not None and self._proc.poll() is None:
565535
self._proc.stdin.close()
566536
self._proc.stdout.close()
567-
# self._proc.stderr.close()
537+
# stderr is redirected to DEVNULL, not a pipe
568538
self._terminate(0.2)
569539
self._proc = None
540+
# fix: close DEVNULL file handle to prevent resource leak
541+
if hasattr(self, 'DEVNULL') and self.DEVNULL is not None:
542+
try:
543+
self.DEVNULL.close()
544+
except Exception:
545+
pass
546+
self.DEVNULL = None
570547

571548
def _terminate(self, timeout=1.0):
572549
"""Terminate the sub process."""

training/train.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from training.lr_scheduler import PolynomialLRWarmup
2121
from onevision_encoder import OneVisionEncoderModel, OneVisionEncoderConfig
2222

23-
torch._dynamo.config.optimize_ddp = True
23+
# fix: removed conflicting line (was: True immediately overwritten by False)
2424
torch._dynamo.config.optimize_ddp = False
2525

2626
parser = argparse.ArgumentParser(description="Multi-dataset video training")
@@ -657,12 +657,15 @@ def wrap_ddp(model):
657657

658658
# 按 batch 固定划分:前50% residual, 中37.5% frame_sampling, 后12.5% collage
659659
n1 = int(bs * 0.5)
660-
n2 = int(bs * 0.375)
660+
# fix: n2 must be cumulative threshold, not standalone percentage
661+
# bug was: n2 = int(bs * 0.375) which gives n2=37 when bs=100
662+
# this caused mask_frame_sampling = (idx >= 50) & (idx < 37) to be always False
663+
n2 = int(bs * 0.875) # cumulative: 50% + 37.5% = 87.5%
661664

662665
idx_range = torch.arange(bs, device=dev)
663-
mask_residual = idx_range < n1
664-
mask_frame_sampling = (idx_range >= n1) & (idx_range < n2)
665-
mask_collage = idx_range >= n2
666+
mask_residual = idx_range < n1 # idx in [0, n1)
667+
mask_frame_sampling = (idx_range >= n1) & (idx_range < n2) # idx in [n1, n2)
668+
mask_collage = idx_range >= n2 # idx in [n2, bs)
666669

667670
# ---------- residual(前50%): 生成 out 行 ----------
668671
if mask_residual.any():
@@ -831,8 +834,8 @@ def wrap_ddp(model):
831834
opt.step()
832835
opt.zero_grad()
833836

834-
# 学习率更新
835-
lr_scheduler.step()
837+
# fix: lr update should only happen after opt.step(), not every micro-batch
838+
lr_scheduler.step()
836839

837840
batch_end_callback(
838841
global_step=global_step,

0 commit comments

Comments
 (0)