-
Notifications
You must be signed in to change notification settings - Fork 127
add partial_rope_qk_inplace #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,121 @@ | ||||||||||||||||||
| import torch | ||||||||||||||||||
| import triton | ||||||||||||||||||
| import triton.language as tl | ||||||||||||||||||
| from sgl_kernel_npu.utils.triton_utils import get_device_properties | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @triton.jit | ||||||||||||||||||
| def partial_rope_qk_inplace_kernel( | ||||||||||||||||||
| query_ptr, | ||||||||||||||||||
| key_ptr, | ||||||||||||||||||
| cos_sin_ptr, | ||||||||||||||||||
| stride_qt, | ||||||||||||||||||
| stride_qh, | ||||||||||||||||||
| stride_qd, | ||||||||||||||||||
| stride_kt, | ||||||||||||||||||
| stride_kh, | ||||||||||||||||||
| stride_kd, | ||||||||||||||||||
| stride_ct, | ||||||||||||||||||
| stride_cd, | ||||||||||||||||||
| groups: tl.constexpr, | ||||||||||||||||||
| D_ROPE: tl.constexpr, | ||||||||||||||||||
| IS_NEOX_STYLE: tl.constexpr, | ||||||||||||||||||
| HQ_IN_GRID: tl.constexpr, | ||||||||||||||||||
| ): | ||||||||||||||||||
| t_id = tl.program_id(0) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| d = tl.arange(0, D_ROPE // 2) | ||||||||||||||||||
| if IS_NEOX_STYLE: | ||||||||||||||||||
| idx_even = d | ||||||||||||||||||
| idx_odd = d + D_ROPE // 2 | ||||||||||||||||||
| else: | ||||||||||||||||||
| idx_even = d * 2 | ||||||||||||||||||
| idx_odd = d * 2 + 1 | ||||||||||||||||||
|
|
||||||||||||||||||
| # cos / sin | ||||||||||||||||||
| cos = tl.load(cos_sin_ptr + t_id * stride_ct + d * stride_cd) # (D_ROPE // 2,) | ||||||||||||||||||
| sin = tl.load(cos_sin_ptr + t_id * stride_ct + (d + D_ROPE // 2) * stride_cd) # (D_ROPE // 2,) | ||||||||||||||||||
|
|
||||||||||||||||||
| if not HQ_IN_GRID: | ||||||||||||||||||
| hk_id = tl.program_id(1) | ||||||||||||||||||
| # ================= Q ================= | ||||||||||||||||||
| for g_id in range(groups): | ||||||||||||||||||
| hq_id = hk_id + g_id | ||||||||||||||||||
| q_base = query_ptr + t_id * stride_qt + hq_id * stride_qh | ||||||||||||||||||
| q1 = tl.load(q_base + idx_even * stride_qd) | ||||||||||||||||||
| q2 = tl.load(q_base + idx_odd * stride_qd) | ||||||||||||||||||
| q_out1 = (q1 * cos) - (q2 * sin) | ||||||||||||||||||
| q_out2 = (q1 * sin) + (q2 * cos) | ||||||||||||||||||
| tl.store(q_base + idx_even * stride_qd, q_out1) | ||||||||||||||||||
| tl.store(q_base + idx_odd * stride_qd, q_out2) | ||||||||||||||||||
|
|
||||||||||||||||||
| # ================= K ================= | ||||||||||||||||||
| k_base = key_ptr + t_id * stride_kt + hk_id * stride_kh | ||||||||||||||||||
| k1 = tl.load(k_base + idx_even * stride_kd) | ||||||||||||||||||
| k2 = tl.load(k_base + idx_odd * stride_kd) | ||||||||||||||||||
|
|
||||||||||||||||||
| k_out1 = (k1 * cos) - (k2 * sin) | ||||||||||||||||||
| k_out2 = (k1 * sin) + (k2 * cos) | ||||||||||||||||||
|
|
||||||||||||||||||
| tl.store(k_base + idx_even * stride_kd, k_out1) | ||||||||||||||||||
| tl.store(k_base + idx_odd * stride_kd, k_out2) | ||||||||||||||||||
| else: | ||||||||||||||||||
| hq_id = tl.program_id(1) | ||||||||||||||||||
| # ================= Q ================= | ||||||||||||||||||
| q_base = query_ptr + t_id * stride_qt + hq_id * stride_qh | ||||||||||||||||||
| q1 = tl.load(q_base + idx_even * stride_qd) | ||||||||||||||||||
| q2 = tl.load(q_base + idx_odd * stride_qd) | ||||||||||||||||||
| q_out1 = (q1 * cos) - (q2 * sin) | ||||||||||||||||||
| q_out2 = (q1 * sin) + (q2 * cos) | ||||||||||||||||||
| tl.store(q_base + idx_even * stride_qd, q_out1) | ||||||||||||||||||
| tl.store(q_base + idx_odd * stride_qd, q_out2) | ||||||||||||||||||
|
|
||||||||||||||||||
| # ================= K ================= | ||||||||||||||||||
| if hq_id % groups == 0: | ||||||||||||||||||
| hk_id = hq_id // groups | ||||||||||||||||||
| k_base = key_ptr + t_id * stride_kt + hk_id * stride_kh | ||||||||||||||||||
| k1 = tl.load(k_base + idx_even * stride_kd) | ||||||||||||||||||
| k2 = tl.load(k_base + idx_odd * stride_kd) | ||||||||||||||||||
| k_out1 = (k1 * cos) - (k2 * sin) | ||||||||||||||||||
| k_out2 = (k1 * sin) + (k2 * cos) | ||||||||||||||||||
| tl.store(k_base + idx_even * stride_kd, k_out1) | ||||||||||||||||||
| tl.store(k_base + idx_odd * stride_kd, k_out2) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def partial_rope_qk_inplace( | ||||||||||||||||||
| query, # [T, Hq, D] | ||||||||||||||||||
| key, # [T, Hk, D] | ||||||||||||||||||
| cos_sin, # [T, rotary_dim] | ||||||||||||||||||
| rotary_dim, | ||||||||||||||||||
| is_neox_style=False, | ||||||||||||||||||
| ): | ||||||||||||||||||
| T, Hq, D = query.shape | ||||||||||||||||||
| _, Hk, _ = key.shape | ||||||||||||||||||
| assert Hq % Hk == 0 | ||||||||||||||||||
|
Comment on lines
+93
to
+95
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is important to validate that
Suggested change
|
||||||||||||||||||
| _, vec_cors = get_device_properties() | ||||||||||||||||||
| grid = (T, Hq) | ||||||||||||||||||
| HQ_IN_GRID = True | ||||||||||||||||||
| if T * Hk >= vec_cors: | ||||||||||||||||||
| grid = (T, Hk) | ||||||||||||||||||
| HQ_IN_GRID = False | ||||||||||||||||||
|
|
||||||||||||||||||
| partial_rope_qk_inplace_kernel[grid]( | ||||||||||||||||||
| query, | ||||||||||||||||||
| key, | ||||||||||||||||||
| cos_sin, | ||||||||||||||||||
| query.stride(0), | ||||||||||||||||||
| query.stride(1), | ||||||||||||||||||
| query.stride(2), | ||||||||||||||||||
| key.stride(0), | ||||||||||||||||||
| key.stride(1), | ||||||||||||||||||
| key.stride(2), | ||||||||||||||||||
| cos_sin.stride(0), | ||||||||||||||||||
| cos_sin.stride(1), | ||||||||||||||||||
| groups=Hq // Hk, | ||||||||||||||||||
| D_ROPE=rotary_dim, | ||||||||||||||||||
| IS_NEOX_STYLE=is_neox_style, | ||||||||||||||||||
| HQ_IN_GRID=HQ_IN_GRID, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| return query, key | ||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| import torch | ||
| from sgl_kernel_npu.norm.partial_rope_qk_inplace import partial_rope_qk_inplace | ||
|
|
||
|
|
||
| def apply_rotary_emb( | ||
| x: torch.Tensor, | ||
| cos: torch.Tensor, | ||
| sin: torch.Tensor, | ||
| is_neox_style=False, | ||
| ) -> torch.Tensor: | ||
| cos = cos.unsqueeze(-2) | ||
| sin = sin.unsqueeze(-2) | ||
| if is_neox_style: | ||
| x1, x2 = torch.chunk(x, 2, dim=-1) | ||
| else: | ||
| x1 = x[..., ::2] | ||
| x2 = x[..., 1::2] | ||
| o1 = x1 * cos - x2 * sin | ||
| o2 = x2 * cos + x1 * sin | ||
| if is_neox_style: | ||
| return torch.cat((o1, o2), dim=-1) | ||
| else: | ||
| return torch.stack((o1, o2), dim=-1).flatten(-2) | ||
|
|
||
|
|
||
| def rope_native(query, key, cos_sin, rotary_dim, is_neox_style=False): | ||
| head_dim = query.shape[-1] | ||
| cos, sin = cos_sin.chunk(2, dim=-1) | ||
| q_pe, q_nope = torch.split(query, [rotary_dim, head_dim - rotary_dim], dim=-1) | ||
| k_pe, k_nope = torch.split(key, [rotary_dim, head_dim - rotary_dim], dim=-1) | ||
| q_pe = apply_rotary_emb(q_pe, cos, sin, is_neox_style=is_neox_style) | ||
| k_pe = apply_rotary_emb(k_pe, cos, sin, is_neox_style=is_neox_style) | ||
| q = torch.cat((q_pe, q_nope), dim=-1) | ||
| k = torch.cat((k_pe, k_nope), dim=-1) | ||
| return q, k | ||
|
|
||
|
|
||
| def test_partial_rope_qk_inplace(): | ||
| dtype = torch.float32 | ||
| shapes = [ | ||
| [64, 4, 1, 256, 64], # partial, HQ_IN_GRID | ||
| [64, 4, 1, 64, 64], # no partial | ||
| [1, 4, 1, 256, 64], # HK_IN_GRID | ||
| [1, 4, 1, 64, 64], | ||
| ] | ||
|
Comment on lines
+40
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current test cases only cover shapes = [
[32, 4, 1, 256, 64], # partial
[32, 8, 2, 128, 64], # GQA
[32, 4, 1, 64, 64] # no partial
] |
||
| for T, Hq, Hk, D, D_ROPE in shapes: | ||
| for is_neox_style in [True, False]: | ||
| query = torch.randn((T, Hq, D), dtype=dtype, device="npu") | ||
| key = torch.randn((T, Hk, D), dtype=dtype, device="npu") | ||
| cos_sin = torch.randn((T, D_ROPE), dtype=dtype, device="npu") | ||
| _query = query.clone() | ||
| _key = key.clone() | ||
| # triton | ||
| res_q, res_k = partial_rope_qk_inplace(query, key, cos_sin, rotary_dim=D_ROPE, is_neox_style=is_neox_style) | ||
| # native | ||
| ans_q, ans_k = rope_native(_query, _key, cos_sin, rotary_dim=D_ROPE, is_neox_style=is_neox_style) | ||
| assert torch.allclose(res_q, ans_q) | ||
| assert torch.allclose(res_k, ans_k) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_partial_rope_qk_inplace() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Triton's
tl.arangerequires the size to be a power of 2. IfD_ROPE // 2is not a power of 2 (e.g., for a rotary dimension of 96), this kernel will fail to compile. Consider using a block size that is the next power of 2 and applying a mask during loads and stores to ensure compatibility with non-power-of-2 rotary dimensions.