Skip to content

Commit 3489cf2

Browse files
[None][perf] Adopt vLLM's fused Triton bilinear pos-embed kernel for Qwen3-VL
Port ``_bilinear_pos_embed_kernel`` + ``_triton_pos_embed_interpolate`` from vLLM's qwen3_vl.py (commit 21eb2c33) and route ``Qwen3VisionModel.fast_pos_embed_interpolate`` through it when Triton is available. The kernel fuses the per-(t, h, w) bilinear interpolation of the learned position embedding with the spatial-merge reorder, replacing the prior PyTorch chain of ``torch.linspace`` + ``torch.meshgrid`` + embedding gather + scatter + ``permute`` with a single launch. The existing PyTorch implementation is retained as ``_pos_embed_interpolate_native`` and is used when Triton is unavailable or the embedding weight is not on CUDA. Microbench (H200, bf16, NUM_GRID=48, HIDDEN=1152, vLLM-realistic std=0.25 weights): (t,h,w) tokens native μs triton μs speedup (1, 16,16) 256 171.7 18.7 9.16x (1, 32,32) 1024 170.6 19.0 9.00x (1, 48,48) 2304 175.2 19.3 9.08x (1, 64,64) 4096 173.3 19.1 9.09x (1,100,100) 10000 289.9 21.7 13.39x (1,128,128) 16384 437.3 21.6 20.25x Accuracy: a new unit test ``tests/unittest/_torch/modeling/test_vit_bilinear_pos_embed.py`` mirrors vLLM's reference test (``tests/kernels/core/test_vit_bilinear_pos_embed.py``) across 8 representative (t, h, w) grids and bf16 / fp32. Triton matches native within ``atol = rtol = 1e-2`` (bf16) and ``atol = 5e-5, rtol = 1e-5`` (fp32) -- the same tolerances vLLM uses. Differences come from the precomputed scalar h/w_scale in the Triton path vs ``torch.linspace`` in the native path and stay at single-ULP level. Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
1 parent d8fb364 commit 3489cf2

2 files changed

Lines changed: 326 additions & 63 deletions

File tree

tensorrt_llm/_torch/models/modeling_qwen3vl.py

Lines changed: 147 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import torch
66
import torch.nn as nn
7+
import triton
8+
import triton.language as tl
79
from PIL import Image
810
from transformers import AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
911
from transformers.activations import ACT2FN as HF_ACT2FN
@@ -609,6 +611,140 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
609611
return hidden_states
610612

611613

614+
# ---------------------------------------------------------------------------
615+
# Fused bilinear position-embedding interpolation for the Qwen3-VL vision
616+
# tower. Ported from vLLM
617+
# (vllm/model_executor/models/qwen3_vl.py @ 21eb2c33).
618+
#
619+
# ``fast_pos_embed_interpolate`` resamples the learned grid of positional
620+
# embeddings onto each (t, h, w) image grid using bilinear interpolation and
621+
# then reorders the spatial axis to match the spatial-merge layout used by
622+
# the rest of the vision tower. The Triton path fuses the
623+
# bilinear-interp + spatial-merge reorder into a single kernel so the
624+
# embedding gather, the 4 corner reads, and the permute are all one fused
625+
# pass instead of separate ops.
626+
# ---------------------------------------------------------------------------
627+
628+
629+
@triton.jit
630+
def _bilinear_pos_embed_kernel(
631+
embed_ptr,
632+
output_ptr,
633+
H,
634+
W,
635+
h_scale,
636+
w_scale,
637+
NUM_GRID: tl.constexpr,
638+
M_SIZE: tl.constexpr,
639+
HIDDEN_DIM: tl.constexpr,
640+
BLOCK_D: tl.constexpr,
641+
):
642+
"""Fused bilinear pos-embed interpolation with spatial-merge reorder."""
643+
pid = tl.program_id(0)
644+
total_spatial = H * W
645+
spatial_idx = pid % total_spatial
646+
647+
num_blocks_w = W // M_SIZE
648+
block_idx = spatial_idx // (M_SIZE * M_SIZE)
649+
local_idx = spatial_idx % (M_SIZE * M_SIZE)
650+
br = block_idx // num_blocks_w
651+
bc = block_idx % num_blocks_w
652+
lr = local_idx // M_SIZE
653+
lc = local_idx % M_SIZE
654+
row = br * M_SIZE + lr
655+
col = bc * M_SIZE + lc
656+
657+
h_frac = row.to(tl.float32) * h_scale
658+
w_frac = col.to(tl.float32) * w_scale
659+
660+
hf = tl.math.floor(h_frac).to(tl.int32)
661+
wf = tl.math.floor(w_frac).to(tl.int32)
662+
hc = tl.minimum(hf + 1, NUM_GRID - 1)
663+
wc = tl.minimum(wf + 1, NUM_GRID - 1)
664+
665+
dh = h_frac - hf.to(tl.float32)
666+
dw = w_frac - wf.to(tl.float32)
667+
w11 = dh * dw
668+
w10 = dh - w11
669+
w01 = dw - w11
670+
w00 = 1.0 - dh - w01
671+
672+
off00 = (hf * NUM_GRID + wf) * HIDDEN_DIM
673+
off01 = (hf * NUM_GRID + wc) * HIDDEN_DIM
674+
off10 = (hc * NUM_GRID + wf) * HIDDEN_DIM
675+
off11 = (hc * NUM_GRID + wc) * HIDDEN_DIM
676+
out_off = pid * HIDDEN_DIM
677+
678+
# Cast weights to output dtype so the multiply-accumulate stays in
679+
# the same precision as the reference PyTorch implementation used in
680+
# the unit test.
681+
out_dtype = output_ptr.dtype.element_ty
682+
w00_c = w00.to(out_dtype)
683+
w01_c = w01.to(out_dtype)
684+
w10_c = w10.to(out_dtype)
685+
w11_c = w11.to(out_dtype)
686+
687+
for d in tl.range(0, HIDDEN_DIM, BLOCK_D):
688+
cols = d + tl.arange(0, BLOCK_D)
689+
mask = cols < HIDDEN_DIM
690+
691+
e00 = tl.load(embed_ptr + off00 + cols, mask=mask)
692+
e01 = tl.load(embed_ptr + off01 + cols, mask=mask)
693+
e10 = tl.load(embed_ptr + off10 + cols, mask=mask)
694+
e11 = tl.load(embed_ptr + off11 + cols, mask=mask)
695+
696+
val = w00_c * e00 + w01_c * e01 + w10_c * e10 + w11_c * e11
697+
698+
tl.store(output_ptr + out_off + cols, val, mask=mask)
699+
700+
701+
def _triton_pos_embed_interpolate(
702+
embed_weight: torch.Tensor,
703+
t: int,
704+
h: int,
705+
w: int,
706+
num_grid_per_side: int,
707+
m_size: int,
708+
dtype: torch.dtype,
709+
) -> torch.Tensor:
710+
"""Launch the fused Triton kernel for one (t, h, w) grid.
711+
712+
Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the
713+
bilinearly-interpolated position embeddings already in spatial-merge
714+
order.
715+
"""
716+
assert h % m_size == 0 and w % m_size == 0, (
717+
f"h={h} and w={w} must be divisible by m_size={m_size}"
718+
)
719+
hidden_dim = embed_weight.shape[1]
720+
total_out = t * h * w
721+
output = torch.empty(
722+
total_out,
723+
hidden_dim,
724+
device=embed_weight.device,
725+
dtype=dtype,
726+
)
727+
728+
h_scale = float(num_grid_per_side - 1) / float(h - 1) if h > 1 else 0.0
729+
w_scale = float(num_grid_per_side - 1) / float(w - 1) if w > 1 else 0.0
730+
731+
BLOCK_D = triton.next_power_of_2(hidden_dim)
732+
733+
_bilinear_pos_embed_kernel[(total_out,)](
734+
embed_weight,
735+
output,
736+
h,
737+
w,
738+
h_scale,
739+
w_scale,
740+
num_grid_per_side,
741+
m_size,
742+
hidden_dim,
743+
BLOCK_D,
744+
)
745+
return output
746+
747+
612748
class Qwen3VisionModel(torch.nn.Module):
613749
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
614750
super().__init__()
@@ -698,70 +834,19 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
698834

