Skip to content

support ssd chunk scan triton & ssd chunk state triton on npu#448

Open
sigama-w wants to merge 2 commits into
sgl-project:mainfrom
sigama-w:main
Open

support ssd chunk scan triton & ssd chunk state triton on npu#448
sigama-w wants to merge 2 commits into
sgl-project:mainfrom
sigama-w:main

Conversation

@sigama-w
Copy link
Copy Markdown

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements NPU-specific Mamba kernels for causal convolution and SSD operations. Feedback highlights a bug where the chunk scan returns the wrong tensor, a missing null check for the output parameter, and the need to replace CUDA device contexts with NPU-specific ones for proper kernel execution.

IS_TRITON_22=TRITON_22,
HAS_INITSTATES=initial_states is not None,
)
return out_x No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function currently returns out_x, which is None if z is None. This is likely a bug as the user would expect the result of the scan (which is in out). It is recommended to return out as the primary result.

Suggested change
return out_x
return out

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the GPU, this is how it is implemented. Keep it consistent with that.

else:
chunk_indices, chunk_offsets = None, None

assert out.shape == x.shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The out parameter is optional in the function signature but its shape is asserted immediately without checking for None. This will cause an AttributeError if the user does not provide an output tensor. It is better to initialize it if it's missing.

    if out is None:
        out = torch.empty_like(x)
    else:
        assert out.shape == x.shape

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the GPU, this is how it is implemented. Keep it consistent with that.

Comment thread python/sgl_kernel_npu/sgl_kernel_npu/mamba/ssd_chunk_state.py Outdated
batch * nchunks,
nheads,
)
with torch.cuda.device(x.device.index):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since this kernel is specifically for NPU, you should use torch.npu.device instead of torch.cuda.device.

Suggested change
with torch.cuda.device(x.device.index):
with torch.npu.device(x.device.index):

Comment thread python/sgl_kernel_npu/sgl_kernel_npu/mamba/ssd_chunk_state.py Outdated
Comment thread python/sgl_kernel_npu/sgl_kernel_npu/mamba/ssd_chunk_scan.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant