|
4 | 4 |
|
5 | 5 | import torch |
6 | 6 | import torch.nn as nn |
| 7 | +import triton |
| 8 | +import triton.language as tl |
7 | 9 | from PIL import Image |
8 | 10 | from transformers import AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel |
9 | 11 | from transformers.activations import ACT2FN as HF_ACT2FN |
@@ -609,6 +611,140 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
609 | 611 | return hidden_states |
610 | 612 |
|
611 | 613 |
|
| 614 | +# --------------------------------------------------------------------------- |
| 615 | +# Fused bilinear position-embedding interpolation for the Qwen3-VL vision |
| 616 | +# tower. Ported from vLLM |
| 617 | +# (vllm/model_executor/models/qwen3_vl.py @ 21eb2c33). |
| 618 | +# |
| 619 | +# ``fast_pos_embed_interpolate`` resamples the learned grid of positional |
| 620 | +# embeddings onto each (t, h, w) image grid using bilinear interpolation and |
| 621 | +# then reorders the spatial axis to match the spatial-merge layout used by |
| 622 | +# the rest of the vision tower. The Triton path fuses the |
| 623 | +# bilinear-interp + spatial-merge reorder into a single kernel so the |
| 624 | +# embedding gather, the 4 corner reads, and the permute are all one fused |
| 625 | +# pass instead of separate ops. |
| 626 | +# --------------------------------------------------------------------------- |
| 627 | + |
| 628 | + |
| 629 | +@triton.jit |
| 630 | +def _bilinear_pos_embed_kernel( |
| 631 | + embed_ptr, |
| 632 | + output_ptr, |
| 633 | + H, |
| 634 | + W, |
| 635 | + h_scale, |
| 636 | + w_scale, |
| 637 | + NUM_GRID: tl.constexpr, |
| 638 | + M_SIZE: tl.constexpr, |
| 639 | + HIDDEN_DIM: tl.constexpr, |
| 640 | + BLOCK_D: tl.constexpr, |
| 641 | +): |
| 642 | + """Fused bilinear pos-embed interpolation with spatial-merge reorder.""" |
| 643 | + pid = tl.program_id(0) |
| 644 | + total_spatial = H * W |
| 645 | + spatial_idx = pid % total_spatial |
| 646 | + |
| 647 | + num_blocks_w = W // M_SIZE |
| 648 | + block_idx = spatial_idx // (M_SIZE * M_SIZE) |
| 649 | + local_idx = spatial_idx % (M_SIZE * M_SIZE) |
| 650 | + br = block_idx // num_blocks_w |
| 651 | + bc = block_idx % num_blocks_w |
| 652 | + lr = local_idx // M_SIZE |
| 653 | + lc = local_idx % M_SIZE |
| 654 | + row = br * M_SIZE + lr |
| 655 | + col = bc * M_SIZE + lc |
| 656 | + |
| 657 | + h_frac = row.to(tl.float32) * h_scale |
| 658 | + w_frac = col.to(tl.float32) * w_scale |
| 659 | + |
| 660 | + hf = tl.math.floor(h_frac).to(tl.int32) |
| 661 | + wf = tl.math.floor(w_frac).to(tl.int32) |
| 662 | + hc = tl.minimum(hf + 1, NUM_GRID - 1) |
| 663 | + wc = tl.minimum(wf + 1, NUM_GRID - 1) |
| 664 | + |
| 665 | + dh = h_frac - hf.to(tl.float32) |
| 666 | + dw = w_frac - wf.to(tl.float32) |
| 667 | + w11 = dh * dw |
| 668 | + w10 = dh - w11 |
| 669 | + w01 = dw - w11 |
| 670 | + w00 = 1.0 - dh - w01 |
| 671 | + |
| 672 | + off00 = (hf * NUM_GRID + wf) * HIDDEN_DIM |
| 673 | + off01 = (hf * NUM_GRID + wc) * HIDDEN_DIM |
| 674 | + off10 = (hc * NUM_GRID + wf) * HIDDEN_DIM |
| 675 | + off11 = (hc * NUM_GRID + wc) * HIDDEN_DIM |
| 676 | + out_off = pid * HIDDEN_DIM |
| 677 | + |
| 678 | + # Cast weights to output dtype so the multiply-accumulate stays in |
| 679 | + # the same precision as the reference PyTorch implementation used in |
| 680 | + # the unit test. |
| 681 | + out_dtype = output_ptr.dtype.element_ty |
| 682 | + w00_c = w00.to(out_dtype) |
| 683 | + w01_c = w01.to(out_dtype) |
| 684 | + w10_c = w10.to(out_dtype) |
| 685 | + w11_c = w11.to(out_dtype) |
| 686 | + |
| 687 | + for d in tl.range(0, HIDDEN_DIM, BLOCK_D): |
| 688 | + cols = d + tl.arange(0, BLOCK_D) |
| 689 | + mask = cols < HIDDEN_DIM |
| 690 | + |
| 691 | + e00 = tl.load(embed_ptr + off00 + cols, mask=mask) |
| 692 | + e01 = tl.load(embed_ptr + off01 + cols, mask=mask) |
| 693 | + e10 = tl.load(embed_ptr + off10 + cols, mask=mask) |
| 694 | + e11 = tl.load(embed_ptr + off11 + cols, mask=mask) |
| 695 | + |
| 696 | + val = w00_c * e00 + w01_c * e01 + w10_c * e10 + w11_c * e11 |
| 697 | + |
| 698 | + tl.store(output_ptr + out_off + cols, val, mask=mask) |
| 699 | + |
| 700 | + |
| 701 | +def _triton_pos_embed_interpolate( |
| 702 | + embed_weight: torch.Tensor, |
| 703 | + t: int, |
| 704 | + h: int, |
| 705 | + w: int, |
| 706 | + num_grid_per_side: int, |
| 707 | + m_size: int, |
| 708 | + dtype: torch.dtype, |
| 709 | +) -> torch.Tensor: |
| 710 | + """Launch the fused Triton kernel for one (t, h, w) grid. |
| 711 | +
|
| 712 | + Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the |
| 713 | + bilinearly-interpolated position embeddings already in spatial-merge |
| 714 | + order. |
| 715 | + """ |
| 716 | + assert h % m_size == 0 and w % m_size == 0, ( |
| 717 | + f"h={h} and w={w} must be divisible by m_size={m_size}" |
| 718 | + ) |
| 719 | + hidden_dim = embed_weight.shape[1] |
| 720 | + total_out = t * h * w |
| 721 | + output = torch.empty( |
| 722 | + total_out, |
| 723 | + hidden_dim, |
| 724 | + device=embed_weight.device, |
| 725 | + dtype=dtype, |
| 726 | + ) |
| 727 | + |
| 728 | + h_scale = float(num_grid_per_side - 1) / float(h - 1) if h > 1 else 0.0 |
| 729 | + w_scale = float(num_grid_per_side - 1) / float(w - 1) if w > 1 else 0.0 |
| 730 | + |
| 731 | + BLOCK_D = triton.next_power_of_2(hidden_dim) |
| 732 | + |
| 733 | + _bilinear_pos_embed_kernel[(total_out,)]( |
| 734 | + embed_weight, |
| 735 | + output, |
| 736 | + h, |
| 737 | + w, |
| 738 | + h_scale, |
| 739 | + w_scale, |
| 740 | + num_grid_per_side, |
| 741 | + m_size, |
| 742 | + hidden_dim, |
| 743 | + BLOCK_D, |
| 744 | + ) |
| 745 | + return output |
| 746 | + |
| 747 | + |
612 | 748 | class Qwen3VisionModel(torch.nn.Module): |
613 | 749 | def __init__(self, model_config: ModelConfig[PretrainedConfig]): |
614 | 750 | super().__init__() |
@@ -698,70 +834,19 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: |
698 | 834 |
|
699 | 835 | # Adopted from https://github.com/vllm-project/vllm/blob/21eb2c3372fb6447ef36bee44ff7af79a330ffec/vllm/model_executor/models/qwen3_vl.py#L470) |
700 | 836 | def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: |
| 837 | + embed_weight = self.pos_embed.weight |
701 | 838 | num_grid_per_side = self.num_grid_per_side |
702 | 839 | m_size = self.spatial_merge_size |
703 | | - hidden_dim = self.pos_embed.embedding_dim |
704 | | - |
705 | | - outputs = [] |
706 | | - for t, h, w in grid_thw: |
707 | | - h_idxs = torch.linspace( |
708 | | - 0, |
709 | | - num_grid_per_side - 1, |
710 | | - h, |
711 | | - dtype=torch.float32, |
712 | | - device=self.pos_embed.weight.device, |
713 | | - ) |
714 | | - w_idxs = torch.linspace( |
715 | | - 0, |
716 | | - num_grid_per_side - 1, |
717 | | - w, |
718 | | - dtype=torch.float32, |
719 | | - device=self.pos_embed.weight.device, |
720 | | - ) |
721 | | - |
722 | | - h_floor = h_idxs.to(torch.long) |
723 | | - w_floor = w_idxs.to(torch.long) |
724 | | - h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) |
725 | | - w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) |
726 | | - |
727 | | - dh = h_idxs - h_floor |
728 | | - dw = w_idxs - w_floor |
729 | | - |
730 | | - # Create meshgrid view for all h, w vars |
731 | | - dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") |
732 | | - h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") |
733 | | - h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") |
734 | | - |
735 | | - # original computation of weights |
736 | | - # w00 = (1 - dh_grid) * (1 - dw_grid) |
737 | | - # w01 = (1 - dh_grid) * dw_grid |
738 | | - # w10 = dh_grid * (1 - dw_grid) |
739 | | - # w11 = dh_grid * dw_grid |
740 | | - # we reuse w11 here to avoid duplicate |
741 | | - # dh_grid * dw_grid computation |
742 | | - w11 = dh_grid * dw_grid |
743 | | - w10 = dh_grid - w11 |
744 | | - w01 = dw_grid - w11 |
745 | | - w00 = 1 - dh_grid - w01 |
746 | | - |
747 | | - h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) |
748 | | - w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) |
749 | | - h_grid_idx = h_grid * num_grid_per_side |
750 | | - |
751 | | - indices = (h_grid_idx + w_grid).reshape(4, -1) |
752 | | - weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) |
753 | | - weights = weights.to(dtype=self.pos_embed.weight.dtype) |
754 | | - |
755 | | - embeds = self.pos_embed(indices) |
756 | | - embeds *= weights |
757 | | - combined = embeds.sum(dim=0) |
758 | | - |
759 | | - combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim) |
760 | | - combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) |
761 | | - repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) |
762 | | - outputs.append(repeated) |
763 | | - |
764 | | - return torch.cat(outputs, dim=0) |
| 840 | + dtype = embed_weight.dtype |
| 841 | + return torch.cat( |
| 842 | + [ |
| 843 | + _triton_pos_embed_interpolate( |
| 844 | + embed_weight, t, h, w, num_grid_per_side, m_size, dtype |
| 845 | + ) |
| 846 | + for t, h, w in grid_thw |
| 847 | + ], |
| 848 | + dim=0, |
| 849 | + ) |
765 | 850 |
|
766 | 851 | def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata): |
767 | 852 | # NOTE: The single prompt is divided into multiple seq_lens, so pretending have many batch_sizes. |
|
0 commit comments