699835
# Adopted from https://github.com/vllm-project/vllm/blob/21eb2c3372fb6447ef36bee44ff7af79a330ffec/vllm/model_executor/models/qwen3_vl.py#L470)
700836
def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
837+
embed_weight = self.pos_embed.weight
701838
num_grid_per_side = self.num_grid_per_side
702839
m_size = self.spatial_merge_size
703-
hidden_dim = self.pos_embed.embedding_dim
704-
705-
outputs = []
706-
for t, h, w in grid_thw:
707-
h_idxs = torch.linspace(
708-
0,
709-
num_grid_per_side - 1,
710-
h,
711-
dtype=torch.float32,
712-
device=self.pos_embed.weight.device,
713-
)
714-
w_idxs = torch.linspace(
715-
0,
716-
num_grid_per_side - 1,
717-
w,
718-
dtype=torch.float32,
719-
device=self.pos_embed.weight.device,
720-
)
721-
722-
h_floor = h_idxs.to(torch.long)
723-
w_floor = w_idxs.to(torch.long)
724-
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
725-
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
726-
727-
dh = h_idxs - h_floor
728-
dw = w_idxs - w_floor
729-
730-
# Create meshgrid view for all h, w vars
731-
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
732-
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
733-
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
734-
735-
# original computation of weights
736-
# w00 = (1 - dh_grid) * (1 - dw_grid)
737-
# w01 = (1 - dh_grid) * dw_grid
738-
# w10 = dh_grid * (1 - dw_grid)
739-
# w11 = dh_grid * dw_grid
740-
# we reuse w11 here to avoid duplicate
741-
# dh_grid * dw_grid computation
742-
w11 = dh_grid * dw_grid
743-
w10 = dh_grid - w11
744-
w01 = dw_grid - w11
745-
w00 = 1 - dh_grid - w01
746-
747-
h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
748-
w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
749-
h_grid_idx = h_grid * num_grid_per_side
750-
751-
indices = (h_grid_idx + w_grid).reshape(4, -1)
752-
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
753-
weights = weights.to(dtype=self.pos_embed.weight.dtype)
754-
755-
embeds = self.pos_embed(indices)
756-
embeds *= weights
757-
combined = embeds.sum(dim=0)
758-
759-
combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim)
760-
combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
761-
repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
762-
outputs.append(repeated)
763-
764-
return torch.cat(outputs, dim=0)
840+
dtype = embed_weight.dtype
841+
return torch.cat(
842+
[
843+
_triton_pos_embed_interpolate(
844+
embed_weight, t, h, w, num_grid_per_side, m_size, dtype
845+
)
846+
for t, h, w in grid_thw
847+
],
848+
dim=0,
849+
)
765850

766851
def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata):
767852
# NOTE: The single prompt is divided into multiple seq_lens, so pretending have many batch_sizes.

0 commit comments

Comments
 (0)