@@ -562,6 +562,7 @@ def __init__(self, config: OneVisionEncoderConfig):
562562 def forward (
563563 self ,
564564 pixel_values : torch .Tensor ,
565+ patch_postions : Optional [torch .Tensor ] = None ,
565566 visible_indices : Optional [torch .Tensor ] = None ,
566567 output_attentions : Optional [bool ] = None ,
567568 output_hidden_states : Optional [bool ] = None ,
@@ -616,13 +617,16 @@ def forward(
616617 )
617618
618619 # 3. RoPE Construction
619- freqs_full = self .video_rope (
620- t = t_frames ,
621- h = height // self .config .patch_size ,
622- w = width // self .config .patch_size ,
623- device = pixel_values .device ,
624- )
625- freqs_visible = freqs_full [visible_indices ]
620+ if patch_postions is not None :
621+ freqs_visible = self .video_rope .forward_from_positions (patch_postions )
622+ else :
623+ freqs_full = self .video_rope (
624+ t = t_frames ,
625+ h = height // self .config .patch_size ,
626+ w = width // self .config .patch_size ,
627+ device = pixel_values .device ,
628+ )
629+ freqs_visible = freqs_full [visible_indices ]
626630
627631 # Concatenate D/2 + D/2 -> D for applying rope
628632 freqs_visible = torch .cat ([freqs_visible , freqs_visible ], dim = - 1 )
0 commit comments