11from typing import Optional
22
3+ import numpy as np
34import torch
45from transformers import Qwen2_5_VLModel , Qwen3VLModel
56
@@ -11,98 +12,27 @@ def qwen3_vl_get_rope_index(
1112 video_grid_thw : Optional [torch .LongTensor ] = None ,
1213 attention_mask : Optional [torch .Tensor ] = None ,
1314) -> tuple [torch .Tensor , torch .Tensor ]:
14- """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""
15+ """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.
1516
16- # Since we use timestamps to seperate videos, like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split
17- if video_grid_thw is not None :
18- video_grid_thw = torch .repeat_interleave (video_grid_thw , video_grid_thw [:, 0 ], dim = 0 )
19- video_grid_thw [:, 0 ] = 1
17+ Performance notes: the trivial Python-loop port of the upstream layout
18+ triggers O(B + N_vision) device syncs per call (input_ids.tolist(),
19+ repeat_interleave on a GPU repeats tensor, t.item()/h.item()/w.item()
20+ per image/video, llm_pos_ids_list[-1].max() per span, ...). On large
21+ multimodal sequences this can cost tens of ms per step.
2022
23+ This implementation pulls input_ids + grid_thw to host once at entry, then
24+ builds the (3, n_valid) position tensor for each row with numpy
25+ slice-assigns (avoiding per-token list.append + torch.tensor() conversion),
26+ and copies it back with a single H2D per row. The trivial no-vision
27+ branch is unchanged.
28+ """
2129 spatial_merge_size = self .config .vision_config .spatial_merge_size
2230 image_token_id = self .config .image_token_id
2331 video_token_id = self .config .video_token_id
2432 vision_start_token_id = self .config .vision_start_token_id
25- mrope_position_deltas = []
26- if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None ):
27- total_input_ids = input_ids
28- if attention_mask is None :
29- attention_mask = torch .ones_like (total_input_ids )
30- position_ids = torch .ones (
31- 3 ,
32- input_ids .shape [0 ],
33- input_ids .shape [1 ],
34- dtype = input_ids .dtype ,
35- device = input_ids .device ,
36- )
37- image_index , video_index = 0 , 0
38- attention_mask = attention_mask .to (total_input_ids .device )
39- for i , input_ids in enumerate (total_input_ids ):
40- input_ids = input_ids [attention_mask [i ] == 1 ]
41- image_nums , video_nums = 0 , 0
42- vision_start_indices = torch .argwhere (input_ids == vision_start_token_id ).squeeze (1 )
43- vision_tokens = input_ids [vision_start_indices + 1 ]
44- image_nums = (vision_tokens == image_token_id ).sum ()
45- video_nums = (vision_tokens == video_token_id ).sum ()
46- input_tokens = input_ids .tolist ()
47- llm_pos_ids_list : list = []
48- st = 0
49- remain_images , remain_videos = image_nums , video_nums
50- for _ in range (image_nums + video_nums ):
51- if image_token_id in input_tokens and remain_images > 0 :
52- ed_image = input_tokens .index (image_token_id , st )
53- else :
54- ed_image = len (input_tokens ) + 1
55- if video_token_id in input_tokens and remain_videos > 0 :
56- ed_video = input_tokens .index (video_token_id , st )
57- else :
58- ed_video = len (input_tokens ) + 1
59- if ed_image < ed_video :
60- t , h , w = (
61- image_grid_thw [image_index ][0 ],
62- image_grid_thw [image_index ][1 ],
63- image_grid_thw [image_index ][2 ],
64- )
65- image_index += 1
66- remain_images -= 1
67- ed = ed_image
6833
69- else :
70- t , h , w = (
71- video_grid_thw [video_index ][0 ],
72- video_grid_thw [video_index ][1 ],
73- video_grid_thw [video_index ][2 ],
74- )
75- video_index += 1
76- remain_videos -= 1
77- ed = ed_video
78- llm_grid_t , llm_grid_h , llm_grid_w = (
79- t .item (),
80- h .item () // spatial_merge_size ,
81- w .item () // spatial_merge_size ,
82- )
83- text_len = ed - st
84-
85- st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (llm_pos_ids_list ) > 0 else 0
86- llm_pos_ids_list .append (torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx )
87-
88- # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
89- t_index = torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (- 1 , llm_grid_h * llm_grid_w ).flatten ()
90- h_index = torch .arange (llm_grid_h ).view (1 , - 1 , 1 ).expand (llm_grid_t , - 1 , llm_grid_w ).flatten ()
91- w_index = torch .arange (llm_grid_w ).view (1 , 1 , - 1 ).expand (llm_grid_t , llm_grid_h , - 1 ).flatten ()
92- llm_pos_ids_list .append (torch .stack ([t_index , h_index , w_index ]) + text_len + st_idx )
93- st = ed + llm_grid_t * llm_grid_h * llm_grid_w
94-
95- if st < len (input_tokens ):
96- st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (llm_pos_ids_list ) > 0 else 0
97- text_len = len (input_tokens ) - st
98- llm_pos_ids_list .append (torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx )
99-
100- llm_positions = torch .cat (llm_pos_ids_list , dim = 1 ).reshape (3 , - 1 )
101- position_ids [..., i , attention_mask [i ] == 1 ] = llm_positions .to (position_ids .device )
102- mrope_position_deltas .append (llm_positions .max () + 1 - len (total_input_ids [i ]))
103- mrope_position_deltas = torch .tensor (mrope_position_deltas , device = input_ids .device ).unsqueeze (1 )
104- return position_ids , mrope_position_deltas
105- else :
34+ has_vision = input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None )
35+ if not has_vision :
10636 if attention_mask is not None :
10737 position_ids = attention_mask .long ().cumsum (- 1 ) - 1
10838 position_ids .masked_fill_ (attention_mask == 0 , 1 )
@@ -120,9 +50,115 @@ def qwen3_vl_get_rope_index(
12050 device = input_ids .device ,
12151 dtype = input_ids .dtype ,
12252 )
123-
12453 return position_ids , mrope_position_deltas
12554
55+ device = input_ids .device
56+ dtype = input_ids .dtype
57+
58+ # One-shot device -> host pull. After this point everything is plain
59+ # numpy / Python int — no further device syncs inside the hot loop.
60+ input_ids_np = input_ids .detach ().cpu ().numpy ()
61+ if attention_mask is not None :
62+ attention_mask_np = attention_mask .detach ().to (input_ids .device ).cpu ().numpy ()
63+ else :
64+ attention_mask_np = np .ones_like (input_ids_np )
65+
66+ if image_grid_thw is not None :
67+ image_thw_np = image_grid_thw .detach ().cpu ().numpy ()
68+ else :
69+ image_thw_np = np .zeros ((0 , 3 ), dtype = np .int64 )
70+ if video_grid_thw is not None :
71+ # Mirror the upstream repeat_interleave on T axis + set T=1: each
72+ # frame of a video is treated as an independent (1, H, W) entry so
73+ # video timestamps carry the temporal position.
74+ video_thw_np = video_grid_thw .detach ().cpu ().numpy ()
75+ if video_thw_np .shape [0 ] > 0 :
76+ video_thw_np = np .repeat (video_thw_np , video_thw_np [:, 0 ], axis = 0 )
77+ video_thw_np [:, 0 ] = 1
78+ else :
79+ video_thw_np = np .zeros ((0 , 3 ), dtype = np .int64 )
80+
81+ B , S = input_ids .shape
82+ position_ids = torch .ones (3 , B , S , dtype = dtype , device = device )
83+ mrope_position_deltas : list [int ] = []
84+
85+ image_index = 0
86+ video_index = 0
87+ for i in range (B ):
88+ row = input_ids_np [i ]
89+ mask_row_bool = attention_mask_np [i ].astype (bool )
90+ valid_tokens = row [mask_row_bool ]
91+ n_valid = int (valid_tokens .shape [0 ])
92+
93+ # Vectorized scan: find all <vision_start> followed by image_pad or video_pad.
94+ vstart_positions = np .flatnonzero (valid_tokens == vision_start_token_id )
95+ vstart_positions = vstart_positions [vstart_positions + 1 < n_valid ]
96+ next_tokens = valid_tokens [vstart_positions + 1 ]
97+ is_image = next_tokens == image_token_id
98+ is_video = next_tokens == video_token_id
99+ keep = is_image | is_video
100+ vstart_positions = vstart_positions [keep ]
101+ is_video_kept = is_video [keep ]
102+ n_spans = int (vstart_positions .shape [0 ])
103+
104+ out = np .empty ((3 , n_valid ), dtype = np .int64 )
105+ st = 0
106+ last_max = - 1 # so st_idx = last_max + 1 starts at 0
107+ for s_idx in range (n_spans ):
108+ # span_start is the first vision token after <vision_start>.
109+ span_start = int (vstart_positions [s_idx ]) + 1
110+ if is_video_kept [s_idx ]:
111+ t , h , w = video_thw_np [video_index ]
112+ video_index += 1
113+ else :
114+ t , h , w = image_thw_np [image_index ]
115+ image_index += 1
116+ llm_grid_t = int (t )
117+ llm_grid_h = int (h ) // spatial_merge_size
118+ llm_grid_w = int (w ) // spatial_merge_size
119+ text_len = span_start - st
120+ st_idx = last_max + 1
121+ base = text_len + st_idx
122+
123+ if text_len > 0 :
124+ run = np .arange (st_idx , st_idx + text_len , dtype = np .int64 )
125+ out [0 , st : st + text_len ] = run
126+ out [1 , st : st + text_len ] = run
127+ out [2 , st : st + text_len ] = run
128+
129+ v_len = llm_grid_t * llm_grid_h * llm_grid_w
130+ v_start = st + text_len
131+ if v_len > 0 :
132+ t_axis = np .repeat (np .arange (llm_grid_t , dtype = np .int64 ), llm_grid_h * llm_grid_w ) + base
133+ h_axis = np .tile (np .repeat (np .arange (llm_grid_h , dtype = np .int64 ), llm_grid_w ), llm_grid_t ) + base
134+ w_axis = np .tile (np .arange (llm_grid_w , dtype = np .int64 ), llm_grid_t * llm_grid_h ) + base
135+ out [0 , v_start : v_start + v_len ] = t_axis
136+ out [1 , v_start : v_start + v_len ] = h_axis
137+ out [2 , v_start : v_start + v_len ] = w_axis
138+
139+ st = v_start + v_len
140+ last_max = base + max (llm_grid_t , llm_grid_h , llm_grid_w ) - 1
141+
142+ if st < n_valid :
143+ text_len = n_valid - st
144+ st_idx = last_max + 1
145+ run = np .arange (st_idx , st_idx + text_len , dtype = np .int64 )
146+ out [0 , st :n_valid ] = run
147+ out [1 , st :n_valid ] = run
148+ out [2 , st :n_valid ] = run
149+ last_max = st_idx + text_len - 1
150+
151+ # Single H2D for this row's positions.
152+ llm_positions = torch .from_numpy (out ).to (device = device , dtype = dtype , non_blocking = True )
153+ if attention_mask is not None :
154+ position_ids [:, i , attention_mask [i ] == 1 ] = llm_positions
155+ else :
156+ position_ids [:, i , :] = llm_positions
157+ mrope_position_deltas .append (last_max + 1 - int (S ))
158+
159+ mrope_position_deltas = torch .tensor (mrope_position_deltas , device = device , dtype = dtype ).unsqueeze (1 )
160+ return position_ids , mrope_position_deltas
161+
126162
127163def qwen2_5_vl_rope_index (
128164 self : Qwen2_5_VLModel ,
0 commit comments