Skip to content

Commit 7fd7f4a

Browse files
authored
[spec] fix DeepSeek v3.2 MTP metadata and cuda graph (#1591)
* [spec] fix DeepSeek v3.2 MTP metadata and cuda graph Signed-off-by: AlpinDale <alpindale@gmail.com> * some oversights from previous PR Signed-off-by: AlpinDale <alpindale@gmail.com> --------- Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent e5538f1 commit 7fd7f4a

3 files changed

Lines changed: 8 additions & 6 deletions

File tree

aphrodite/platforms/cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def get_attn_backend_cls(
380380
logger.info_once(
381381
"Using FlexAttention backend for %s.",
382382
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
383+
scope="global",
383384
)
384385
return FLEX_ATTENTION_V1
385386

aphrodite/platforms/xpu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def get_attn_backend_cls(
7171
logger.info_once("Using Flash Attention backend.", scope="global")
7272
return FLASH_ATTN
7373
elif selected_backend:
74-
raise ValueError(
75-
f"Invalid attention backend for {cls.device_name}, with use_v1: {use_v1} use_mla: {use_mla}"
76-
)
74+
raise ValueError(f"Invalid attention backend for {cls.device_name}, with use_mla: {use_mla}")
7775

7876
logger.info_once("Using Flash Attention backend.", scope="global")
7977
return "aphrodite.v1.attention.backends.flash_attn.FlashAttentionBackend"

aphrodite/v1/spec_decode/eagle.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
(sorted(self.aphrodite_config.compilation_config.cudagraph_capture_sizes)) if self.use_cuda_graph else []
9090
)
9191

92+
self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes)
9293
# persistent buffers for cuda graph
9394
self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device)
9495
self.uses_mrope = self.aphrodite_config.model_config.uses_mrope
@@ -824,7 +825,7 @@ def load_model(self, target_model: nn.Module) -> None:
824825
)
825826
indexer_layers = get_layers_from_aphrodite_config(self.aphrodite_config, DeepseekV32IndexerCache)
826827
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
827-
self.attn_layer_names = list(draft_attn_layer_names)
828+
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
828829
self.indexer_layer_names = list(draft_indexer_layer_names)
829830

830831
if self.indexer_layer_names:
@@ -907,14 +908,16 @@ def dummy_run(
907908
num_tokens: int,
908909
use_cudagraphs=True,
909910
) -> None:
910-
if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]:
911+
# Determine if CUDA graphs should be used for this run.
912+
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
913+
if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]:
911914
num_tokens = self.aphrodite_config.pad_for_cudagraph(num_tokens)
912915

913916
with set_forward_context(
914917
None,
915918
self.aphrodite_config,
916919
num_tokens=num_tokens,
917-
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE if use_cudagraphs else CUDAGraphMode.NONE,
920+
cudagraph_runtime_mode=(CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE),
918921
):
919922
if self.supports_mm_inputs:
920923
input_ids = None

0 commit comments

Comments
 (0)