support ssd chunk scan triton & ssd chunk state triton on npu#448
support ssd chunk scan triton & ssd chunk state triton on npu#448sigama-w wants to merge 2 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements NPU-specific Mamba kernels for causal convolution and SSD operations. Feedback highlights a bug where the chunk scan returns the wrong tensor, a missing null check for the output parameter, and the need to replace CUDA device contexts with NPU-specific ones for proper kernel execution.
| IS_TRITON_22=TRITON_22, | ||
| HAS_INITSTATES=initial_states is not None, | ||
| ) | ||
| return out_x No newline at end of file |
There was a problem hiding this comment.
On the GPU, this is how it is implemented. Keep it consistent with that.
| else: | ||
| chunk_indices, chunk_offsets = None, None | ||
|
|
||
| assert out.shape == x.shape |
There was a problem hiding this comment.
The out parameter is optional in the function signature but its shape is asserted immediately without checking for None. This will cause an AttributeError if the user does not provide an output tensor. It is better to initialize it if it's missing.
if out is None:
out = torch.empty_like(x)
else:
assert out.shape == x.shapeThere was a problem hiding this comment.
On the GPU, this is how it is implemented. Keep it consistent with that.
| batch * nchunks, | ||
| nheads, | ||
| ) | ||
| with torch.cuda.device(x.device.index): |
No description provided.