Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions python/sgl_kernel_npu/sgl_kernel_npu/norm/partial_rope_qk_inplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
import triton
import triton.language as tl
from sgl_kernel_npu.utils.triton_utils import get_device_properties


@triton.jit
def partial_rope_qk_inplace_kernel(
query_ptr,
key_ptr,
cos_sin_ptr,
stride_qt,
stride_qh,
stride_qd,
stride_kt,
stride_kh,
stride_kd,
stride_ct,
stride_cd,
groups: tl.constexpr,
D_ROPE: tl.constexpr,
IS_NEOX_STYLE: tl.constexpr,
HQ_IN_GRID: tl.constexpr,
):
t_id = tl.program_id(0)


d = tl.arange(0, D_ROPE // 2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Triton's tl.arange requires the size to be a power of 2. If D_ROPE // 2 is not a power of 2 (e.g., for a rotary dimension of 96), this kernel will fail to compile. Consider using a block size that is the next power of 2 and applying a mask during loads and stores to ensure compatibility with non-power-of-2 rotary dimensions.

if IS_NEOX_STYLE:
idx_even = d
idx_odd = d + D_ROPE // 2
else:
idx_even = d * 2
idx_odd = d * 2 + 1

# cos / sin
cos = tl.load(cos_sin_ptr + t_id * stride_ct + d * stride_cd) # (D_ROPE // 2,)
sin = tl.load(cos_sin_ptr + t_id * stride_ct + (d + D_ROPE // 2) * stride_cd) # (D_ROPE // 2,)

if not HQ_IN_GRID:
hk_id = tl.program_id(1)
# ================= Q =================
for g_id in range(groups):
hq_id = hk_id + g_id
q_base = query_ptr + t_id * stride_qt + hq_id * stride_qh
q1 = tl.load(q_base + idx_even * stride_qd)
q2 = tl.load(q_base + idx_odd * stride_qd)
q_out1 = (q1 * cos) - (q2 * sin)
q_out2 = (q1 * sin) + (q2 * cos)
tl.store(q_base + idx_even * stride_qd, q_out1)
tl.store(q_base + idx_odd * stride_qd, q_out2)

# ================= K =================
k_base = key_ptr + t_id * stride_kt + hk_id * stride_kh
k1 = tl.load(k_base + idx_even * stride_kd)
k2 = tl.load(k_base + idx_odd * stride_kd)

k_out1 = (k1 * cos) - (k2 * sin)
k_out2 = (k1 * sin) + (k2 * cos)

tl.store(k_base + idx_even * stride_kd, k_out1)
tl.store(k_base + idx_odd * stride_kd, k_out2)
else:
hq_id = tl.program_id(1)
# ================= Q =================
q_base = query_ptr + t_id * stride_qt + hq_id * stride_qh
q1 = tl.load(q_base + idx_even * stride_qd)
q2 = tl.load(q_base + idx_odd * stride_qd)
q_out1 = (q1 * cos) - (q2 * sin)
q_out2 = (q1 * sin) + (q2 * cos)
tl.store(q_base + idx_even * stride_qd, q_out1)
tl.store(q_base + idx_odd * stride_qd, q_out2)

# ================= K =================
if hq_id % groups == 0:
hk_id = hq_id // groups
k_base = key_ptr + t_id * stride_kt + hk_id * stride_kh
k1 = tl.load(k_base + idx_even * stride_kd)
k2 = tl.load(k_base + idx_odd * stride_kd)
k_out1 = (k1 * cos) - (k2 * sin)
k_out2 = (k1 * sin) + (k2 * cos)
tl.store(k_base + idx_even * stride_kd, k_out1)
tl.store(k_base + idx_odd * stride_kd, k_out2)


def partial_rope_qk_inplace(
query, # [T, Hq, D]
key, # [T, Hk, D]
cos_sin, # [T, rotary_dim]
rotary_dim,
is_neox_style=False,
):
T, Hq, D = query.shape
_, Hk, _ = key.shape
assert Hq % Hk == 0
Comment on lines +93 to +95
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

It is important to validate that rotary_dim is within the bounds of the head dimension D and that it is an even number. If rotary_dim > D, the kernel will perform out-of-bounds memory accesses on the query and key tensors.

Suggested change
T, Hq, D = query.shape
_, Hk, _ = key.shape
assert Hq % Hk == 0
T, Hq, D = query.shape
_, Hk, _ = key.shape
assert Hq % Hk == 0
assert rotary_dim <= D
assert rotary_dim % 2 == 0

_, vec_cors = get_device_properties()
grid = (T, Hq)
HQ_IN_GRID = True
if T * Hk >= vec_cors:
grid = (T, Hk)
HQ_IN_GRID = False

partial_rope_qk_inplace_kernel[grid](
query,
key,
cos_sin,
query.stride(0),
query.stride(1),
query.stride(2),
key.stride(0),
key.stride(1),
key.stride(2),
cos_sin.stride(0),
cos_sin.stride(1),
groups=Hq // Hk,
D_ROPE=rotary_dim,
IS_NEOX_STYLE=is_neox_style,
HQ_IN_GRID=HQ_IN_GRID,
)

return query, key
62 changes: 62 additions & 0 deletions tests/python/sgl_kernel_npu/test_partial_rope_qk_inplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
from sgl_kernel_npu.norm.partial_rope_qk_inplace import partial_rope_qk_inplace


def apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style=False,
) -> torch.Tensor:
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)


def rope_native(query, key, cos_sin, rotary_dim, is_neox_style=False):
head_dim = query.shape[-1]
cos, sin = cos_sin.chunk(2, dim=-1)
q_pe, q_nope = torch.split(query, [rotary_dim, head_dim - rotary_dim], dim=-1)
k_pe, k_nope = torch.split(key, [rotary_dim, head_dim - rotary_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, cos, sin, is_neox_style=is_neox_style)
k_pe = apply_rotary_emb(k_pe, cos, sin, is_neox_style=is_neox_style)
q = torch.cat((q_pe, q_nope), dim=-1)
k = torch.cat((k_pe, k_nope), dim=-1)
return q, k


def test_partial_rope_qk_inplace():
dtype = torch.float32
shapes = [
[64, 4, 1, 256, 64], # partial, HQ_IN_GRID
[64, 4, 1, 64, 64], # no partial
[1, 4, 1, 256, 64], # HK_IN_GRID
[1, 4, 1, 64, 64],
]
Comment on lines +40 to +45
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current test cases only cover Hk = 1, which does not exercise the GQA logic and hides the bug in the hq_id calculation. Adding a test case with Hk > 1 (e.g., Hq=8, Hk=2) is necessary to ensure correctness for Grouped Query Attention.

    shapes = [
        [32, 4, 1, 256, 64],  # partial
        [32, 8, 2, 128, 64],  # GQA
        [32, 4, 1, 64, 64]    # no partial
    ]

for T, Hq, Hk, D, D_ROPE in shapes:
for is_neox_style in [True, False]:
query = torch.randn((T, Hq, D), dtype=dtype, device="npu")
key = torch.randn((T, Hk, D), dtype=dtype, device="npu")
cos_sin = torch.randn((T, D_ROPE), dtype=dtype, device="npu")
_query = query.clone()
_key = key.clone()
# triton
res_q, res_k = partial_rope_qk_inplace(query, key, cos_sin, rotary_dim=D_ROPE, is_neox_style=is_neox_style)
# native
ans_q, ans_k = rope_native(_query, _key, cos_sin, rotary_dim=D_ROPE, is_neox_style=is_neox_style)
assert torch.allclose(res_q, ans_q)
assert torch.allclose(res_k, ans_k)


if __name__ == "__main__":
test_partial_rope_qk_inplace()
Loading