add partial_rope_qk_inplace#416
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a Triton-based in-place partial RoPE kernel for NPUs and a corresponding test suite. The review feedback identifies a logic error in the head indexing for Grouped Query Attention (GQA) and suggests adding input validation for the rotary dimension. Additionally, the reviewer notes that the kernel may fail for non-power-of-two dimensions and recommends expanding the test cases to include GQA scenarios.
| for g_id in range(groups): | ||
| hq_id = hk_id + g_id |
There was a problem hiding this comment.
The calculation of hq_id is incorrect for Grouped Query Attention (GQA) where Hk > 1. It should be hk_id * groups + g_id to correctly map each query head to its corresponding key head. The current implementation would incorrectly reuse query head indices across different key heads (e.g., if Hk=2 and groups=2, hk_id=1 would incorrectly process hq_id=1, 2 instead of 2, 3).
| for g_id in range(groups): | |
| hq_id = hk_id + g_id | |
| for g_id in range(groups): | |
| hq_id = hk_id * groups + g_id |
| T, Hq, D = query.shape | ||
| _, Hk, _ = key.shape | ||
| assert Hq % Hk == 0 |
There was a problem hiding this comment.
It is important to validate that rotary_dim is within the bounds of the head dimension D and that it is an even number. If rotary_dim > D, the kernel will perform out-of-bounds memory accesses on the query and key tensors.
| T, Hq, D = query.shape | |
| _, Hk, _ = key.shape | |
| assert Hq % Hk == 0 | |
| T, Hq, D = query.shape | |
| _, Hk, _ = key.shape | |
| assert Hq % Hk == 0 | |
| assert rotary_dim <= D | |
| assert rotary_dim % 2 == 0 |
| t_id = tl.program_id(0) | ||
| hk_id = tl.program_id(1) | ||
|
|
||
| d = tl.arange(0, D_ROPE // 2) |
There was a problem hiding this comment.
Triton's tl.arange requires the size to be a power of 2. If D_ROPE // 2 is not a power of 2 (e.g., for a rotary dimension of 96), this kernel will fail to compile. Consider using a block size that is the next power of 2 and applying a mask during loads and stores to ensure compatibility with non-power-of-2 rotary dimensions.
| shapes = [ | ||
| [32, 4, 1, 256, 64], # partial | ||
| [32, 4, 1, 64, 64] # no partial | ||
| ] |
There was a problem hiding this comment.
The current test cases only cover Hk = 1, which does not exercise the GQA logic and hides the bug in the hq_id calculation. Adding a test case with Hk > 1 (e.g., Hq=8, Hk=2) is necessary to ensure correctness for Grouped Query Attention.
shapes = [
[32, 4, 1, 256, 64], # partial
[32, 8, 2, 128, 64], # GQA
[32, 4, 1, 64, 64] # no partial
]
No description provided.