Skip to content

Commit 100a788

Browse files
authored
[Feat] add layerwise KV load patch for vllm-ascend v0.18.0 on DSA model (#911)
## Purpose Fix a bug in vllm-ascend v0.18.0 where `wait_for_kv_layer_from_connector` is skipped when the `mlapo` (fused operator) branch is active in `AscendSFAImpl.forward`. ## Modifications - Created `v0180` patch directory structure for version-specific patches. - Added `vllm_ascend/pc/attention/sfa_v1.py` containing the patched `forward` method: - Moved `wait_for_kv_layer_from_connector(layer_name)` from the `else` block to the function entry point. - Ensures KV cache synchronization happens before any computation, covering both `mlapo` and `native` paths. - Added `vllm_ascend/sparse_ascend_patch.py` to register the patch via `@when_imported`. - Updated `apply_patch.py` to apply this patch when vllm-ascend version is detected as `0.18.0`. ## Test - **Environment**: vllm-ascend v0.18.0, UCM enabled (layerwise). - **Verification**: - **GLM5-W8A8 (mlapo)**: Verify `wait_for_kv_layer_from_connector` is called and inference output is correct. - **GLM5-W4A8 (native)**: Verify existing behavior is unchanged and output is correct. - **Logs**: Check for "UCM patch (v0.18): AscendSFAImpl.forward replaced..." during startup.
1 parent e9bb421 commit 100a788

7 files changed

Lines changed: 306 additions & 2 deletions

File tree

ucm/integration/vllm/patch/apply_patch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ def _norm(v: Optional[str]) -> Optional[str]:
5353
if not v:
5454
return None
5555
v = str(v).strip()
56-
# common suffixes: 0.11.0+xxx / 0.11.0.post1
56+
# common suffixes: 0.11.0+xxx / 0.11.0.post1 / 0.11.0rc1
5757
v = v.split("+", 1)[0]
5858
v = v.split(".post", 1)[0]
59+
v = v.split("rc", 1)[0]
5960
return v
6061

6162
try:
@@ -102,7 +103,7 @@ def get_vllm_version() -> Optional[str]:
102103

103104
def get_supported_versions() -> list[str]:
104105
"""Get patch-required vLLM versions."""
105-
return ["0.11.0"]
106+
return ["0.11.0", "0.18.0"]
106107

107108

108109
def apply_all_patches() -> None:
@@ -148,6 +149,9 @@ def apply_all_patches() -> None:
148149
if ENABLE_SPARSE:
149150
logger.info("UCM patching vllm-ascend for sparse...")
150151
import ucm.integration.vllm.patch.v0110.vllm_ascend.sparse_ascend_patch
152+
case "0.18.0":
153+
logger.info("UCM patching vllm-ascend for pc...")
154+
import ucm.integration.vllm.patch.v0180.vllm_ascend.pc_ascend_patch
151155
case _:
152156
pass
153157

ucm/integration/vllm/patch/v0180/__init__.py

Whitespace-only changes.

ucm/integration/vllm/patch/v0180/vllm_ascend/__init__.py

Whitespace-only changes.

ucm/integration/vllm/patch/v0180/vllm_ascend/pc/__init__.py

Whitespace-only changes.

ucm/integration/vllm/patch/v0180/vllm_ascend/pc/attention/__init__.py

Whitespace-only changes.
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import torch
2+
import torch_npu
3+
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
4+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
5+
from vllm_ascend.attention.mla_v1 import (
6+
MAX_O_PROJ_PREFETCH_SIZE,
7+
MLAPO_MAX_SUPPORTED_TOKENS,
8+
)
9+
from vllm_ascend.attention.utils import (
10+
maybe_save_kv_layer_to_connector,
11+
wait_for_kv_layer_from_connector,
12+
)
13+
from vllm_ascend.device.device_op import DeviceOperator
14+
from vllm_ascend.distributed.utils import all_gather_async
15+
from vllm_ascend.ops.layer_shard_linear import (
16+
is_hidden_layer,
17+
reach_layer_for_shard_weight_series,
18+
)
19+
from vllm_ascend.utils import get_weight_prefetch_method
20+
21+
22+
class AscendSFAImpl:
23+
def forward(
24+
self,
25+
layer_name,
26+
hidden_states: torch.Tensor, # query in unified attn
27+
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
28+
attn_metadata,
29+
need_gather_q_kv: bool = False,
30+
output: torch.Tensor | None = None,
31+
) -> torch.Tensor:
32+
assert output is not None, "Output tensor must be provided."
33+
if attn_metadata is None:
34+
# Profiling run.
35+
if self.enable_dsa_cp_with_layer_shard and not _EXTRA_CTX.in_profile_run:
36+
for layer in self.layer_sharding_kwargs or []:
37+
if is_hidden_layer(layer):
38+
reach_layer_for_shard_weight_series(layer)
39+
return output.fill_(0)
40+
41+
cos = attn_metadata.cos
42+
sin = attn_metadata.sin
43+
slot_mapping = attn_metadata.slot_mapping
44+
slot_mapping_cp = None
45+
if self.enable_dsa_cp:
46+
assert attn_metadata.dsa_cp_context is not None
47+
slot_mapping_cp = attn_metadata.dsa_cp_context.slot_mapping_cp
48+
actual_seq_lengths_query = (
49+
attn_metadata.dsa_cp_context.actual_seq_lengths_query
50+
)
51+
actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key
52+
else:
53+
actual_seq_lengths_query = attn_metadata.cum_query_lens
54+
actual_seq_lengths_key = attn_metadata.seq_lens
55+
56+
# Inputs and outputs may be padded for CUDA graphs
57+
num_input_tokens = attn_metadata.num_input_tokens
58+
output_padded = output
59+
60+
# all-gather o_proj weight for prefill stage of PD mix node
61+
o_proj_full_handle = None
62+
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj
63+
# weight for prefill stage.
64+
full_gather_o_proj_enabled = (
65+
self.enable_dsa_cp_with_o_proj_tp
66+
and attn_metadata.attn_state
67+
not in {
68+
AscendAttentionState.DecodeOnly,
69+
AscendAttentionState.SpecDecoding,
70+
}
71+
)
72+
73+
# run mlapo ops when dsa-cp is disabled, and ensure that num_tokens satisfies the count limitation
74+
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
75+
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_with_mlapo(
76+
hidden_states=hidden_states,
77+
kv_cache=kv_cache,
78+
cos=cos,
79+
sin=sin,
80+
slot_mapping=slot_mapping,
81+
num_input_tokens=num_input_tokens,
82+
)
83+
k_li, k_li_scale = self.indexer_select_pre_process(
84+
x=hidden_states, cos=cos, sin=sin
85+
)
86+
# [patch] Add 'wait_for_kv_layer_from_connector' call for mlapo path
87+
wait_for_kv_layer_from_connector(layer_name)
88+
# native
89+
else:
90+
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
91+
weight_prefetch_method = get_weight_prefetch_method()
92+
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
93+
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
94+
)
95+
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
96+
q_c, kv_no_split = qkv_lora.split(
97+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
98+
dim=-1,
99+
)
100+
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
101+
q_c = self.q_a_layernorm(q_c)
102+
103+
k_li, k_li_scale = self.indexer_select_pre_process(
104+
x=hidden_states, cos=cos, sin=sin
105+
)
106+
107+
wait_for_kv_layer_from_connector(layer_name)
108+
109+
if self.enable_dsa_cp:
110+
assert slot_mapping_cp is not None
111+
k_pe, k_nope = self.exec_kv(
112+
kv_no_split, cos, sin, kv_cache, slot_mapping_cp, attn_metadata
113+
)
114+
else:
115+
k_pe, k_nope = self.exec_kv(
116+
kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata
117+
)
118+
119+
if self.enable_dsa_cp:
120+
assert k_pe is not None
121+
assert k_nope is not None
122+
assert k_li is not None
123+
async_op = (
124+
self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled
125+
)
126+
# support all_gather kv async for communication calculation overlap
127+
if not self.use_sparse_c8_indexer:
128+
fused_kv_no_split, kv_ag_handle = all_gather_async(
129+
torch.cat(
130+
[
131+
k_pe.view(-1, k_pe.shape[-1]),
132+
k_nope.view(-1, k_nope.shape[-1]),
133+
k_li.view(-1, k_li.shape[-1]),
134+
],
135+
dim=1,
136+
),
137+
get_tp_group(),
138+
async_op=async_op,
139+
)
140+
else:
141+
# due to different dtypes, we have to split commu pass
142+
assert k_li_scale is not None
143+
fused_kv_no_split, _ = all_gather_async(
144+
torch.cat(
145+
[
146+
k_pe.view(-1, k_pe.shape[-1]),
147+
k_nope.view(-1, k_nope.shape[-1]),
148+
],
149+
dim=1,
150+
),
151+
get_tp_group(),
152+
async_op=async_op,
153+
)
154+
k_li, _ = all_gather_async(
155+
k_li,
156+
get_tp_group(),
157+
async_op=async_op,
158+
)
159+
k_li_scale, kv_ag_handle = all_gather_async(
160+
k_li_scale,
161+
get_tp_group(),
162+
async_op=async_op,
163+
)
164+
165+
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
166+
q_pe = self.rope_single(q_pe, cos, sin)
167+
168+
if self.enable_dsa_cp:
169+
if kv_ag_handle is not None:
170+
kv_ag_handle.wait()
171+
172+
if self.enable_dsa_cp_with_layer_shard:
173+
for layer in self.layer_sharding_kwargs or []:
174+
if is_hidden_layer(layer):
175+
reach_layer_for_shard_weight_series(layer)
176+
elif full_gather_o_proj_enabled:
177+
_, o_proj_full_handle = all_gather_async(
178+
self.o_proj_tp_weight,
179+
get_tp_group(),
180+
output=AscendSFAImpl.o_proj_full_pool,
181+
)
182+
183+
if kv_cache is not None:
184+
assert fused_kv_no_split is not None
185+
if not self.use_sparse_c8_indexer:
186+
k_pe, k_nope, k_li = fused_kv_no_split.split(
187+
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim],
188+
dim=-1,
189+
)
190+
else:
191+
k_pe, k_nope = fused_kv_no_split.split(
192+
[self.qk_rope_head_dim, self.kv_lora_rank], dim=-1
193+
)
194+
k_nope = k_nope.view(k_nope.shape[0], 1, -1)
195+
k_pe = k_pe.view(k_pe.shape[0], 1, -1)
196+
DeviceOperator.reshape_and_cache(
197+
key=k_nope[: attn_metadata.num_actual_tokens],
198+
value=k_pe[: attn_metadata.num_actual_tokens],
199+
key_cache=kv_cache[0],
200+
value_cache=kv_cache[1],
201+
slot_mapping=slot_mapping[: attn_metadata.num_actual_tokens],
202+
)
203+
204+
k_li = self._get_full_kv(k_li, attn_metadata)
205+
206+
if kv_cache is not None:
207+
if self.is_kv_producer:
208+
attn_metadata.reshape_cache_event = torch.npu.Event()
209+
torch_npu.npu_scatter_nd_update_(
210+
kv_cache[2].view(-1, k_li.shape[-1]),
211+
slot_mapping.view(-1, 1),
212+
k_li.view(-1, k_li.shape[-1]),
213+
) # b, s, n, d
214+
if self.use_sparse_c8_indexer:
215+
assert len(kv_cache) == 4
216+
assert k_li_scale is not None
217+
torch_npu.npu_scatter_nd_update_(
218+
kv_cache[3].view(-1, k_li_scale.shape[-1]),
219+
slot_mapping.view(-1, 1),
220+
k_li_scale.view(-1, k_li_scale.shape[-1]),
221+
)
222+
if self.is_kv_producer:
223+
attn_metadata.reshape_cache_event.record()
224+
225+
topk_indices = self.indexer_select_post_process(
226+
x=hidden_states,
227+
q_c=q_c,
228+
kv_cache=kv_cache,
229+
attn_metadata=attn_metadata,
230+
cos=cos,
231+
sin=sin,
232+
actual_seq_lengths_query=actual_seq_lengths_query,
233+
actual_seq_lengths_key=actual_seq_lengths_key,
234+
)
235+
236+
attn_output = self._execute_sparse_flash_attention_process(
237+
ql_nope,
238+
q_pe,
239+
kv_cache,
240+
topk_indices,
241+
attn_metadata,
242+
actual_seq_lengths_query,
243+
actual_seq_lengths_key,
244+
)
245+
246+
attn_output = self._v_up_proj(attn_output)
247+
weight_prefetch_method = get_weight_prefetch_method()
248+
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
249+
inputs=self.o_proj.weight,
250+
dependency=attn_output,
251+
max_size=MAX_O_PROJ_PREFETCH_SIZE,
252+
linear_layer=self.o_proj,
253+
)
254+
255+
if self.enable_dsa_cp_with_o_proj_tp:
256+
# When using SFA-CP with pd mixed, o_proj has two cases:
257+
# 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1.
258+
# 2. decode: all-to-all the hidden_state before the o_proj forward.
259+
result, require_o_proj_forward = (
260+
self._handle_o_proj_weight_switch_and_forward(
261+
attn_output=attn_output,
262+
output=output,
263+
o_proj_full_handle=o_proj_full_handle,
264+
should_shard_weight=full_gather_o_proj_enabled,
265+
)
266+
)
267+
if not require_o_proj_forward:
268+
return result
269+
attn_output = result
270+
271+
if self.enable_dsa_cp_strict_accuracy:
272+
send = (
273+
attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim)
274+
.permute(1, 0, 2)
275+
.reshape(-1, self.num_heads * self.v_head_dim)
276+
)
277+
278+
attn_output = torch.empty_like(send)
279+
torch.distributed.all_to_all_single(
280+
attn_output, send, group=get_tp_group().device_group
281+
)
282+
283+
output[...] = self.o_proj(attn_output)[0]
284+
285+
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
286+
287+
return output_padded
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from ucm.integration.vllm.patch.utils import patch_or_inject, when_imported
2+
from ucm.logger import init_logger
3+
4+
logger = init_logger(__name__)
5+
6+
7+
@when_imported("vllm_ascend.attention.sfa_v1")
8+
def patch_sfa_v1(mod):
9+
logger.debug(f"Patched {mod} called")
10+
11+
from ucm.integration.vllm.patch.v0180.vllm_ascend.pc.attention import sfa_v1
12+
13+
patch_or_inject(mod.AscendSFAImpl, "forward", sfa_v1.AscendSFAImpl.forward)

0 commit comments

Comments
 (0)