Skip to content

Commit 9d54223

Browse files
committed
feat: add optional patch_positions parameter to forward method for improved flexibility
1 parent 6aeeec7 commit 9d54223

1 file changed

Lines changed: 11 additions & 7 deletions

File tree

onevision_encoder/modeling_onevision_encoder.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ def __init__(self, config: OneVisionEncoderConfig):
562562
def forward(
563563
self,
564564
pixel_values: torch.Tensor,
565+
patch_postions: Optional[torch.Tensor] = None,
565566
visible_indices: Optional[torch.Tensor] = None,
566567
output_attentions: Optional[bool] = None,
567568
output_hidden_states: Optional[bool] = None,
@@ -616,13 +617,16 @@ def forward(
616617
)
617618

618619
# 3. RoPE Construction
619-
freqs_full = self.video_rope(
620-
t=t_frames,
621-
h=height // self.config.patch_size,
622-
w=width // self.config.patch_size,
623-
device=pixel_values.device,
624-
)
625-
freqs_visible = freqs_full[visible_indices]
620+
if patch_postions is not None:
621+
freqs_visible = self.video_rope.forward_from_positions(patch_postions)
622+
else:
623+
freqs_full = self.video_rope(
624+
t=t_frames,
625+
h=height // self.config.patch_size,
626+
w=width // self.config.patch_size,
627+
device=pixel_values.device,
628+
)
629+
freqs_visible = freqs_full[visible_indices]
626630

627631
# Concatenate D/2 + D/2 -> D for applying rope
628632
freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)

0 commit comments

Comments
 (0)