Skip to content

fix linear att cpu cache offload load speed#1317

Open
hiworldwzj wants to merge 3 commits into
mainfrom
wzj_fix
Open

fix linear att cpu cache offload load speed#1317
hiworldwzj wants to merge 3 commits into
mainfrom
wzj_fix

Conversation

@hiworldwzj
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the Triton kernel for linear attention CPU cache copying by implementing 128-bit alignment using torch.uint64 and refining the parallelization strategy for better performance. It also updates stream synchronization to use non-blocking wait calls and refactors the multi-turn benchmark tool from an asynchronous architecture to a threaded implementation using requests and ThreadPoolExecutor. The reviewer feedback identifies several improvement opportunities, including the removal of dead commented-out code, the correction of an inaccurate comment in the new speed test script, and the restoration of robust error handling in the benchmark tool to prevent the entire process from crashing due to individual request failures.

# TODO fa3 现在必须使用同步模式, 未来需要移除
g_infer_context.get_overlap_stream().synchronize()
torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream())
# g_infer_context.get_overlap_stream().synchronize()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out code should be removed to maintain a clean codebase.

Comment on lines +152 to +154
if response.status_code != 200:
err = response.text
raise RuntimeError(f"stream_one_turn failed: status={response.status_code}, body={err[:200]}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Changing the error handling to raise a RuntimeError makes the benchmark tool fragile. A single request failure will now crash the entire benchmark process. It is better to log the error and return None to allow other concurrent sessions to continue, maintaining the previous robust behavior.

Suggested change
if response.status_code != 200:
err = response.text
raise RuntimeError(f"stream_one_turn failed: status={response.status_code}, body={err[:200]}")
if response.status_code != 200:
err = response.text
print(f"\n[turn failed] status={response.status_code} body={err[:200]}")
return None

Comment on lines 196 to +197
if first_token_time is None:
return None
raise RuntimeError("stream_one_turn failed: no token received from stream")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Raising a RuntimeError here will terminate the benchmark session abruptly. Consider returning None to allow the benchmark to continue with other sessions.

Suggested change
if first_token_time is None:
return None
raise RuntimeError("stream_one_turn failed: no token received from stream")
if first_token_time is None:
print("\n[turn failed] no token received from stream")
return None

# ---------------------------------------------------------------------------
# Step 2 – derive sizes from the config
# ---------------------------------------------------------------------------
big_page_token_num = _env_args["cpu_cache_token_page_size"] # 512
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment # 512 is incorrect. Based on the _env_args definition, cpu_cache_token_page_size is 2048 * 8, which equals 16384.

Suggested change
big_page_token_num = _env_args["cpu_cache_token_page_size"] # 512
big_page_token_num = _env_args["cpu_cache_token_page_size"] # 16384

Comment on lines +123 to +133
# conv_shape = linear_config.get_conv_state_shape()
# cpu_kv_conv_state = torch.empty(
# (buffer_count, linear_config.linear_layer_num, *conv_shape),
# dtype=linear_config.conv_state_dtype, device="cuda",
# )

# ssm_shape = linear_config.get_ssm_state_shape() # (num_linear_v_heads, head_linear_k_dim, head_linear_v_dim)
# cpu_kv_ssm_state = torch.empty(
# (buffer_count, linear_config.linear_layer_num, *ssm_shape),
# dtype=linear_config.ssm_state_dtype, device="cuda",
# )
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Remove commented-out code blocks to maintain cleanliness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant