Skip to content

Fa4 support#1327

Open
blueswhen wants to merge 2 commits into
mainfrom
fa4
Open

Fa4 support#1327
blueswhen wants to merge 2 commits into
mainfrom
fa4

Conversation

@blueswhen
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces FlashAttention-4 (FA4) support and implements paged attention backends for FlashAttention-3 (FA3) and FlashInfer (FP and MLA variants) to support page sizes greater than 1. It also updates memory management and request queue components to handle paged memory allocations. The code review identified several critical issues and improvement opportunities: in-place modifications of kv_starts in FlashInfer states can corrupt original sequence length tensors and should be avoided by cloning; slicing non-contiguous tensors followed by .view() calls will raise runtime errors and should be replaced with .reshape(); hardcoded dimensions like qk_rope_head_dim should be dynamically retrieved from configurations; and allocating tensors on the CPU before moving them to the GPU is inefficient and should be replaced with direct device allocation.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

batch_size = self.infer_state.batch_size
device = self.infer_state.input_ids.device
q_starts = self.infer_state.b1_cu_q_seq_len.int()
kv_starts = self.infer_state.b1_cu_kv_seq_len.int()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying kv_starts in-place with kv_starts[1:] = ... will corrupt the original self.infer_state.b1_cu_kv_seq_len tensor because .int() returns a view or the same tensor if it is already int32. Use .clone() to avoid side effects.

Suggested change
kv_starts = self.infer_state.b1_cu_kv_seq_len.int()
kv_starts = self.infer_state.b1_cu_kv_seq_len.int().clone()

else:
self.kv_indices = torch.empty(buffer_len, dtype=torch.int32, device=device)

self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying self.kv_starts in-place will corrupt the original self.infer_state.b1_cu_kv_seq_len tensor. Use .clone() to avoid side effects.

Suggested change
self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int()
self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int().clone()

model = self.backend.model
device = self.infer_state.input_ids.device
batch_size = self.infer_state.batch_size
self.kv_starts = self.infer_state.b1_cu_kv_seq_len
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying self.kv_starts in-place will corrupt the original self.infer_state.b1_cu_kv_seq_len tensor. Use .clone() to avoid side effects.

Suggested change
self.kv_starts = self.infer_state.b1_cu_kv_seq_len
self.kv_starts = self.infer_state.b1_cu_kv_seq_len.clone()

Comment on lines +159 to +160
k_cache=k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim),
v_cache=k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, kv_lora_rank),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Slicing k along the last dimension makes it non-contiguous. Calling .view() on a non-contiguous tensor will raise a RuntimeError at runtime. Use .reshape() instead to safely handle non-contiguous layouts.

Suggested change
k_cache=k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim),
v_cache=k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, kv_lora_rank),
k_cache=k[:, :, -qk_rope_head_dim:].reshape(-1, self.backend.page_size, 1, qk_rope_head_dim),
v_cache=k[:, :, :-qk_rope_head_dim].reshape(-1, self.backend.page_size, 1, kv_lora_rank),

Comment on lines +179 to +180
k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, k.shape[-1] - qk_rope_head_dim),
k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Slicing k along the last dimension makes it non-contiguous. Calling .view() on a non-contiguous tensor will raise a RuntimeError at runtime. Use .reshape() instead to safely handle non-contiguous layouts.

Suggested change
k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, k.shape[-1] - qk_rope_head_dim),
k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim),
k[:, :, :-qk_rope_head_dim].reshape(-1, self.backend.page_size, 1, k.shape[-1] - qk_rope_head_dim),
k[:, :, -qk_rope_head_dim:].reshape(-1, self.backend.page_size, 1, qk_rope_head_dim),

self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl, alloc_func=torch.empty
):
q_nope, q_rope = q
qk_rope_head_dim = 64
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid hardcoding qk_rope_head_dim = 64. Use self.backend.model.qk_rope_head_dim to dynamically retrieve the dimension from the model configuration.

Suggested change
qk_rope_head_dim = 64
qk_rope_head_dim = self.backend.model.qk_rope_head_dim

)
assert v is None
q_nope, q_rope = q
qk_rope_head_dim = 64
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid hardcoding qk_rope_head_dim = 64. Use self.backend.qk_rope_head_dim to dynamically retrieve the dimension from the backend configuration.

Suggested change
qk_rope_head_dim = 64
qk_rope_head_dim = self.backend.qk_rope_head_dim

Comment on lines +23 to +24
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Allocating a tensor on CPU and then moving it to GPU using .to() is inefficient. Allocate the tensor directly on the target device using the device argument in torch.empty.

Suggested change
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32, device=get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32, device=get_current_device_id()),

Comment on lines +24 to +25
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Allocating a tensor on CPU and then moving it to GPU using .to() is inefficient. Allocate the tensor directly on the target device using the device argument in torch.empty.

Suggested change
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32, device=get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32, device=get_current_device_id()),

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant