diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/norm/partial_rope_qk_inplace.py b/python/sgl_kernel_npu/sgl_kernel_npu/norm/partial_rope_qk_inplace.py new file mode 100644 index 000000000..6ac59add7 --- /dev/null +++ b/python/sgl_kernel_npu/sgl_kernel_npu/norm/partial_rope_qk_inplace.py @@ -0,0 +1,121 @@ +import torch +import triton +import triton.language as tl +from sgl_kernel_npu.utils.triton_utils import get_device_properties + + +@triton.jit +def partial_rope_qk_inplace_kernel( + query_ptr, + key_ptr, + cos_sin_ptr, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_ct, + stride_cd, + groups: tl.constexpr, + D_ROPE: tl.constexpr, + IS_NEOX_STYLE: tl.constexpr, + HQ_IN_GRID: tl.constexpr, +): + t_id = tl.program_id(0) + + + d = tl.arange(0, D_ROPE // 2) + if IS_NEOX_STYLE: + idx_even = d + idx_odd = d + D_ROPE // 2 + else: + idx_even = d * 2 + idx_odd = d * 2 + 1 + + # cos / sin + cos = tl.load(cos_sin_ptr + t_id * stride_ct + d * stride_cd) # (D_ROPE // 2,) + sin = tl.load(cos_sin_ptr + t_id * stride_ct + (d + D_ROPE // 2) * stride_cd) # (D_ROPE // 2,) + + if not HQ_IN_GRID: + hk_id = tl.program_id(1) + # ================= Q ================= + for g_id in range(groups): + hq_id = hk_id + g_id + q_base = query_ptr + t_id * stride_qt + hq_id * stride_qh + q1 = tl.load(q_base + idx_even * stride_qd) + q2 = tl.load(q_base + idx_odd * stride_qd) + q_out1 = (q1 * cos) - (q2 * sin) + q_out2 = (q1 * sin) + (q2 * cos) + tl.store(q_base + idx_even * stride_qd, q_out1) + tl.store(q_base + idx_odd * stride_qd, q_out2) + + # ================= K ================= + k_base = key_ptr + t_id * stride_kt + hk_id * stride_kh + k1 = tl.load(k_base + idx_even * stride_kd) + k2 = tl.load(k_base + idx_odd * stride_kd) + + k_out1 = (k1 * cos) - (k2 * sin) + k_out2 = (k1 * sin) + (k2 * cos) + + tl.store(k_base + idx_even * stride_kd, k_out1) + tl.store(k_base + idx_odd * stride_kd, k_out2) + else: + hq_id = tl.program_id(1) + # ================= Q ================= + q_base = query_ptr + t_id * stride_qt + hq_id * stride_qh + q1 = tl.load(q_base + idx_even * stride_qd) + q2 = tl.load(q_base + idx_odd * stride_qd) + q_out1 = (q1 * cos) - (q2 * sin) + q_out2 = (q1 * sin) + (q2 * cos) + tl.store(q_base + idx_even * stride_qd, q_out1) + tl.store(q_base + idx_odd * stride_qd, q_out2) + + # ================= K ================= + if hq_id % groups == 0: + hk_id = hq_id // groups + k_base = key_ptr + t_id * stride_kt + hk_id * stride_kh + k1 = tl.load(k_base + idx_even * stride_kd) + k2 = tl.load(k_base + idx_odd * stride_kd) + k_out1 = (k1 * cos) - (k2 * sin) + k_out2 = (k1 * sin) + (k2 * cos) + tl.store(k_base + idx_even * stride_kd, k_out1) + tl.store(k_base + idx_odd * stride_kd, k_out2) + + +def partial_rope_qk_inplace( + query, # [T, Hq, D] + key, # [T, Hk, D] + cos_sin, # [T, rotary_dim] + rotary_dim, + is_neox_style=False, +): + T, Hq, D = query.shape + _, Hk, _ = key.shape + assert Hq % Hk == 0 + _, vec_cors = get_device_properties() + grid = (T, Hq) + HQ_IN_GRID = True + if T * Hk >= vec_cors: + grid = (T, Hk) + HQ_IN_GRID = False + + partial_rope_qk_inplace_kernel[grid]( + query, + key, + cos_sin, + query.stride(0), + query.stride(1), + query.stride(2), + key.stride(0), + key.stride(1), + key.stride(2), + cos_sin.stride(0), + cos_sin.stride(1), + groups=Hq // Hk, + D_ROPE=rotary_dim, + IS_NEOX_STYLE=is_neox_style, + HQ_IN_GRID=HQ_IN_GRID, + ) + + return query, key diff --git a/tests/python/sgl_kernel_npu/test_partial_rope_qk_inplace.py b/tests/python/sgl_kernel_npu/test_partial_rope_qk_inplace.py new file mode 100644 index 000000000..7f42c6d86 --- /dev/null +++ b/tests/python/sgl_kernel_npu/test_partial_rope_qk_inplace.py @@ -0,0 +1,62 @@ +import torch +from sgl_kernel_npu.norm.partial_rope_qk_inplace import partial_rope_qk_inplace + + +def apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style=False, +) -> torch.Tensor: + cos = cos.unsqueeze(-2) + sin = sin.unsqueeze(-2) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +def rope_native(query, key, cos_sin, rotary_dim, is_neox_style=False): + head_dim = query.shape[-1] + cos, sin = cos_sin.chunk(2, dim=-1) + q_pe, q_nope = torch.split(query, [rotary_dim, head_dim - rotary_dim], dim=-1) + k_pe, k_nope = torch.split(key, [rotary_dim, head_dim - rotary_dim], dim=-1) + q_pe = apply_rotary_emb(q_pe, cos, sin, is_neox_style=is_neox_style) + k_pe = apply_rotary_emb(k_pe, cos, sin, is_neox_style=is_neox_style) + q = torch.cat((q_pe, q_nope), dim=-1) + k = torch.cat((k_pe, k_nope), dim=-1) + return q, k + + +def test_partial_rope_qk_inplace(): + dtype = torch.float32 + shapes = [ + [64, 4, 1, 256, 64], # partial, HQ_IN_GRID + [64, 4, 1, 64, 64], # no partial + [1, 4, 1, 256, 64], # HK_IN_GRID + [1, 4, 1, 64, 64], + ] + for T, Hq, Hk, D, D_ROPE in shapes: + for is_neox_style in [True, False]: + query = torch.randn((T, Hq, D), dtype=dtype, device="npu") + key = torch.randn((T, Hk, D), dtype=dtype, device="npu") + cos_sin = torch.randn((T, D_ROPE), dtype=dtype, device="npu") + _query = query.clone() + _key = key.clone() + # triton + res_q, res_k = partial_rope_qk_inplace(query, key, cos_sin, rotary_dim=D_ROPE, is_neox_style=is_neox_style) + # native + ans_q, ans_k = rope_native(_query, _key, cos_sin, rotary_dim=D_ROPE, is_neox_style=is_neox_style) + assert torch.allclose(res_q, ans_q) + assert torch.allclose(res_k, ans_k) + + +if __name__ == "__main__": + test_partial_rope_qk_inplace()