Skip to content

add partial_rope_qk_inplace#416

Open
xiaobaicxy wants to merge 2 commits intosgl-project:mainfrom
xiaobaicxy:main
Open

add partial_rope_qk_inplace#416
xiaobaicxy wants to merge 2 commits intosgl-project:mainfrom
xiaobaicxy:main

Conversation

@xiaobaicxy
Copy link
Copy Markdown
Contributor

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Triton-based in-place partial RoPE kernel for NPUs and a corresponding test suite. The review feedback identifies a logic error in the head indexing for Grouped Query Attention (GQA) and suggests adding input validation for the rotary dimension. Additionally, the reviewer notes that the kernel may fail for non-power-of-two dimensions and recommends expanding the test cases to include GQA scenarios.

Comment on lines +39 to +40
for g_id in range(groups):
hq_id = hk_id + g_id
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of hq_id is incorrect for Grouped Query Attention (GQA) where Hk > 1. It should be hk_id * groups + g_id to correctly map each query head to its corresponding key head. The current implementation would incorrectly reuse query head indices across different key heads (e.g., if Hk=2 and groups=2, hk_id=1 would incorrectly process hq_id=1, 2 instead of 2, 3).

Suggested change
for g_id in range(groups):
hq_id = hk_id + g_id
for g_id in range(groups):
hq_id = hk_id * groups + g_id

Comment on lines +68 to +70
T, Hq, D = query.shape
_, Hk, _ = key.shape
assert Hq % Hk == 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

It is important to validate that rotary_dim is within the bounds of the head dimension D and that it is an even number. If rotary_dim > D, the kernel will perform out-of-bounds memory accesses on the query and key tensors.

Suggested change
T, Hq, D = query.shape
_, Hk, _ = key.shape
assert Hq % Hk == 0
T, Hq, D = query.shape
_, Hk, _ = key.shape
assert Hq % Hk == 0
assert rotary_dim <= D
assert rotary_dim % 2 == 0

t_id = tl.program_id(0)
hk_id = tl.program_id(1)

d = tl.arange(0, D_ROPE // 2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Triton's tl.arange requires the size to be a power of 2. If D_ROPE // 2 is not a power of 2 (e.g., for a rotary dimension of 96), this kernel will fail to compile. Consider using a block size that is the next power of 2 and applying a mask during loads and stores to ensure compatibility with non-power-of-2 rotary dimensions.

Comment on lines +40 to +43
shapes = [
[32, 4, 1, 256, 64], # partial
[32, 4, 1, 64, 64] # no partial
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current test cases only cover Hk = 1, which does not exercise the GQA logic and hides the bug in the hq_id calculation. Adding a test case with Hk > 1 (e.g., Hq=8, Hk=2) is necessary to ensure correctness for Grouped Query Attention.

    shapes = [
        [32, 4, 1, 256, 64],  # partial
        [32, 8, 2, 128, 64],  # GQA
        [32, 4, 1, 64, 64]    # no partial
    ]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant