Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces FlashAttention-4 (FA4) support and implements paged attention backends for FlashAttention-3 (FA3) and FlashInfer (FP and MLA variants) to support page sizes greater than 1. It also updates memory management and request queue components to handle paged memory allocations. The code review identified several critical issues and improvement opportunities: in-place modifications of kv_starts in FlashInfer states can corrupt original sequence length tensors and should be avoided by cloning; slicing non-contiguous tensors followed by .view() calls will raise runtime errors and should be replaced with .reshape(); hardcoded dimensions like qk_rope_head_dim should be dynamically retrieved from configurations; and allocating tensors on the CPU before moving them to the GPU is inefficient and should be replaced with direct device allocation.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| batch_size = self.infer_state.batch_size | ||
| device = self.infer_state.input_ids.device | ||
| q_starts = self.infer_state.b1_cu_q_seq_len.int() | ||
| kv_starts = self.infer_state.b1_cu_kv_seq_len.int() |
There was a problem hiding this comment.
Modifying kv_starts in-place with kv_starts[1:] = ... will corrupt the original self.infer_state.b1_cu_kv_seq_len tensor because .int() returns a view or the same tensor if it is already int32. Use .clone() to avoid side effects.
| kv_starts = self.infer_state.b1_cu_kv_seq_len.int() | |
| kv_starts = self.infer_state.b1_cu_kv_seq_len.int().clone() |
| else: | ||
| self.kv_indices = torch.empty(buffer_len, dtype=torch.int32, device=device) | ||
|
|
||
| self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() |
There was a problem hiding this comment.
| model = self.backend.model | ||
| device = self.infer_state.input_ids.device | ||
| batch_size = self.infer_state.batch_size | ||
| self.kv_starts = self.infer_state.b1_cu_kv_seq_len |
There was a problem hiding this comment.
| k_cache=k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim), | ||
| v_cache=k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, kv_lora_rank), |
There was a problem hiding this comment.
Slicing k along the last dimension makes it non-contiguous. Calling .view() on a non-contiguous tensor will raise a RuntimeError at runtime. Use .reshape() instead to safely handle non-contiguous layouts.
| k_cache=k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim), | |
| v_cache=k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, kv_lora_rank), | |
| k_cache=k[:, :, -qk_rope_head_dim:].reshape(-1, self.backend.page_size, 1, qk_rope_head_dim), | |
| v_cache=k[:, :, :-qk_rope_head_dim].reshape(-1, self.backend.page_size, 1, kv_lora_rank), |
| k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, k.shape[-1] - qk_rope_head_dim), | ||
| k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim), |
There was a problem hiding this comment.
Slicing k along the last dimension makes it non-contiguous. Calling .view() on a non-contiguous tensor will raise a RuntimeError at runtime. Use .reshape() instead to safely handle non-contiguous layouts.
| k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, k.shape[-1] - qk_rope_head_dim), | |
| k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim), | |
| k[:, :, :-qk_rope_head_dim].reshape(-1, self.backend.page_size, 1, k.shape[-1] - qk_rope_head_dim), | |
| k[:, :, -qk_rope_head_dim:].reshape(-1, self.backend.page_size, 1, qk_rope_head_dim), |
| self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl, alloc_func=torch.empty | ||
| ): | ||
| q_nope, q_rope = q | ||
| qk_rope_head_dim = 64 |
| ) | ||
| assert v is None | ||
| q_nope, q_rope = q | ||
| qk_rope_head_dim = 64 |
| torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), | ||
| torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), |
There was a problem hiding this comment.
Allocating a tensor on CPU and then moving it to GPU using .to() is inefficient. Allocate the tensor directly on the target device using the device argument in torch.empty.
| torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), | |
| torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), | |
| torch.empty(shared_len, dtype=torch.int32, device=get_current_device_id()), | |
| torch.empty(shared_len, dtype=torch.int32, device=get_current_device_id()), |
| torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), | ||
| torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), |
There was a problem hiding this comment.
Allocating a tensor on CPU and then moving it to GPU using .to() is inefficient. Allocate the tensor directly on the target device using the device argument in torch.empty.
| torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), | |
| torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), | |
| torch.empty(shared_len, dtype=torch.int32, device=get_current_device_id()), | |
| torch.empty(shared_len, dtype=torch.int32, device=get_current_device_id()), |
No description provided.