Skip to content

Commit e7cdc94

Browse files
authored
perf(rope): vectorize qwen3_vl_get_rope_index, drop per-token sync (#184)
1 parent d0617d5 commit e7cdc94

1 file changed

Lines changed: 122 additions & 86 deletions

File tree

  • src/lmms_engine/models/common_ops

src/lmms_engine/models/common_ops/rope.py

Lines changed: 122 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
import numpy as np
34
import torch
45
from transformers import Qwen2_5_VLModel, Qwen3VLModel
56

@@ -11,98 +12,27 @@ def qwen3_vl_get_rope_index(
1112
video_grid_thw: Optional[torch.LongTensor] = None,
1213
attention_mask: Optional[torch.Tensor] = None,
1314
) -> tuple[torch.Tensor, torch.Tensor]:
14-
"""Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""
15+
"""Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.
1516
16-
# Since we use timestamps to seperate videos, like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split
17-
if video_grid_thw is not None:
18-
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
19-
video_grid_thw[:, 0] = 1
17+
Performance notes: the trivial Python-loop port of the upstream layout
18+
triggers O(B + N_vision) device syncs per call (input_ids.tolist(),
19+
repeat_interleave on a GPU repeats tensor, t.item()/h.item()/w.item()
20+
per image/video, llm_pos_ids_list[-1].max() per span, ...). On large
21+
multimodal sequences this can cost tens of ms per step.
2022
23+
This implementation pulls input_ids + grid_thw to host once at entry, then
24+
builds the (3, n_valid) position tensor for each row with numpy
25+
slice-assigns (avoiding per-token list.append + torch.tensor() conversion),
26+
and copies it back with a single H2D per row. The trivial no-vision
27+
branch is unchanged.
28+
"""
2129
spatial_merge_size = self.config.vision_config.spatial_merge_size
2230
image_token_id = self.config.image_token_id
2331
video_token_id = self.config.video_token_id
2432
vision_start_token_id = self.config.vision_start_token_id
25-
mrope_position_deltas = []
26-
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
27-
total_input_ids = input_ids
28-
if attention_mask is None:
29-
attention_mask = torch.ones_like(total_input_ids)
30-
position_ids = torch.ones(
31-
3,
32-
input_ids.shape[0],
33-
input_ids.shape[1],
34-
dtype=input_ids.dtype,
35-
device=input_ids.device,
36-
)
37-
image_index, video_index = 0, 0
38-
attention_mask = attention_mask.to(total_input_ids.device)
39-
for i, input_ids in enumerate(total_input_ids):
40-
input_ids = input_ids[attention_mask[i] == 1]
41-
image_nums, video_nums = 0, 0
42-
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
43-
vision_tokens = input_ids[vision_start_indices + 1]
44-
image_nums = (vision_tokens == image_token_id).sum()
45-
video_nums = (vision_tokens == video_token_id).sum()
46-
input_tokens = input_ids.tolist()
47-
llm_pos_ids_list: list = []
48-
st = 0
49-
remain_images, remain_videos = image_nums, video_nums
50-
for _ in range(image_nums + video_nums):
51-
if image_token_id in input_tokens and remain_images > 0:
52-
ed_image = input_tokens.index(image_token_id, st)
53-
else:
54-
ed_image = len(input_tokens) + 1
55-
if video_token_id in input_tokens and remain_videos > 0:
56-
ed_video = input_tokens.index(video_token_id, st)
57-
else:
58-
ed_video = len(input_tokens) + 1
59-
if ed_image < ed_video:
60-
t, h, w = (
61-
image_grid_thw[image_index][0],
62-
image_grid_thw[image_index][1],
63-
image_grid_thw[image_index][2],
64-
)
65-
image_index += 1
66-
remain_images -= 1
67-
ed = ed_image
6833

69-
else:
70-
t, h, w = (
71-
video_grid_thw[video_index][0],
72-
video_grid_thw[video_index][1],
73-
video_grid_thw[video_index][2],
74-
)
75-
video_index += 1
76-
remain_videos -= 1
77-
ed = ed_video
78-
llm_grid_t, llm_grid_h, llm_grid_w = (
79-
t.item(),
80-
h.item() // spatial_merge_size,
81-
w.item() // spatial_merge_size,
82-
)
83-
text_len = ed - st
84-
85-
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
86-
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
87-
88-
# t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
89-
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
90-
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
91-
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
92-
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
93-
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
94-
95-
if st < len(input_tokens):
96-
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
97-
text_len = len(input_tokens) - st
98-
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
99-
100-
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
101-
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
102-
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
103-
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
104-
return position_ids, mrope_position_deltas
105-
else:
34+
has_vision = input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None)
35+
if not has_vision:
10636
if attention_mask is not None:
10737
position_ids = attention_mask.long().cumsum(-1) - 1
10838
position_ids.masked_fill_(attention_mask == 0, 1)
@@ -120,9 +50,115 @@ def qwen3_vl_get_rope_index(
12050
device=input_ids.device,
12151
dtype=input_ids.dtype,
12252
)
123-
12453
return position_ids, mrope_position_deltas
12554

55+
device = input_ids.device
56+
dtype = input_ids.dtype
57+
58+
# One-shot device -> host pull. After this point everything is plain
59+
# numpy / Python int — no further device syncs inside the hot loop.
60+
input_ids_np = input_ids.detach().cpu().numpy()
61+
if attention_mask is not None:
62+
attention_mask_np = attention_mask.detach().to(input_ids.device).cpu().numpy()
63+
else:
64+
attention_mask_np = np.ones_like(input_ids_np)
65+
66+
if image_grid_thw is not None:
67+
image_thw_np = image_grid_thw.detach().cpu().numpy()
68+
else:
69+
image_thw_np = np.zeros((0, 3), dtype=np.int64)
70+
if video_grid_thw is not None:
71+
# Mirror the upstream repeat_interleave on T axis + set T=1: each
72+
# frame of a video is treated as an independent (1, H, W) entry so
73+
# video timestamps carry the temporal position.
74+
video_thw_np = video_grid_thw.detach().cpu().numpy()
75+
if video_thw_np.shape[0] > 0:
76+
video_thw_np = np.repeat(video_thw_np, video_thw_np[:, 0], axis=0)
77+
video_thw_np[:, 0] = 1
78+
else:
79+
video_thw_np = np.zeros((0, 3), dtype=np.int64)
80+
81+
B, S = input_ids.shape
82+
position_ids = torch.ones(3, B, S, dtype=dtype, device=device)
83+
mrope_position_deltas: list[int] = []
84+
85+
image_index = 0
86+
video_index = 0
87+
for i in range(B):
88+
row = input_ids_np[i]
89+
mask_row_bool = attention_mask_np[i].astype(bool)
90+
valid_tokens = row[mask_row_bool]
91+
n_valid = int(valid_tokens.shape[0])
92+
93+
# Vectorized scan: find all <vision_start> followed by image_pad or video_pad.
94+
vstart_positions = np.flatnonzero(valid_tokens == vision_start_token_id)
95+
vstart_positions = vstart_positions[vstart_positions + 1 < n_valid]
96+
next_tokens = valid_tokens[vstart_positions + 1]
97+
is_image = next_tokens == image_token_id
98+
is_video = next_tokens == video_token_id
99+
keep = is_image | is_video
100+
vstart_positions = vstart_positions[keep]
101+
is_video_kept = is_video[keep]
102+
n_spans = int(vstart_positions.shape[0])
103+
104+
out = np.empty((3, n_valid), dtype=np.int64)
105+
st = 0
106+
last_max = -1 # so st_idx = last_max + 1 starts at 0
107+
for s_idx in range(n_spans):
108+
# span_start is the first vision token after <vision_start>.
109+
span_start = int(vstart_positions[s_idx]) + 1
110+
if is_video_kept[s_idx]:
111+
t, h, w = video_thw_np[video_index]
112+
video_index += 1
113+
else:
114+
t, h, w = image_thw_np[image_index]
115+
image_index += 1
116+
llm_grid_t = int(t)
117+
llm_grid_h = int(h) // spatial_merge_size
118+
llm_grid_w = int(w) // spatial_merge_size
119+
text_len = span_start - st
120+
st_idx = last_max + 1
121+
base = text_len + st_idx
122+
123+
if text_len > 0:
124+
run = np.arange(st_idx, st_idx + text_len, dtype=np.int64)
125+
out[0, st : st + text_len] = run
126+
out[1, st : st + text_len] = run
127+
out[2, st : st + text_len] = run
128+
129+
v_len = llm_grid_t * llm_grid_h * llm_grid_w
130+
v_start = st + text_len
131+
if v_len > 0:
132+
t_axis = np.repeat(np.arange(llm_grid_t, dtype=np.int64), llm_grid_h * llm_grid_w) + base
133+
h_axis = np.tile(np.repeat(np.arange(llm_grid_h, dtype=np.int64), llm_grid_w), llm_grid_t) + base
134+
w_axis = np.tile(np.arange(llm_grid_w, dtype=np.int64), llm_grid_t * llm_grid_h) + base
135+
out[0, v_start : v_start + v_len] = t_axis
136+
out[1, v_start : v_start + v_len] = h_axis
137+
out[2, v_start : v_start + v_len] = w_axis
138+
139+
st = v_start + v_len
140+
last_max = base + max(llm_grid_t, llm_grid_h, llm_grid_w) - 1
141+
142+
if st < n_valid:
143+
text_len = n_valid - st
144+
st_idx = last_max + 1
145+
run = np.arange(st_idx, st_idx + text_len, dtype=np.int64)
146+
out[0, st:n_valid] = run
147+
out[1, st:n_valid] = run
148+
out[2, st:n_valid] = run
149+
last_max = st_idx + text_len - 1
150+
151+
# Single H2D for this row's positions.
152+
llm_positions = torch.from_numpy(out).to(device=device, dtype=dtype, non_blocking=True)
153+
if attention_mask is not None:
154+
position_ids[:, i, attention_mask[i] == 1] = llm_positions
155+
else:
156+
position_ids[:, i, :] = llm_positions
157+
mrope_position_deltas.append(last_max + 1 - int(S))
158+
159+
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=device, dtype=dtype).unsqueeze(1)
160+
return position_ids, mrope_position_deltas
161+
126162

127163
def qwen2_5_vl_rope_index(
128164
self: Qwen2_5_VLModel,

0 commit comments

Comments
 (0)