From 36c4ea0d3851ed8c3a72284d6eb6683082efe65c Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Sat, 13 Jun 2026 20:16:13 +0800 Subject: [PATCH 01/19] feat: qwen3.5 perf opt --- .../basemodel/attention/create_utils.py | 4 +- lightllm/common/basemodel/attention/fa3/fp.py | 35 ++- .../basemodel/attention/flashinfer/fp.py | 106 ++++++- .../basemodel/attention/flashinfer/mla.py | 8 +- lightllm/common/basemodel/basemodel.py | 17 +- lightllm/common/basemodel/batch_objs.py | 3 + lightllm/common/basemodel/infer_struct.py | 2 + .../transformer_layer_infer_template.py | 15 +- .../fused_moe/fused_moe_weight.py | 4 + .../meta_weights/fused_moe/impl/base_impl.py | 2 + .../fused_moe/impl/deepgemm_impl.py | 2 + .../fused_moe/impl/marlin_impl.py | 2 + .../fused_moe/impl/triton_impl.py | 8 + .../layer_weights/meta_weights/norm_weight.py | 17 +- .../fused_moe/grouped_fused_moe.py | 48 ++- .../triton_kernel/fused_moe/moe_sum_reduce.py | 56 +++- .../triton_kernel/norm/gated_rmsnorm.py | 7 - .../basemodel/triton_kernel/norm/rmsnorm.py | 119 +++++++- .../triton_kernel/repack_kv_index.py | 7 +- lightllm/distributed/communication_op.py | 39 +++ lightllm/distributed/flashinfer_all_reduce.py | 35 +++ .../layer_infer/transformer_layer_infer.py | 18 +- .../layer_infer/transformer_layer_infer.py | 77 +++-- .../layer_weights/transformer_layer_weight.py | 127 ++++++-- lightllm/models/qwen3next/model.py | 102 +++++++ .../triton_kernel/gdn_decode_pack.py | 284 ++++++++++++++++++ .../triton_kernel/shared_expert_gate.py | 108 +++++++ lightllm/server/api_cli.py | 2 +- lightllm/server/api_openai.py | 2 + lightllm/server/core/objs/sampling_params.py | 3 + .../model_infer/mode_backend/base_backend.py | 63 +++- .../mode_backend/chunked_prefill/impl.py | 17 +- .../mode_backend/generic_post_process.py | 278 +++++++++++++++-- lightllm/utils/sgl_utils.py | 4 +- 34 files changed, 1510 insertions(+), 111 deletions(-) create mode 100644 lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py create mode 100644 lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 594e81a9b4..08d15294bd 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -100,7 +100,7 @@ def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashi return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list) -def get_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend: +def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type backend_str = args.llm_decode_att_backend[index] @@ -120,7 +120,7 @@ def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "fl return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list) -def get_mla_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend: +def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type backend_str = args.llm_decode_att_backend[index] diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..48deabaf4b 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -3,12 +3,16 @@ from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id -from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.sgl_utils import flash_attn_with_kvcache, get_scheduler_metadata from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +_DECODE_MAX_NUM_SPLITS = 32 +_DECODE_PACK_GQA = True + + class Fa3AttBackend(BaseAttBackend): def __init__(self, model): super().__init__(model=model) @@ -119,6 +123,7 @@ class Fa3DecodeAttState(BaseDecodeAttState): cu_seqlens_k: torch.Tensor = None page_table: torch.Tensor = None b_att_seq_len: torch.Tensor = None + scheduler_metadata: torch.Tensor = None # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 decode_max_q_seq_len: int = None @@ -179,8 +184,33 @@ def init_state(self): ) self.b_att_seq_len = self.infer_state.b_seq_len self.decode_max_q_seq_len = 1 + self._init_scheduler_metadata() return + def _init_scheduler_metadata(self): + if get_scheduler_metadata is None: + self.scheduler_metadata = None + return + + model = self.backend.model + self.scheduler_metadata = get_scheduler_metadata( + batch_size=self.b_att_seq_len.shape[0], + max_seqlen_q=self.decode_max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + num_heads=model.config["num_attention_heads"] // model.tp_world_size_, + num_heads_k=model.tp_k_head_num_, + headdim=model.head_dim_, + cache_seqlens=self.b_att_seq_len, + qkv_dtype=model.data_type, + headdim_v=model.head_dim_, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + page_size=1, + causal=True, + num_splits=_DECODE_MAX_NUM_SPLITS, + pack_gqa=_DECODE_PACK_GQA, + ) + def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"): super().copy_for_decode_cuda_graph(new_state) @@ -235,6 +265,9 @@ def _normal_decode_att( causal=True, window_size=window_size, softcap=0.0, + scheduler_metadata=self.scheduler_metadata, + num_splits=_DECODE_MAX_NUM_SPLITS, + pack_gqa=_DECODE_PACK_GQA, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=False, diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 91a004ec2e..b5145ac932 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -6,6 +6,82 @@ from .env_utils import set_flashinfer_envs +def _fast_plan_tensor_core_decode( + decode_wrapper, + indptr, + indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + indptr_host, + kv_lens_arr_host, + max_kv_len, +): + batch_size = len(last_page_len) + if batch_size != decode_wrapper._fixed_batch_size: + raise ValueError( + "The batch size should be fixed in cudagraph mode, the runtime batch size {} " + "mismatches the batch size set during initialization {}".format( + batch_size, decode_wrapper._fixed_batch_size + ) + ) + if len(indices) > len(decode_wrapper._paged_kv_indices_buf): + raise ValueError("The size of indices should be less than or equal to the allocated buffer") + + qo_indptr_host = getattr(decode_wrapper, "_lightllm_qo_indptr_host", None) + if qo_indptr_host is None or len(qo_indptr_host) != batch_size + 1: + from flashinfer.decode import _get_range_buf + + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + decode_wrapper._lightllm_qo_indptr_host = qo_indptr_host + + if indptr_host is None: + indptr_host = indptr.to("cpu") + if kv_lens_arr_host is None: + from flashinfer.decode import get_seq_lens + + last_page_len_host = last_page_len.to("cpu") + kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) + if max_kv_len is None: + max_kv_len = max(kv_lens_arr_host).item() + + decode_wrapper._batch_size = batch_size + decode_wrapper._num_qo_heads = num_qo_heads + decode_wrapper._num_kv_heads = num_kv_heads + decode_wrapper._block_tables = None + decode_wrapper._max_kv_len = max_kv_len + + args = [ + decode_wrapper._float_workspace_buffer, + decode_wrapper._int_workspace_buffer, + decode_wrapper._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_host, + kv_lens_arr_host, + batch_size, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + decode_wrapper.is_cuda_graph_enabled, + head_dim, + head_dim, + False, + -1, + ] + if decode_wrapper._backend == "fa2": + args.extend([-1, False, 0]) + decode_wrapper._plan_info = decode_wrapper._cached_module.plan(*args) + decode_wrapper._pos_encoding_mode = "NONE" + decode_wrapper._window_left = -1 + decode_wrapper._logits_soft_cap = 0.0 + decode_wrapper._sm_scale = None + decode_wrapper._rope_scale = None + decode_wrapper._rope_theta = None + + class FlashInferAttBackend(BaseAttBackend): def __init__(self, model): set_flashinfer_envs() @@ -25,6 +101,10 @@ def __init__(self, model): model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() ), ] + self.kv_starts_host_buffer = [ + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + ] self.q_data_type = model.data_type self.kv_data_type = model.data_type @@ -124,11 +204,11 @@ class FlashInferDecodeAttState(BaseDecodeAttState): kv_last_page_len_buffer: torch.Tensor = None kv_indices: torch.Tensor = None kv_starts: torch.Tensor = None + kv_starts_host: torch.Tensor = None + kv_seq_lens_host: torch.Tensor = None decode_wrapper: object = None def init_state(self): - import flashinfer - self.backend: FlashInferAttBackend = self.backend device = self.infer_state.input_ids.device model = self.backend.model @@ -154,8 +234,21 @@ def init_state(self): self.infer_state.b_kv_start_loc, self.infer_state.max_kv_seq_len, self.kv_indices, + zero_output=False, ) self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.infer_state.b_seq_len_cpu is not None: + self.kv_seq_lens_host = self.infer_state.b_seq_len_cpu + self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][ + : self.infer_state.batch_size + 1 + ] + self.kv_starts_host[0] = 0 + torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:]) + if self.infer_state.skip_decode_att_wrapper_init: + return + + import flashinfer + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.backend.workspace_buffer, @@ -182,7 +275,8 @@ def init_state(self): def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): super().copy_for_decode_cuda_graph(new_state) - self.decode_wrapper.plan( + _fast_plan_tensor_core_decode( + self.decode_wrapper, new_state.kv_starts, new_state.kv_indices, new_state.kv_last_page_len_buffer, @@ -190,9 +284,9 @@ def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): new_state.backend.tp_kv_head_num, new_state.backend.head_dim, 1, - q_data_type=new_state.backend.q_data_type, - kv_data_type=new_state.backend.kv_data_type, - non_blocking=True, + new_state.kv_starts_host, + new_state.kv_seq_lens_host, + new_state.infer_state.max_kv_seq_len, ) def decode_att( diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 84b44dc45a..8689839db4 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -116,8 +116,6 @@ class MlaFlashInferDecodeAttState(BaseDecodeAttState): decode_wrapper: object = None def init_state(self): - import flashinfer - self.backend: MlaFlashInferAttBackend = self.backend model = self.backend.model device = self.infer_state.input_ids.device @@ -144,7 +142,13 @@ def init_state(self): self.infer_state.b_kv_start_loc, self.infer_state.max_kv_seq_len, self.kv_indices, + zero_output=False, ) + if self.infer_state.skip_decode_att_wrapper_init: + return + + import flashinfer + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 94f9d4c1a2..e6d5735385 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -314,6 +314,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0] infer_state.b_req_idx = model_input.b_req_idx infer_state.b_seq_len = model_input.b_seq_len + infer_state.b_seq_len_cpu = model_input.b_seq_len_cpu infer_state.b_mtp_index = model_input.b_mtp_index if model_input.is_prefill: if model_input.b_ready_cache_len is not None: @@ -371,6 +372,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 ) new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) + if new_model_input.b_seq_len_cpu is not None: + new_model_input.b_seq_len_cpu = F.pad( + new_model_input.b_seq_len_cpu, (0, padded_batch_size), mode="constant", value=2 + ) new_model_input.mem_indexes = F.pad( new_model_input.mem_indexes, (0, padded_batch_size), @@ -562,6 +567,8 @@ def _decode( model_input=model_input, new_batch_size=infer_batch_size ) infer_state = self._create_inferstate(model_input) + need_capture = self.graph.need_capture(infer_batch_size) + infer_state.skip_decode_att_wrapper_init = not need_capture copy_kv_index_to_req( self.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -571,7 +578,7 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() - if self.graph.need_capture(infer_batch_size): + if need_capture: infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: @@ -1037,6 +1044,9 @@ def autotune_layers(self): # 控制autotune的层数,用于适配不同模型 return self.config.get("first_k_dense_replace", 0) + 1 + def _autotune_extra_warmup(self): + return + @final @torch.no_grad() @post_empty_cache @@ -1106,6 +1116,11 @@ def _autotune_warmup(self): self.mem_manager.free_all() gc.collect() torch.cuda.empty_cache() + try: + self._autotune_extra_warmup() + except Exception as e: + logger.warning(f"extra autotune warmup failed: {str(e)}") + logger.exception(str(e)) self.layers_num = layer_num_bak torch.distributed.barrier() Autotuner.end_autotune_warmup() diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 1795ff9a82..81bf3cfd5c 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -42,6 +42,7 @@ class ModelInput: multimodal_params: list = None # cpu 变量 mem_indexes_cpu: torch.Tensor = None + b_seq_len_cpu: torch.Tensor = None # prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理 # 的一些变量 b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出 @@ -64,6 +65,8 @@ def to_cuda(self): assert self.is_prefill self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) + if not self.b_seq_len.is_cuda: + self.b_seq_len_cpu = self.b_seq_len self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) if self.b_ready_cache_len is not None: diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 711484c835..575f1ee25f 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -40,6 +40,7 @@ def __init__(self): self.b_mtp_index: torch.Tensor = None self.b_seq_len: torch.Tensor = None + self.b_seq_len_cpu: torch.Tensor = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None # prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度 @@ -56,6 +57,7 @@ def __init__(self): self.return_all_prompt_logics: bool = False self.multimodal_params: dict = None self.is_cuda_graph: bool = False # 标记是否是cuda graph的捕获推理 + self.skip_decode_att_wrapper_init: bool = False self.dist_group: CustomProcessGroup = None # 在microbatch overlap的运行模式下,用于标记当前 microbatch 的 index 序号 diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index f0cc129c09..061658bc07 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -53,6 +53,18 @@ def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tens def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") + def _add_residual_ffn_norm(self, input_embdings, residual, infer_state: InferStateInfo, layer_weight): + add_rmsnorm = getattr(layer_weight.ffn_norm_weight_, "add_rmsnorm", None) + if add_rmsnorm is None: + input_embdings.add_(residual.view(-1, self.embed_dim_)) + return self._ffn_norm(input_embdings, infer_state, layer_weight) + return add_rmsnorm( + input=input_embdings, + residual=residual.view(-1, self.embed_dim_), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) + def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) @@ -89,10 +101,9 @@ def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, l def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): input1 = self._att_norm(input_embdings, infer_state, layer_weight) o = self.token_attention_forward(input1, infer_state, layer_weight) - input_embdings.add_(o.view(-1, self.embed_dim_)) + input1 = self._add_residual_ffn_norm(input_embdings, o, infer_state, layer_weight) o = None - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..d9a77b39a5 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -134,6 +134,8 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -150,6 +152,8 @@ def experts( num_expert_group=num_expert_group, is_prefill=is_prefill, per_expert_scale=self.per_expert_scale, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) def low_latency_dispatch( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index dd6f9a6880..b54b03ee05 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -63,5 +63,7 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 4d4614c007..bc0e86d7eb 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -76,6 +76,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): output = fused_experts( hidden_states=input_tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 0094b09b1c..417d001c72 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -30,6 +30,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index a0d30547a3..fdda2b2139 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -94,6 +94,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale @@ -111,6 +113,8 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return input_tensor @@ -129,6 +133,8 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -150,5 +156,7 @@ def __call__( topk_ids=topk_ids, router_logits=router_logits, is_prefill=is_prefill, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return output diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index ee9d1923c3..33b59ac5a9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -2,7 +2,7 @@ from typing import Optional, Dict from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size -from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import add_rmsnorm_forward, rmsnorm_forward from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward @@ -71,6 +71,21 @@ def __call__( ) -> torch.Tensor: return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def add_rmsnorm( + self, + input: torch.Tensor, + residual: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + input.ndim in [2, 3] and residual.ndim in [2, 3] and self.weight.ndim == 1 + ), f"input.ndim: {input.ndim}, residual.ndim: {residual.ndim}, weight.ndim: {self.weight.ndim}" + if out is None: + out = alloc_func(input.shape, dtype=input.dtype, device=input.device) + return add_rmsnorm_forward(x=input, residual=residual, weight=self.weight, eps=eps, out=out) + class GatedRMSNormWeight(RMSNormWeight): def _triton_forward( diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 76acea25a7..95dcba9836 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -221,10 +221,17 @@ def moe_align_fused_kernel( expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] token_num, + expert_num: tl.constexpr, topk_num: tl.constexpr, BLOCK_SIZE: tl.constexpr, + ZERO_EXPERT_TOKEN_NUM: tl.constexpr, + BLOCK_EXPERT: tl.constexpr, ): token_block = tl.program_id(0) + if ZERO_EXPERT_TOKEN_NUM: + expert_offs = tl.arange(0, BLOCK_EXPERT) + tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) + offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offs < token_num * topk_num @@ -282,6 +289,8 @@ def moe_align_fused( run_config = {} BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256) num_warps = run_config.get("num_warps", 4) + expert_num = expert_token_num.shape[0] + zero_expert_token_num = token_num * topk_num <= BLOCK_SIZE grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),) moe_align_fused_kernel[grid]( @@ -291,8 +300,11 @@ def moe_align_fused( expert_to_weight, expert_token_num, token_num, + expert_num, topk_num, BLOCK_SIZE=BLOCK_SIZE, + ZERO_EXPERT_TOKEN_NUM=zero_expert_token_num, + BLOCK_EXPERT=triton.next_power_of_2(expert_num), num_warps=num_warps, ) return expert_to_token_index, expert_to_weight, expert_token_num @@ -911,6 +923,8 @@ def fused_experts_impl( layout="blocked", limit=None, alpha=None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -957,7 +971,12 @@ def fused_experts_impl( expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda") expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda") - expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda") + expert_token_count_in_align_kernel = topk_num * tokens_in_chunk <= 128 + expert_to_token_num = ( + torch.empty((E,), dtype=torch.int32, device="cuda") + if expert_token_count_in_align_kernel + else torch.zeros((E,), dtype=torch.int32, device="cuda") + ) moe_align_fused( expert_to_token_index=expert_to_tokens, expert_to_weight=expert_to_weights, @@ -1011,8 +1030,15 @@ def fused_experts_impl( bias=w2_bias, ) + has_shared_gate = shared_expert_out is not None moe_sum_reduce( - intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx] + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + shared=None if not has_shared_gate else shared_expert_out[begin_chunk_idx:end_chunk_idx], + gate=None if not has_shared_gate else shared_expert_gate[begin_chunk_idx:end_chunk_idx], + run_config=( + None if not has_shared_gate else {"BLOCK_M": 1, "BLOCK_DIM": 128, "NUM_STAGE": 1, "num_warps": 2} + ), ) return out_hidden_states @@ -1035,6 +1061,8 @@ def inplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: fused_experts_impl( hidden_states, @@ -1054,6 +1082,8 @@ def inplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) @@ -1075,6 +1105,8 @@ def inplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: pass @@ -1105,6 +1137,8 @@ def outplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: return fused_experts_impl( hidden_states, @@ -1124,6 +1158,8 @@ def outplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) @@ -1145,6 +1181,8 @@ def outplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: return torch.empty_like(hidden_states) @@ -1176,6 +1214,8 @@ def fused_experts( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): if inplace: torch.ops.lightllm.inplace_fused_experts_impl( @@ -1195,6 +1235,8 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return hidden_states else: @@ -1215,4 +1257,6 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py index e16351eec8..4f95cca7c6 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py @@ -14,12 +14,20 @@ def _moe_sum_reduce_kernel( output_ptr, output_stride_0, output_stride_1, + shared_ptr, + shared_stride_0, + shared_stride_1, + gate_ptr, + gate_stride_0, + gate_stride_1, token_num: int, topk_num: int, hidden_dim: int, BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr, NUM_STAGE: tl.constexpr, + HAS_SHARED_GATE: tl.constexpr, + GATE_DIM: tl.constexpr, ): input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) @@ -42,12 +50,38 @@ def _moe_sum_reduce_kernel( for i in tl.range(0, topk_num, num_stages=NUM_STAGE): tmp = tl.load(input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0) accumulator += tmp + if HAS_SHARED_GATE: + shared = tl.load( + shared_ptr + token_index * shared_stride_0 + offs_dim * shared_stride_1, + mask=offs_dim < dim_end, + other=0.0, + ).to(tl.float32) + if GATE_DIM == 1: + gate = tl.load(gate_ptr + token_index * gate_stride_0).to(tl.float32) + tl.zeros( + (BLOCK_DIM,), dtype=tl.float32 + ) + else: + gate = tl.load( + gate_ptr + token_index * gate_stride_0 + offs_dim * gate_stride_1, + mask=offs_dim < dim_end, + other=0.0, + ).to(tl.float32) + gate = 1.0 / (1.0 + tl.exp(-gate)) + accumulator += shared * gate store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim tl.store(store_t_ptr, accumulator.to(input_ptr.dtype.element_ty), mask=offs_dim < dim_end) -def _get_moe_sum_reduce_static_key(input: torch.Tensor, output: torch.Tensor): - return {"topk_num": input.shape[1], "hidden_dim": input.shape[2], "out_dtype": str(output.dtype)} +def _get_moe_sum_reduce_static_key( + input: torch.Tensor, output: torch.Tensor, shared: torch.Tensor = None, gate: torch.Tensor = None +): + return { + "topk_num": input.shape[1], + "hidden_dim": input.shape[2], + "out_dtype": str(output.dtype), + "has_shared_gate": shared is not None, + "gate_dim": 0 if gate is None else gate.shape[-1], + } def _get_moe_sum_reduce_configs(): @@ -67,12 +101,20 @@ def _get_moe_sum_reduce_configs(): run_key_func=lambda input: input.shape[0], mutates_args=["output"], ) -def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = None): +def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, shared=None, gate=None, run_config: Dict = None): assert input.is_contiguous() assert output.is_contiguous() token_num, topk_num, hidden_dim = input.shape assert output.shape[0] == token_num and output.shape[1] == hidden_dim + has_shared_gate = shared is not None + if has_shared_gate: + assert gate is not None + shared = shared.view(token_num, hidden_dim) + gate = gate.view(token_num, gate.shape[-1]) + assert shared.is_contiguous() + assert gate.is_contiguous() + assert gate.shape[1] in (1, hidden_dim) if not run_config: run_config = { @@ -97,12 +139,20 @@ def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = *input.stride(), output, *output.stride(), + shared if has_shared_gate else output, + shared.stride(0) if has_shared_gate else 0, + shared.stride(1) if has_shared_gate else 0, + gate if has_shared_gate else output, + gate.stride(0) if has_shared_gate else 0, + gate.stride(1) if has_shared_gate else 0, token_num=token_num, topk_num=topk_num, hidden_dim=hidden_dim, BLOCK_M=BLOCK_M, BLOCK_DIM=BLOCK_DIM, NUM_STAGE=NUM_STAGE, + HAS_SHARED_GATE=has_shared_gate, + GATE_DIM=gate.shape[1] if has_shared_gate else 0, num_warps=num_warps, ) return diff --git a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py index 89db5e00cb..c62c5eb5d2 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py @@ -16,7 +16,6 @@ def gated_rmsnorm_forward_kernel( W, # pointer to the weights B, # pointer to the biases Z, # pointer to the other branch (required, not optional) - Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_z_row, @@ -33,7 +32,6 @@ def gated_rmsnorm_forward_kernel( X += row * stride_x_row + group * N Y += row * stride_y_row + group * N Z += row * stride_z_row + group * N - Rstd += group * M W += group * N if HAS_BIAS: B += group * N @@ -47,7 +45,6 @@ def gated_rmsnorm_forward_kernel( xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) @@ -128,9 +125,6 @@ def gated_rmsnorm_forward( else: out = torch.empty_like(x) assert out.stride(-1) == 1 - # For RMS norm, we still need rstd for the kernel - rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) - # Default heuristic when autotune is disabled or no config provided if not run_config: # Less than 64KB per feature: enqueue fused kernel @@ -160,7 +154,6 @@ def gated_rmsnorm_forward( weight, bias, z, - rstd, x.stride(0), out.stride(0), z.stride(0), diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py index 8dc8558922..79f57ba051 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py @@ -3,6 +3,7 @@ import triton import triton.language as tl import os +from lightllm.common.triton_utils.autotuner import autotune rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8")) @@ -48,6 +49,51 @@ def _rms_norm_fwd_fused( tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) +@triton.jit +def _add_rms_norm_fwd_fused( + X, + R, + Y, + W, + x_stride0, + x_stride1, + r_stride0, + r_stride1, + y_stride0, + y_stride1, + N, + eps, + HAS_WEIGHT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + X += row * x_stride0 + R += row * r_stride0 + Y += row * y_stride0 + + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32) + x = x + r + tl.store(X + cols * x_stride1, x.to(X.dtype.element_ty), mask=mask) + _var += x * x + + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + y = x * rstd + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + y *= w + tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) + + def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None): # allocate output y = torch.empty_like(x) if out is None else out @@ -60,7 +106,7 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() + MAX_FUSED_SIZE = 65536 // x_arg.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) # print("BLOCK_SIZE:", BLOCK_SIZE) if N > BLOCK_SIZE: @@ -86,6 +132,77 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) return y +def _get_add_rmsnorm_configs(): + return [{"num_warps": nw} for nw in [4, 8, 16]] + + +def _get_add_rmsnorm_static_key(x_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor): + return { + "x_dtype": str(x_arg.dtype), + "out_dtype": str(y_arg.dtype), + "weight_dtype": "none" if weight is None else str(weight.dtype), + "N": x_arg.shape[1], + "has_weight": weight is not None, + } + + +@autotune( + kernel_name="add_rmsnorm_forward:v1", + configs_gen_func=_get_add_rmsnorm_configs, + static_key_func=_get_add_rmsnorm_static_key, + run_key_func=lambda x_arg: x_arg.shape[0], + mutates_args=["x_arg", "y_arg"], +) +def _add_rmsnorm_forward( + x_arg: torch.Tensor, + residual_arg: torch.Tensor, + y_arg: torch.Tensor, + weight: torch.Tensor, + eps: float, + run_config: dict = None, +): + M, N = x_arg.shape + MAX_FUSED_SIZE = 65536 // x_arg.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + if BLOCK_SIZE > 16384: + BLOCK_SIZE = 16384 + if not run_config: + run_config = {"num_warps": rmsnorm_num_warps} + _add_rms_norm_fwd_fused[(M,)]( + x_arg, + residual_arg, + y_arg, + weight, + x_arg.stride(0), + x_arg.stride(1), + residual_arg.stride(0), + residual_arg.stride(1), + y_arg.stride(0), + y_arg.stride(1), + N, + eps, + HAS_WEIGHT=weight is not None, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=run_config["num_warps"], + ) + return y_arg + + +def add_rmsnorm_forward(x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float, out=None): + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, x.shape[-1]) + residual_arg = residual.view(-1, x.shape[-1]) + y_arg = y.view(-1, x.shape[-1]) + assert x_arg.shape == residual_arg.shape == y_arg.shape + if weight is not None: + assert x_arg.shape[-1] == weight.shape[0] + assert y.data_ptr() == y_arg.data_ptr() + _add_rmsnorm_forward(x_arg, residual_arg, y_arg, weight, eps) + return y + + def torch_rms_norm(x, weight, eps): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight diff --git a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py index e86d2e819e..57a1c4d0a3 100644 --- a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py +++ b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py @@ -34,10 +34,11 @@ def _fwd_kernel_repack_kv_index( @torch.no_grad() -def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): +def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index, zero_output: bool = True): batch_size = req_index.shape[0] - # flashinfer requires out_kv_index to be zeroed before use - out_kv_index.zero_() + # Some flashinfer callers need zero-filled padding outside the valid indptr range. + if zero_output: + out_kv_index.zero_() BLOCK = 64 grid = ( batch_size, diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index f15badde25..31ff5b8b65 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -100,6 +100,24 @@ def all_reduce(self, input_: torch.Tensor) -> None: return return dist.all_reduce(input_, group=self.device_group) + def all_reduce_residual_rmsnorm( + self, + input_: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + alloc_func=torch.empty, + ): + if self.flashinfer_reduce is None: + return None + return self.flashinfer_reduce.all_reduce_residual_rmsnorm( + input_, + residual=residual, + norm_weight=norm_weight, + eps=eps, + alloc_func=alloc_func, + ) + def all_gather_into_tensor(self, output_: torch.Tensor, input_: torch.Tensor, async_op: bool = False) -> None: return dist.all_gather_into_tensor(output_, input_, group=self.device_group, async_op=async_op) @@ -235,6 +253,27 @@ def all_reduce( return dist.all_reduce(input_, op, group, async_op) +def all_reduce_residual_rmsnorm( + input_: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, + alloc_func=torch.empty, +): + if _is_single_group(group=group): + return None + if isinstance(group, CustomProcessGroup): + return group.all_reduce_residual_rmsnorm( + input_, + residual=residual, + norm_weight=norm_weight, + eps=eps, + alloc_func=alloc_func, + ) + return None + + def all_gather_into_tensor( output_: torch.Tensor, input_: torch.Tensor, diff --git a/lightllm/distributed/flashinfer_all_reduce.py b/lightllm/distributed/flashinfer_all_reduce.py index 27856d9ac7..d469c3bc66 100644 --- a/lightllm/distributed/flashinfer_all_reduce.py +++ b/lightllm/distributed/flashinfer_all_reduce.py @@ -132,4 +132,39 @@ def all_reduce(self, inp: torch.Tensor) -> torch.Tensor: input=inp, workspace=self._workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce, + launch_with_pdl=True, ) + + def all_reduce_residual_rmsnorm( + self, + inp: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + alloc_func=torch.empty, + ): + if ( + residual.shape != inp.shape + or residual.dtype != inp.dtype + or not residual.is_cuda + or norm_weight.dtype != inp.dtype + or norm_weight.shape[0] != inp.shape[-1] + ): + return None + if not self.should_use(inp): + return None + + residual_out = alloc_func(inp.shape, dtype=inp.dtype, device=inp.device) + norm_out = alloc_func(inp.shape, dtype=inp.dtype, device=inp.device) + flashinfer_comm.allreduce_fusion( + input=inp, + workspace=self._workspace, + pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + launch_with_pdl=True, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=norm_weight, + rms_eps=eps, + ) + return residual_out, norm_out diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py index afbd02a482..d9ac369960 100644 --- a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -28,14 +28,24 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkvo_gate_proj = getattr(layer_weight, "qkvo_gate_proj", None) + if qkvo_gate_proj is None: + qkv_out = layer_weight.qkv_proj.mm(input) + o_gate = layer_weight._o_gate_proj.mm(input) + else: + qkv_gate_out = qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - o_gate = layer_weight._o_gate_proj.mm(input) - # In-place sigmoid for gate - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index bb48bfe49c..3492041813 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -10,8 +10,10 @@ from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager from typing import Tuple -from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.gdn_decode_pack import conv_pack_gdn_decode_inputs +from lightllm.models.qwen3next.triton_kernel.shared_expert_gate import add_shared_expert_gate_, sigmoid_mul_ from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule from lightllm.distributed import all_reduce @@ -114,15 +116,14 @@ def _compute_shared_expert( ): input = input.view(-1, self.embed_dim_) shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) - gate = layer_weight.ffn_gate.mm(input).sigmoid_() - shared_expert_out.mul_(gate) - return shared_expert_out + gate = layer_weight.ffn_gate.mm(input) + return shared_expert_out, gate def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + shared_expert_out, gate = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape @@ -135,15 +136,16 @@ def _moe_ffn_tp( use_grouped_topk=False, topk_group=None, num_expert_group=None, + shared_expert_out=shared_expert_out, + shared_expert_gate=gate, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) - hidden_states.add_(shared_expert_out) return hidden_states def _moe_ffn_edp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + shared_expert_out, gate = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input token_num, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -158,7 +160,7 @@ def _moe_ffn_edp( is_prefill=infer_state.is_prefill, ) ep_output = ep_output.view(token_num, hidden_dim) - ep_output.add_(shared_expert_out) + add_shared_expert_gate_(ep_output, shared_expert_out, gate) return ep_output def _get_qkv( @@ -169,13 +171,25 @@ def _get_qkv( ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkvo_gate_proj = getattr(layer_weight, "qkvo_gate_proj", None) + if qkvo_gate_proj is None: + qkv_out = layer_weight.qkv_proj.mm(input) + o_gate = layer_weight._o_gate_proj.mm(input) + else: + qkv_gate_out = qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ * 2 + + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_ * 2, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1, ) - o_gate = layer_weight._o_gate_proj.mm(input) - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -199,15 +213,24 @@ def _get_o( input, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, + ) -> torch.Tensor: + o_tensor = self._get_o_local(input=input, infer_state=infer_state, layer_weight=layer_weight) + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) + return o_tensor + + def _get_o_local( + self, + input, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, ) -> torch.Tensor: """Output projection with gating (in-place multiply to save one allocation).""" if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - input.mul_(infer_state.gate_value) + sigmoid_mul_(input, infer_state.gate_value) infer_state.gate_value = None o_tensor = layer_weight.o_proj.mm(input) - o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor # ==================== GDN Helper Methods ==================== @@ -257,8 +280,9 @@ def gdn_forward( else: mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) - core_attn_out = self._gdn_decode_kernel( + core_attn_out, z = self._gdn_decode_kernel( mixed_qkv, + z, conv_states, ssm_states, a, @@ -406,6 +430,7 @@ def _gdn_prefill_kernel( def _gdn_decode_kernel( self, mixed_qkv: torch.Tensor, + z: torch.Tensor, conv_states: torch.Tensor, ssm_states: torch.Tensor, a: torch.Tensor, @@ -413,18 +438,24 @@ def _gdn_decode_kernel( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): - mixed_qkv = causal_conv1d_update( + # Recurrent processing with fused gating. Decode uses a specialized + # conv+pack kernel to avoid materializing the post-conv qkv tensor + # before immediately splitting it into q/k/v. + query, key, value, z, a, b = conv_pack_gdn_decode_inputs( mixed_qkv, + z, + a, + b, conv_states, layer_weight.linear_conv1d.mm_param.weight, - bias=layer_weight.linear_conv1d.bias, - activation=self.activation, - conv_state_indices=infer_state.b_buffer_idx, + layer_weight.linear_conv1d.bias, + infer_state.b_buffer_idx, + self.activation, + self.tp_num_k_heads, + self.head_k_dim, + self.tp_num_v_heads, + self.head_v_dim, ) - - # Recurrent processing with fused gating - # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally - query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) core_attn_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, @@ -438,4 +469,4 @@ def _gdn_decode_kernel( a_raw=a, b_raw=b, ) - return core_attn_out + return core_attn_out, z diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 0d415ca0e8..51b702039b 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -11,6 +11,83 @@ QKVROWNMMWeight, QKGEMMANormWeight, ) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import get_row_slice_mixin +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size + + +class QKVGatedROWNMMWeight(MMWeightTpl): + def __init__( + self, + in_dim, + q_head_num, + kv_head_num, + head_dim, + weight_names, + data_type, + bias_names=None, + quant_method=None, + tp_rank=None, + tp_world_size=None, + ): + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + self.q_repeat_times = 1 + self.kv_repeat_times = 1 + assert ( + q_head_num % self.tp_world_size_ == 0 + ), f"q_head_num must be divisible by tp_world_size_, found {q_head_num} % {self.tp_world_size_}" + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or vice versa, " + f"found {kv_head_num} % {self.tp_world_size_}" + ) + q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + super().__init__( + in_dim=in_dim, + out_dims=[q_hidden_size, kv_hidden_size, kv_hidden_size, q_hidden_size], + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.q_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.q_repeat_times, + ) + self.kv_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.kv_repeat_times, + ) + + def _get_param_slicer(self, sub_child_index): + if sub_child_index == 0 or sub_child_index == 3: + return self.q_param_slicer + return self.kv_param_slicer + + def load_hf_weights(self, weights): + super().load_hf_weights(weights) + if self.bias_names is not None: + for sub_child_index, bias_name in enumerate(self.bias_names): + if bias_name is None: + self.bias_list[sub_child_index].zero_() + self.bias_list[sub_child_index].load_ok = True + + def _get_tp_padded_head_num(self, head_num): + if head_num % self.tp_world_size_ == 0: + return head_num // self.tp_world_size_ + if self.tp_world_size_ % head_num == 0: + self.kv_repeat_times = self.tp_world_size_ // head_num + return self.kv_repeat_times * head_num // self.tp_world_size_ + raise ValueError( + f"head_num must be divisible by tp_world_size_ or vice versa, found {head_num} % {self.tp_world_size_}" + ) class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -23,25 +100,39 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): def _init_qkv(self): in_dim = self.n_embed q_out_dim = self.q_head_num_ * self.head_dim - self.qkv_proj = QKVROWNMMWeight( - in_dim=in_dim, - q_head_num=self.q_head_num_, - kv_head_num=self.k_head_num_, - head_dim=self.head_dim, - weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], - data_type=self.data_type_, - bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], - quant_method=self.get_quant_method("qkv_proj"), - ) self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" - self._o_gate_proj = ROWMMWeight( - in_dim=in_dim, - out_dims=[q_out_dim], - weight_names=[self._o_gate_weight_name], - data_type=self.data_type_, - bias_names=None, - quant_method=self.get_quant_method("o_gate_proj"), - ) + qkv_quant = self.get_quant_method("qkv_proj") + gate_quant = self.get_quant_method("o_gate_proj") + if qkv_quant.method_name == "none" and gate_quant.method_name == "none": + self.qkvo_gate_proj = QKVGatedROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name, self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name, None], + quant_method=qkv_quant, + ) + else: + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=qkv_quant, + ) + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=gate_quant, + ) def _init_weight(self): if self.is_linear_attention_layer: diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 9b5e9b7a50..a95196abaf 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -17,6 +17,8 @@ from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.req_manager import ReqManagerForMamba from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig +from lightllm.common.basemodel.batch_objs import ModelOutput +from lightllm.distributed import all_reduce, all_reduce_residual_rmsnorm logger = init_logger(__name__) @@ -51,6 +53,29 @@ def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch def autotune_layers(self): return self.config["full_attention_interval"] + def _autotune_extra_warmup(self): + if not self.trans_layers_weight: + return + + norm_weight = self.trans_layers_weight[0].ffn_norm_weight_ + add_rmsnorm = getattr(norm_weight, "add_rmsnorm", None) + if add_rmsnorm is None: + return + + hidden_dim = norm_weight.weight.shape[0] + max_batch_size = min(self.graph_max_batch_size, self.batch_max_tokens) + warmup_batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + warmup_batch_sizes = [bs for bs in warmup_batch_sizes if bs <= max_batch_size] + if max_batch_size not in warmup_batch_sizes: + warmup_batch_sizes.append(max_batch_size) + + for batch_size in sorted(set(warmup_batch_sizes)): + x = torch.zeros((batch_size, hidden_dim), dtype=self.data_type, device="cuda") + residual = torch.zeros_like(x) + out = torch.empty_like(x) + add_rmsnorm(input=x, residual=residual, eps=self.layers_infer[0].eps_, out=out) + return + def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) @@ -102,3 +127,80 @@ def _init_req_manager(self): self.max_req_num, create_max_seq_len, None, linear_config=LinearAttCacheConfig.load_from_args() ) return + + def _token_forward(self, infer_state: Qwen3NextInferStateInfo): + input_ids = infer_state.input_ids + input_embs = self.pre_infer.token_forward(input_ids, infer_state, self.pre_post_weight) + input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) + + next_att_normed = None + for i in range(self.layers_num): + layer: Qwen3NextTransformerLayerInfer = self.layers_infer[i] + layer_weight: Qwen3NextTransformerLayerWeight = self.trans_layers_weight[i] + + if next_att_normed is None: + input1 = layer._att_norm(input_embs, infer_state, layer_weight) + else: + input1 = next_att_normed + next_att_normed = None + + if layer.is_linear_attention_layer: + o = layer.token_attention_forward(input1, infer_state, layer_weight) + input1 = layer._add_residual_ffn_norm(input_embs, o, infer_state, layer_weight) + o = None + else: + q, cache_kv = layer._get_qkv(input1, infer_state, layer_weight) + layer._post_cache_kv(cache_kv, infer_state, layer_weight) + o = layer._token_attention_kernel(q, infer_state, layer_weight) + q = None + o = layer._get_o_local(o, infer_state, layer_weight) + fused = None + if layer.tp_world_size_ > 1: + fused = all_reduce_residual_rmsnorm( + o, + residual=input_embs.view(-1, layer.embed_dim_), + norm_weight=layer_weight.ffn_norm_weight_.weight, + eps=layer.eps_, + group=infer_state.dist_group, + alloc_func=layer.alloc_tensor, + ) + if fused is None: + if layer.tp_world_size_ > 1: + all_reduce(o, group=infer_state.dist_group) + input1 = layer._add_residual_ffn_norm(input_embs, o, infer_state, layer_weight) + else: + input_embs, input1 = fused + o = None + + ffn_out = layer._ffn(input1, infer_state, layer_weight) + ffn_out = ffn_out.view(-1, layer.embed_dim_) + + if i + 1 < self.layers_num: + next_layer: Qwen3NextTransformerLayerInfer = self.layers_infer[i + 1] + next_layer_weight: Qwen3NextTransformerLayerWeight = self.trans_layers_weight[i + 1] + add_rmsnorm = getattr(next_layer_weight.att_norm_weight_, "add_rmsnorm", None) + if add_rmsnorm is not None: + next_att_normed = add_rmsnorm( + input=input_embs, + residual=ffn_out, + eps=next_layer.eps_, + alloc_func=next_layer.alloc_tensor, + ) + continue + + input_embs.add_(ffn_out) + + last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + predict_logits: torch.Tensor = self.post_infer.token_forward( + last_input_embs, infer_state=infer_state, layer_weight=self.pre_post_weight + ) + + model_output = ModelOutput(logits=predict_logits.contiguous()) + if self.is_mtp_mode: + input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + model_output.mtp_main_output_hiddens = input_embs.contiguous() + + if infer_state.is_cuda_graph: + model_output.to_no_ref_tensor() + + return model_output diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py new file mode 100644 index 0000000000..a025e35c64 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py @@ -0,0 +1,284 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + BLOCK_QKV: tl.constexpr, + BLOCK_GATE: tl.constexpr, +): + row = tl.program_id(0) + qkv_offsets = tl.arange(0, BLOCK_QKV) + + q_mask = qkv_offsets < q_dim + q_vals = tl.load(mixed_qkv + row * stride_m_b + qkv_offsets * stride_m_d, mask=q_mask, other=0.0) + tl.store(q_out + row * q_dim + qkv_offsets, q_vals, mask=q_mask) + + k_mask = qkv_offsets < k_dim + k_vals = tl.load( + mixed_qkv + row * stride_m_b + (q_dim + qkv_offsets) * stride_m_d, + mask=k_mask, + other=0.0, + ) + tl.store(k_out + row * k_dim + qkv_offsets, k_vals, mask=k_mask) + + v_mask = qkv_offsets < v_dim + v_vals = tl.load( + mixed_qkv + row * stride_m_b + (q_dim + k_dim + qkv_offsets) * stride_m_d, + mask=v_mask, + other=0.0, + ) + tl.store(v_out + row * v_dim + qkv_offsets, v_vals, mask=v_mask) + + z_vals = tl.load(z_raw + row * stride_z_b + qkv_offsets, mask=v_mask, other=0.0) + tl.store(z_out + row * v_dim + qkv_offsets, z_vals, mask=v_mask) + + gate_offsets = tl.arange(0, BLOCK_GATE) + gate_mask = gate_offsets < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + gate_offsets * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + gate_offsets * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + gate_offsets, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + gate_offsets, b_vals, mask=gate_mask) + + +@torch.no_grad() +def pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_qkv = triton.next_power_of_2(max(q_dim, k_dim, v_dim)) + block_gate = triton.next_power_of_2(gate_dim) + _pack_gdn_decode_kernel[(batch,)]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + BLOCK_QKV=block_qkv, + BLOCK_GATE=block_gate, + num_warps=4, + ) + return q, k, v, z, a, b + + +@triton.jit +def _conv_pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + stride_s_b: tl.constexpr, + stride_s_d: tl.constexpr, + stride_s_w: tl.constexpr, + stride_w_d: tl.constexpr, + stride_w_w: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + conv_dim: tl.constexpr, + HAS_BIAS: tl.constexpr, + APPLY_SILU: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offs = block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < conv_dim + state_idx = tl.load(conv_state_indices + row) + + x = tl.load(mixed_qkv + row * stride_m_b + offs * stride_m_d, mask=mask, other=0.0).to(tl.float32) + s0 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 0 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + s1 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 1 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + s2 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 2 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + w0 = tl.load(conv_weight + offs * stride_w_d + 0 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w1 = tl.load(conv_weight + offs * stride_w_d + 1 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w2 = tl.load(conv_weight + offs * stride_w_d + 2 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w3 = tl.load(conv_weight + offs * stride_w_d + 3 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + y = s0 * w0 + s1 * w1 + s2 * w2 + x * w3 + if HAS_BIAS: + bias = tl.load(conv_bias + offs, mask=mask, other=0.0).to(tl.float32) + y += bias + if APPLY_SILU: + y = y * tl.sigmoid(y) + + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 0 * stride_s_w, s1, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 1 * stride_s_w, s2, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 2 * stride_s_w, x, mask=mask) + + q_mask = offs < q_dim + k_mask = (offs >= q_dim) & (offs < q_dim + k_dim) + v_mask = (offs >= q_dim + k_dim) & (offs < conv_dim) + tl.store(q_out + row * q_dim + offs, y, mask=q_mask) + tl.store(k_out + row * k_dim + (offs - q_dim), y, mask=k_mask) + tl.store(v_out + row * v_dim + (offs - q_dim - k_dim), y, mask=v_mask) + + z_mask = offs < v_dim + z_vals = tl.load(z_raw + row * stride_z_b + offs, mask=z_mask, other=0.0) + tl.store(z_out + row * v_dim + offs, z_vals, mask=z_mask) + + gate_mask = offs < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + offs * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + offs * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + offs, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + offs, b_vals, mask=gate_mask) + + +@torch.no_grad() +def conv_pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + conv_state: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + conv_state_indices: torch.Tensor, + activation: str, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + conv_dim = q_dim + k_dim + v_dim + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_size = 256 + grid = (batch, triton.cdiv(conv_dim, block_size)) + _conv_pack_gdn_decode_kernel[grid]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + conv_weight.stride(0), + conv_weight.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + conv_dim, + HAS_BIAS=conv_bias is not None, + APPLY_SILU=activation in ["silu", "swish"], + BLOCK_SIZE=block_size, + num_warps=8, + ) + return q, k, v, z, a, b diff --git a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py new file mode 100644 index 0000000000..c2b110def6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py @@ -0,0 +1,108 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _add_shared_expert_gate_kernel( + hidden, + shared, + gate, + stride_h_m: tl.constexpr, + stride_h_n: tl.constexpr, + stride_s_m: tl.constexpr, + stride_s_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + + hidden_ptrs = hidden + row * stride_h_m + offs * stride_h_n + shared_vals = tl.load(shared + row * stride_s_m + offs * stride_s_n, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + hidden_vals = tl.load(hidden_ptrs, mask=mask, other=0.0).to(tl.float32) + gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) + out = hidden_vals + shared_vals * gate_vals + tl.store(hidden_ptrs, out.to(hidden.dtype.element_ty), mask=mask) + + +@triton.jit +def _sigmoid_mul_kernel( + x, + gate, + stride_x_m: tl.constexpr, + stride_x_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + x_ptrs = x + row * stride_x_m + offs * stride_x_n + x_vals = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) + tl.store(x_ptrs, (x_vals * gate_vals).to(x.dtype.element_ty), mask=mask) + + +@torch.no_grad() +def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + hidden_arg = hidden.view(-1, hidden.shape[-1]) + shared_arg = shared.view(-1, hidden.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert hidden_arg.shape == shared_arg.shape + assert gate_arg.shape[0] == hidden_arg.shape[0] and gate_arg.shape[1] in (1, hidden_arg.shape[1]) + _, n = hidden_arg.shape + block_n = triton.next_power_of_2(n) + _add_shared_expert_gate_kernel[(hidden_arg.shape[0],)]( + hidden_arg, + shared_arg, + gate_arg, + hidden_arg.stride(0), + hidden_arg.stride(1), + shared_arg.stride(0), + shared_arg.stride(1), + gate_arg.stride(0), + gate_arg.stride(1), + n, + gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return hidden + + +@torch.no_grad() +def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + x_arg = x.view(-1, x.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1]) + _, n = x_arg.shape + block_n = triton.next_power_of_2(n) + _sigmoid_mul_kernel[(x_arg.shape[0],)]( + x_arg, + gate_arg, + x_arg.stride(0), + x_arg.stride(1), + gate_arg.stride(0), + gate_arg.stride(1), + n, + gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return x diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 7e40421140..8e9866a1d7 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -401,7 +401,7 @@ def make_argument_parser() -> argparse.ArgumentParser: default=["auto"], help="""decode attention kernel used in llm. auto: automatically select best backend based on GPU and available packages - (priority: flashinfer > fa3 > triton)""", + (priority: fa3 > flashinfer > triton)""", ) parser.add_argument( "--vit_att_backend", diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 0d934c44c9..df79913e23 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -314,6 +314,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req "n": request.n, "best_of": request.n, "add_special_tokens": False, + "return_logprobs": request.logprobs is not None, "seed": request.seed, } @@ -822,6 +823,7 @@ async def completions_impl(request: CompletionRequest, raw_request: Request) -> "n": request.n, "best_of": request.best_of, "add_special_tokens": False, + "return_logprobs": request.logprobs is not None, "seed": request.seed, } if request.max_completion_tokens is not None: diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index c39559f5f6..3515fbf1a1 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -304,6 +304,7 @@ class SamplingParams(ctypes.Structure): ), # whether to add spaces between special tokens when decoding ("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True ("disable_prompt_cache", ctypes.c_bool), # whether to disable prompt cache + ("return_logprobs", ctypes.c_bool), # whether generated token logprobs are required by the caller ("seed", ctypes.c_int64), # random seed ] @@ -340,6 +341,7 @@ def init(self, tokenizer, **kwargs): self.add_special_tokens = kwargs.get("add_special_tokens", True) self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) self.print_eos_token = kwargs.get("print_eos_token", False) + self.return_logprobs = kwargs.get("return_logprobs", True) self.seed = kwargs.get("seed", -1) self.exponential_decay_length_penalty = ExponentialDecayLengthPenalty() @@ -486,6 +488,7 @@ def to_dict(self): "add_spaces_between_special_tokens": self.add_spaces_between_special_tokens, "print_eos_token": self.print_eos_token, "disable_prompt_cache": self.disable_prompt_cache, + "return_logprobs": self.return_logprobs, "seed": self.seed, } diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4323a62d1c..512b6ad2c2 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -343,7 +343,11 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return - def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): + def _async_copy_next_token_infos_to_pin_mem( + self, + next_token_ids: torch.Tensor, + next_token_logprobs: Optional[torch.Tensor], + ): """ 这个函数会把next token id和logprobs保存到pinned memory中 这样可以保障post_handle 函数可以读取到正常的输出结果。 @@ -352,9 +356,13 @@ def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, key="next_token_ids", gpu_tensor=next_token_ids, ) - next_token_logprobs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( - key="next_token_logprobs", - gpu_tensor=next_token_logprobs, + next_token_logprobs_cpu = ( + None + if next_token_logprobs is None + else g_pin_mem_manager.async_copy_from_gpu_tensor( + key="next_token_logprobs", + gpu_tensor=next_token_logprobs, + ) ) return next_token_ids_cpu, next_token_logprobs_cpu @@ -700,7 +708,7 @@ def _post_handle( self, run_reqs: List[InferReq], next_token_ids: List[int], - next_token_logprobs: List[float], + next_token_logprobs: Optional[List[float]], run_reqs_update_packs: List[InferReqUpdatePack], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, pd_prefill_chunked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, @@ -709,9 +717,18 @@ def _post_handle( extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。 """ - for req_obj, next_token_id, next_token_logprob, pack in zip( - run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs - ): + if next_token_logprobs is None: + iter_items = zip(run_reqs, next_token_ids, run_reqs_update_packs) + else: + iter_items = zip(run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs) + + for item in iter_items: + if next_token_logprobs is None: + req_obj, next_token_id, pack = item + next_token_logprob = 0.0 + else: + req_obj, next_token_id, next_token_logprob, pack = item + req_obj: InferReq = req_obj pack: InferReqUpdatePack = pack pack.handle( @@ -800,10 +817,38 @@ def _sample_and_scatter_token( mask=b_has_out, ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( - next_token_ids, next_token_logprobs + next_token_ids, + next_token_logprobs, ) return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + def _can_decode_pre_post_before_prev_post_handle( + self, + run_reqs: List[InferReq], + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + ) -> bool: + if not self.support_overlap: + return False + if extra_post_req_handle_func is not None or self.decode_mask_func is not None: + return False + if self.args.mtp_mode: + return False + + for req_obj in run_reqs: + if req_obj.mtp_step != 0: + return False + if req_obj.infer_aborted or req_obj.finish_status.is_finished(): + return False + + shm_param = req_obj.sampling_param.shm_param + if not shm_param.ignore_eos: + return False + if len(req_obj.stop_sequences) != 0: + return False + if req_obj.cur_output_len + 1 >= shm_param.max_new_tokens: + return False + return True + def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 792a10a788..231a98f853 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -163,11 +163,22 @@ def decode_normal( sync_event.record() # 第二阶段 - event_pack.notify_post_handle_and_wait_pre_post_handle() - update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=False) + can_pre_post_early = self._can_decode_pre_post_before_prev_post_handle( + run_reqs=run_reqs, + extra_post_req_handle_func=self.extra_post_req_handle_func, + ) + if can_pre_post_early: + event_pack.notify_post_handle_event.set() + update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=False) + event_pack.notify_forward_event.set() + event_pack.wait_pre_post_handle_event.wait() + else: + event_pack.notify_post_handle_and_wait_pre_post_handle() + update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=False) # 第三阶段 - event_pack.notify_forward_and_wait_post_handle() + if not can_pre_post_early: + event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() self._post_handle( run_reqs=run_reqs, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index 5b29ea0510..14fd0c21c6 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,5 +1,5 @@ import torch -from typing import List, Tuple +from typing import List, Optional, Tuple, Union from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids @@ -7,8 +7,24 @@ from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.envs_utils import get_env_start_args +_flashinfer_top_k_top_p_sampling_from_logits = None +_flashinfer_top_k_top_p_sampling_from_logits_checked = False +_flashinfer_top_k_top_p_sampling_from_probs = None +_flashinfer_top_k_top_p_sampling_from_probs_checked = False +_flashinfer_top_p_sampling_from_probs = None +_flashinfer_top_p_sampling_from_probs_checked = False +_flashinfer_top_k_sampling_from_probs = None +_flashinfer_top_k_sampling_from_probs_checked = False +_uniform_tensor_cache = {} +_softmax_out_cache = {} +_is_flashinfer_sampling_backend = None + def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): + fast_next_token_ids = _try_flashinfer_sample_without_penalty(logits, reqs) + if fast_next_token_ids is not None: + return fast_next_token_ids.view(-1), None + ( b_req_idx, b_temperatures, @@ -23,6 +39,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): skip_top_k, skip_top_p, exist_req_use_random_seed, + need_logprobs, ) = _get_post_sample_tensors(reqs) eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) @@ -75,7 +92,18 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): cu_invalid_token_num=cu_invalid_token_num, ) - logits.div_(b_temperatures.view((-1, 1))) + if b_temperatures is not None: + logits.div_(b_temperatures.view((-1, 1))) + + if is_all_greedy and not need_logprobs: + batch_next_token_ids = torch.argmax(logits, -1) + if get_env_start_args().mtp_mode: + batch_next_token_logprobs = torch.zeros( + batch_next_token_ids.shape, dtype=torch.float32, device=batch_next_token_ids.device + ) + return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) + return batch_next_token_ids.view(-1), None + probs = torch.softmax(logits, dim=-1) if is_all_greedy: @@ -86,16 +114,82 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): elif skip_top_k and skip_top_p: # topk 等于整个词表,topp 等于1.0,等价于不进行topk topp过滤,直接进行随机采样,可以提升采样速度 batch_next_token_ids = _random_sample(probs, reqs, exist_req_use_random_seed) + if not need_logprobs: + return batch_next_token_ids.view(-1), None batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) else: batch_next_token_ids, batch_next_token_logprobs = _top_p_top_k_sample( - reqs, probs, b_top_ps, b_top_ks, exist_req_use_random_seed + reqs, + probs, + b_top_ps, + b_top_ks, + skip_top_k, + skip_top_p, + exist_req_use_random_seed, + need_logprobs, ) + if batch_next_token_logprobs is None: + return batch_next_token_ids.view(-1), None return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) +def _try_flashinfer_sample_without_penalty(logits: torch.Tensor, reqs: List[InferReq]) -> Optional[torch.Tensor]: + if not _is_flashinfer_sampling() or not reqs: + return None + + first_param = reqs[0].sampling_param.shm_param + top_p = first_param.top_p + top_k = first_param.top_k + temperature = first_param.temperature + vocab_size = reqs[0].vocab_size + + if top_k <= 1 or (top_k == vocab_size and top_p == 1.0): + return None + + for req in reqs: + shm_param = req.sampling_param.shm_param + if shm_param.return_logprobs: + return None + if req.generator is not None: + return None + if len(req.sampling_param.invalid_token_ids) != 0: + return None + if not shm_param.ignore_eos: + return None + if shm_param.presence_penalty != 0.0: + return None + if shm_param.frequency_penalty != 0.0: + return None + if shm_param.repetition_penalty != 1.0: + return None + if shm_param.temperature != temperature: + return None + if shm_param.top_p != top_p: + return None + if shm_param.top_k != top_k: + return None + + if temperature != 1.0: + logits.div_(temperature) + + if top_k == vocab_size and top_p != 1.0: + top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) + return _flashinfer_top_p_sample_from_logits(logits, top_p_tensor) + + top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) + top_k_tensor = _get_uniform_tensor(top_k, logits.shape[0], torch.int32, logits.device) + return _flashinfer_top_p_top_k_sample_from_logits(logits, top_p_tensor, top_k_tensor) + + +def _is_flashinfer_sampling() -> bool: + global _is_flashinfer_sampling_backend + if _is_flashinfer_sampling_backend is None: + _is_flashinfer_sampling_backend = get_env_start_args().sampling_backend == "flashinfer" + return _is_flashinfer_sampling_backend + + def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): probs_sort, probs_idx = probs.sort(dim=-1, descending=True) @@ -107,13 +201,123 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor return probs_sort, probs_idx +def _flashinfer_top_p_top_k_sample_from_logits( + logits: torch.Tensor, + b_top_ps: Union[torch.Tensor, float], + b_top_ks: Union[torch.Tensor, int], +) -> Optional[torch.Tensor]: + global _flashinfer_top_k_top_p_sampling_from_logits + global _flashinfer_top_k_top_p_sampling_from_logits_checked + + if not _flashinfer_top_k_top_p_sampling_from_logits_checked: + try: + from flashinfer.sampling import top_k_top_p_sampling_from_logits + except ImportError: + top_k_top_p_sampling_from_logits = None + _flashinfer_top_k_top_p_sampling_from_logits = top_k_top_p_sampling_from_logits + _flashinfer_top_k_top_p_sampling_from_logits_checked = True + + if _flashinfer_top_k_top_p_sampling_from_logits is None: + return None + + return _flashinfer_top_k_top_p_sampling_from_logits( + logits, + b_top_ks, + b_top_ps, + filter_apply_order="joint", + deterministic=True, + check_nan=False, + ) + + +def _flashinfer_top_p_sample_from_logits( + logits: torch.Tensor, top_p: Union[torch.Tensor, float] +) -> Optional[torch.Tensor]: + probs = _softmax_out(logits) + return _flashinfer_top_p_sample_from_probs(probs, top_p) + + +def _get_uniform_tensor(value: Union[float, int], size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + key = (str(device), dtype, size, value) + tensor = _uniform_tensor_cache.get(key) + if tensor is None: + tensor = torch.full((size,), value, dtype=dtype, device=device) + _uniform_tensor_cache[key] = tensor + return tensor + + +def _softmax_out(logits: torch.Tensor) -> torch.Tensor: + key = (str(logits.device), logits.dtype, tuple(logits.shape)) + probs = _softmax_out_cache.get(key) + if probs is None: + probs = torch.empty_like(logits) + _softmax_out_cache[key] = probs + torch.ops.aten._softmax.out(logits, -1, False, out=probs) + return probs + + +def _get_flashinfer_top_k_top_p_sampling_from_probs(): + global _flashinfer_top_k_top_p_sampling_from_probs + global _flashinfer_top_k_top_p_sampling_from_probs_checked + + if not _flashinfer_top_k_top_p_sampling_from_probs_checked: + try: + from flashinfer.sampling import top_k_top_p_sampling_from_probs + except ImportError: + top_k_top_p_sampling_from_probs = None + _flashinfer_top_k_top_p_sampling_from_probs = top_k_top_p_sampling_from_probs + _flashinfer_top_k_top_p_sampling_from_probs_checked = True + return _flashinfer_top_k_top_p_sampling_from_probs + + +def _flashinfer_top_p_sample_from_probs( + probs: torch.Tensor, top_p: Union[torch.Tensor, float] +) -> Optional[torch.Tensor]: + global _flashinfer_top_p_sampling_from_probs + global _flashinfer_top_p_sampling_from_probs_checked + + if not _flashinfer_top_p_sampling_from_probs_checked: + try: + from flashinfer.sampling import top_p_sampling_from_probs + except ImportError: + top_p_sampling_from_probs = None + _flashinfer_top_p_sampling_from_probs = top_p_sampling_from_probs + _flashinfer_top_p_sampling_from_probs_checked = True + + if _flashinfer_top_p_sampling_from_probs is None: + return None + + return _flashinfer_top_p_sampling_from_probs(probs, top_p, deterministic=True, check_nan=False) + + +def _flashinfer_top_k_sample_from_probs(probs: torch.Tensor, top_k: Union[torch.Tensor, int]) -> Optional[torch.Tensor]: + global _flashinfer_top_k_sampling_from_probs + global _flashinfer_top_k_sampling_from_probs_checked + + if not _flashinfer_top_k_sampling_from_probs_checked: + try: + from flashinfer.sampling import top_k_sampling_from_probs + except ImportError: + top_k_sampling_from_probs = None + _flashinfer_top_k_sampling_from_probs = top_k_sampling_from_probs + _flashinfer_top_k_sampling_from_probs_checked = True + + if _flashinfer_top_k_sampling_from_probs is None: + return None + + return _flashinfer_top_k_sampling_from_probs(probs, top_k, deterministic=True, check_nan=False) + + def _top_p_top_k_sample( reqs: List[InferReq], probs: torch.Tensor, - b_top_ps: torch.Tensor, - b_top_ks: torch.Tensor, + b_top_ps: Union[torch.Tensor, float], + b_top_ks: Union[torch.Tensor, int], + skip_top_k: bool, + skip_top_p: bool, exist_req_use_random_seed: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: + need_logprobs: bool, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: sampling_backend = get_env_start_args().sampling_backend if sampling_backend == "triton": @@ -123,19 +327,32 @@ def _top_p_top_k_sample( else: sampled_index = _random_sample(probs_sort, reqs, exist_req_use_random_seed).view(-1, 1) next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index) + if not need_logprobs: + return next_token_ids.view(-1), None next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index)) return next_token_ids.view(-1), next_token_logprobs.view(-1) elif sampling_backend == "flashinfer": - from flashinfer.sampling import top_k_top_p_sampling_from_probs - - batch_next_token_ids = top_k_top_p_sampling_from_probs( - probs, - b_top_ks, - b_top_ps, - filter_apply_order="joint", - check_nan=False, - ) + if skip_top_k: + batch_next_token_ids = _flashinfer_top_p_sample_from_probs(probs, b_top_ps) + elif skip_top_p: + batch_next_token_ids = _flashinfer_top_k_sample_from_probs(probs, b_top_ks) + else: + top_k_top_p_sampling_from_probs = _get_flashinfer_top_k_top_p_sampling_from_probs() + if top_k_top_p_sampling_from_probs is None: + raise ImportError("flashinfer.sampling.top_k_top_p_sampling_from_probs is not available") + batch_next_token_ids = top_k_top_p_sampling_from_probs( + probs, + b_top_ks, + b_top_ps, + filter_apply_order="joint", + deterministic=True, + check_nan=False, + ) + if batch_next_token_ids is None: + raise ImportError("flashinfer sampling op is not available") + if not need_logprobs: + return batch_next_token_ids.view(-1), None int64_batch_next_token_ids = torch.empty_like(batch_next_token_ids, dtype=torch.int64) int64_batch_next_token_ids[:] = batch_next_token_ids batch_next_token_probs = torch.gather(probs, dim=1, index=int64_batch_next_token_ids.view(-1, 1)) @@ -165,6 +382,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_k = True skip_top_p = True exist_req_use_random_seed = False + need_logprobs = False + all_temperature_one = True # invalid token ids invalid_token_ids: List[int] = [] @@ -192,6 +411,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_p = False if req_obj.generator is not None: exist_req_use_random_seed = True + if shm_param.return_logprobs: + need_logprobs = True + if shm_param.temperature != 1.0: + all_temperature_one = False req_idxes.append(req_obj.req_idx) invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids) cu_invalid_token_num.append(invalid_token_num_start) @@ -200,13 +423,25 @@ def _get_post_sample_tensors(reqs: List[InferReq]): invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids) req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) - temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) - top_ps_cpu = g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32) - top_ks_cpu = g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32) length_penalty_param_cpu = g_pin_mem_manager.gen_from_list( key="length_penalty_param", data=length_penalty_param, dtype=torch.int32 ) mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool) + temperatures_cpu = ( + None + if all_temperature_one + else g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) + ) + sampling_backend = get_env_start_args().sampling_backend + need_top_k_top_p_tensors = (not is_all_greedy) and (not (skip_top_k and skip_top_p)) + need_top_ps_tensor = need_top_k_top_p_tensors and (sampling_backend != "flashinfer" or not skip_top_p) + need_top_ks_tensor = need_top_k_top_p_tensors and (sampling_backend != "flashinfer" or not skip_top_k) + top_ps_cpu = ( + g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32) if need_top_ps_tensor else None + ) + top_ks_cpu = ( + g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32) if need_top_ks_tensor else None + ) if has_invalid_token_ids: invalid_token_ids_cpu = g_pin_mem_manager.gen_from_list( @@ -218,9 +453,9 @@ def _get_post_sample_tensors(reqs: List[InferReq]): return ( req_idxes_cpu.cuda(non_blocking=True), - temperatures_cpu.cuda(non_blocking=True), - top_ps_cpu.cuda(non_blocking=True), - top_ks_cpu.cuda(non_blocking=True), + temperatures_cpu.cuda(non_blocking=True) if temperatures_cpu is not None else None, + top_ps_cpu.cuda(non_blocking=True) if top_ps_cpu is not None else None, + top_ks_cpu.cuda(non_blocking=True) if top_ks_cpu is not None else None, length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, @@ -230,4 +465,5 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_k, skip_top_p, exist_req_use_random_seed, + need_logprobs, ) diff --git a/lightllm/utils/sgl_utils.py b/lightllm/utils/sgl_utils.py index b48a62506d..b79a554f48 100644 --- a/lightllm/utils/sgl_utils.py +++ b/lightllm/utils/sgl_utils.py @@ -17,14 +17,16 @@ ) try: - from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, get_scheduler_metadata flash_attn_varlen_func = flash_attn_varlen_func flash_attn_with_kvcache = flash_attn_with_kvcache + get_scheduler_metadata = get_scheduler_metadata merge_state_v2 = sgl_ops.merge_state_v2 except: flash_attn_varlen_func = None flash_attn_with_kvcache = None + get_scheduler_metadata = None merge_state_v2 = None logger.warning( "sgl_kernel is not installed, or the installed version did not support fa3. \ From 17a3a7b2cb31ec2826d45f961986be0b7799f26f Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Jun 2026 09:30:07 +0800 Subject: [PATCH 02/19] feat(mtp): MTP verify-decode infrastructure Model-agnostic verify-decode machinery: MTP-verify dispatch in TpPartBaseModel, dedicated decode CUDA-graph capture/replay for the (mtp_step+1)-expanded verify layout, a shared mtp_verify_extra_state block on infer_struct/batch_objs, fa3 decode attention narrowed to the verify layout (b_att_seq_len + causal) for fp/fp8/mla, and env/kv-cache helpers for MTP added-layer accounting. --- lightllm/common/basemodel/attention/fa3/fp.py | 11 +- .../common/basemodel/attention/fa3/fp8.py | 29 +- .../common/basemodel/attention/fa3/mla.py | 11 +- lightllm/common/basemodel/basemodel.py | 261 +++++++++++++---- lightllm/common/basemodel/batch_objs.py | 11 + lightllm/common/basemodel/cuda_graph.py | 271 ++++++++++-------- lightllm/common/basemodel/infer_struct.py | 2 + .../basemodel/mtp_verify_extra_state.py | 47 +++ lightllm/utils/envs_utils.py | 18 +- lightllm/utils/kv_cache_utils.py | 8 +- 10 files changed, 477 insertions(+), 192 deletions(-) create mode 100644 lightllm/common/basemodel/mtp_verify_extra_state.py diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..a7395faebf 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -7,6 +7,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn class Fa3AttBackend(BaseAttBackend): @@ -125,8 +126,9 @@ class Fa3DecodeAttState(BaseDecodeAttState): def init_state(self): self.backend: Fa3AttBackend = self.backend args_mtp_step = get_env_start_args().mtp_step + is_mtp_verify_decode = is_mtp_verify_decode_fn(args_mtp_step, self.infer_state.b_num_accepted_tokens) - if args_mtp_step > 0: + if is_mtp_verify_decode: # 修正 mtp 在 fa3 下的输入。 mtp_size = args_mtp_step + 1 b_q_seq_len = torch.full( @@ -143,8 +145,9 @@ def init_state(self): self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() - att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) - assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1 + att_batch_size = self.infer_state.batch_size // mtp_size + assert self.infer_state.batch_size % mtp_size == 0 model = self.backend.model # 可以使用 cuda graph的时候从 buffer中申请 @@ -163,7 +166,7 @@ def init_state(self): device=self.infer_state.input_ids.device, ) - if args_mtp_step > 0: + if is_mtp_verify_decode: page_table_copy( page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], req_to_token_indexs=model.req_manager.req_to_token_indexs, diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index acbb1315fe..d85a1caf33 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -3,7 +3,6 @@ from ..base_att import AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.sgl_utils import flash_attn_with_kvcache -from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.quantization.q_per_head_fp8_quant import q_per_head_fp8_quant from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops from typing import Union @@ -45,9 +44,12 @@ def init_state(self): torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len ) # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) def prefill_att( self, @@ -116,20 +118,19 @@ def init_state(self): super().init_state() self.backend: Fp8Fa3AttBackend = self.backend - args_mtp_step = get_env_start_args().mtp_step - att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) - assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 - - device = self.infer_state.input_ids.device - batch_size = att_batch_size + batch_size = self.b_att_seq_len.shape[0] mem_manager = self.backend.model.mem_manager offline_scales: torch.Tensor = mem_manager.scales head_num = mem_manager.head_num # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) return @@ -180,11 +181,11 @@ def _fp8_decode_att( k_cache=cache_k, v_cache=cache_v, page_table=self.page_table, - cache_seqlens=self.infer_state.b_seq_len, + cache_seqlens=self.b_att_seq_len, cu_seqlens_q=self.cu_seqlens_q, cu_seqlens_k_new=self.cu_seqlens_k, max_seqlen_q=self.decode_max_q_seq_len, - causal=False, + causal=True, window_size=(-1, -1), softcap=0.0, q_descale=q_scale.view(self.infer_state.batch_size, k_head_num), diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py index 9a10457b12..982bd117c3 100644 --- a/lightllm/common/basemodel/attention/fa3/mla.py +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -8,6 +8,7 @@ from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor from lightllm.utils.sgl_utils import flash_attn_varlen_func +from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn class MlaFa3AttBackend(BaseAttBackend): @@ -108,8 +109,9 @@ class MlaFa3DecodeAttState(BaseDecodeAttState): def init_state(self): self.backend: MlaFa3AttBackend = self.backend args_mtp_step = get_env_start_args().mtp_step + is_mtp_verify_decode = is_mtp_verify_decode_fn(args_mtp_step, self.infer_state.b_num_accepted_tokens) - if args_mtp_step > 0: + if is_mtp_verify_decode: # 修正 mtp 在 fa3 下的输入。 mtp_size = args_mtp_step + 1 b_q_seq_len = torch.full( @@ -126,8 +128,9 @@ def init_state(self): self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() - att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) - assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1 + att_batch_size = self.infer_state.batch_size // mtp_size + assert self.infer_state.batch_size % mtp_size == 0 model = self.backend.model # 可以使用 cuda graph的时候从 buffer中申请 @@ -146,7 +149,7 @@ def init_state(self): device=self.infer_state.input_ids.device, ) - if args_mtp_step > 0: + if is_mtp_verify_decode: page_table_copy( page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], req_to_token_indexs=model.req_manager.req_to_token_indexs, diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 94f9d4c1a2..54e3be1512 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -4,6 +4,7 @@ import gc import copy import json +import math import torch import torch.nn.functional as F import triton @@ -17,20 +18,33 @@ from lightllm.common.req_manager import ReqManager from lightllm.common.infer_utils import init_req_to_token_indexes from lightllm.common.build_utils import repair_config -from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req +from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import ( + copy_kv_index_to_req, +) from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg -from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed +from lightllm.common.basemodel.triton_kernel.gather_token_id import ( + gather_token, + gather_token_prefill_decode_mixed, +) from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size -from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num +from lightllm.utils.envs_utils import ( + get_env_start_args, + get_llm_data_type, + get_added_mtp_kv_layer_num, +) from lightllm.distributed.communication_op import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn from lightllm.common.triton_utils.autotuner import AutotuneLevel from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch -from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel +from lightllm.utils.envs_utils import ( + set_model_init_status, + enable_diverse_mode_gqa_decode_fast_kernel, +) from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache from .attention import get_prefill_att_backend_class, get_decode_att_backend_class @@ -315,6 +329,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.b_req_idx = model_input.b_req_idx infer_state.b_seq_len = model_input.b_seq_len infer_state.b_mtp_index = model_input.b_mtp_index + infer_state.b_num_accepted_tokens = model_input.b_num_accepted_tokens if model_input.is_prefill: if model_input.b_ready_cache_len is not None: infer_state.b_ready_cache_len = model_input.b_ready_cache_len @@ -352,6 +367,18 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) return infer_state + def _get_decode_padding_unit(self, model_input: ModelInput) -> int: + padding_unit = self.tp_world_size_ if self.args.enable_tpsp_mix_mode else 1 + if (not model_input.is_prefill) and is_mtp_verify_decode_fn( + self.args.mtp_step, model_input.b_num_accepted_tokens + ): + padding_unit = math.lcm(padding_unit, self.args.mtp_step + 1) + return padding_unit + + def _get_decode_infer_batch_size(self, model_input: ModelInput) -> int: + padding_unit = self._get_decode_padding_unit(model_input) + return triton.cdiv(model_input.batch_size, padding_unit) * padding_unit + def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_size: int): if model_input.batch_size == new_batch_size: return model_input @@ -361,22 +388,111 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s padded_batch_size = new_batch_size - model_input.batch_size new_model_input = copy.copy(model_input) new_model_input.batch_size = new_batch_size - new_model_input.total_token_num += padded_batch_size * 2 - new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) - new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) - new_model_input.b_req_idx = F.pad( - new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID - ) - new_model_input.b_mtp_index = F.pad( - new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 - ) - new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) - new_model_input.mem_indexes = F.pad( - new_model_input.mem_indexes, - (0, padded_batch_size), - mode="constant", - value=self.mem_manager.HOLD_TOKEN_MEMINDEX, + + is_mtp_verify_decode = (not model_input.is_prefill) and is_mtp_verify_decode_fn( + self.args.mtp_step, model_input.b_num_accepted_tokens ) + if is_mtp_verify_decode: + mtp_size = self.args.mtp_step + 1 + assert model_input.batch_size % mtp_size == 0 + assert new_batch_size % mtp_size == 0 + assert padded_batch_size % mtp_size == 0 + padded_req_num = padded_batch_size // mtp_size + + pad_mtp_index = torch.arange( + mtp_size, + dtype=new_model_input.b_mtp_index.dtype, + device=new_model_input.b_mtp_index.device, + ).repeat(padded_req_num) + pad_seq_len = torch.arange( + 2, + mtp_size + 2, + dtype=new_model_input.b_seq_len.dtype, + device=new_model_input.b_seq_len.device, + ).repeat(padded_req_num) + new_model_input.total_token_num += padded_req_num * (mtp_size * (mtp_size + 3) // 2) + new_model_input.max_kv_seq_len = max(mtp_size + 1, model_input.max_kv_seq_len) + new_model_input.input_ids = torch.cat( + ( + new_model_input.input_ids, + torch.ones( + padded_batch_size, + dtype=new_model_input.input_ids.dtype, + device=new_model_input.input_ids.device, + ), + ), + dim=0, + ) + new_model_input.b_req_idx = torch.cat( + ( + new_model_input.b_req_idx, + torch.full( + (padded_batch_size,), + self.req_manager.HOLD_REQUEST_ID, + dtype=new_model_input.b_req_idx.dtype, + device=new_model_input.b_req_idx.device, + ), + ), + dim=0, + ) + new_model_input.b_mtp_index = torch.cat((new_model_input.b_mtp_index, pad_mtp_index), dim=0) + new_model_input.b_seq_len = torch.cat((new_model_input.b_seq_len, pad_seq_len), dim=0) + new_model_input.mem_indexes = torch.cat( + ( + new_model_input.mem_indexes, + torch.full( + (padded_batch_size,), + self.mem_manager.HOLD_TOKEN_MEMINDEX, + dtype=new_model_input.mem_indexes.dtype, + device=new_model_input.mem_indexes.device, + ), + ), + dim=0, + ) + new_model_input.b_num_accepted_tokens = torch.cat( + ( + new_model_input.b_num_accepted_tokens, + torch.ones( + padded_req_num, + dtype=new_model_input.b_num_accepted_tokens.dtype, + device=new_model_input.b_num_accepted_tokens.device, + ), + ), + dim=0, + ) + else: + new_model_input.total_token_num += padded_batch_size * 2 + new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) + new_model_input.input_ids = F.pad( + new_model_input.input_ids, + (0, padded_batch_size), + mode="constant", + value=1, + ) + new_model_input.b_req_idx = F.pad( + new_model_input.b_req_idx, + (0, padded_batch_size), + mode="constant", + value=self.req_manager.HOLD_REQUEST_ID, + ) + new_model_input.b_mtp_index = F.pad( + new_model_input.b_mtp_index, + (0, padded_batch_size), + mode="constant", + value=0, + ) + new_model_input.b_seq_len = F.pad( + new_model_input.b_seq_len, + (0, padded_batch_size), + mode="constant", + value=2, + ) + new_model_input.mem_indexes = F.pad( + new_model_input.mem_indexes, + (0, padded_batch_size), + mode="constant", + value=self.mem_manager.HOLD_TOKEN_MEMINDEX, + ) new_model_input.multimodal_params = new_model_input.multimodal_params + [ {"images": [], "audios": []} for _ in range(padded_batch_size) ] @@ -384,11 +500,17 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s if enable_diverse_mode_gqa_decode_fast_kernel(): if new_model_input.b_shared_seq_len is not None: new_model_input.b_shared_seq_len = F.pad( - new_model_input.b_shared_seq_len, (0, padded_batch_size), mode="constant", value=0 + new_model_input.b_shared_seq_len, + (0, padded_batch_size), + mode="constant", + value=0, ) if new_model_input.b_mark_shared_group is not None: new_model_input.b_mark_shared_group = F.pad( - new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=1 + new_model_input.b_mark_shared_group, + (0, padded_batch_size), + mode="constant", + value=1, ) # 特殊模型,特殊模式的特殊变量的特殊 padding @@ -423,7 +545,10 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle value=self.mem_manager.HOLD_TOKEN_MEMINDEX, ) new_model_input.b_req_idx = F.pad( - new_model_input.b_req_idx, (0, 1), mode="constant", value=self.req_manager.HOLD_REQUEST_ID + new_model_input.b_req_idx, + (0, 1), + mode="constant", + value=self.req_manager.HOLD_REQUEST_ID, ) new_model_input.b_mtp_index = F.pad(new_model_input.b_mtp_index, (0, 1), mode="constant", value=0) new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, 1), mode="constant", value=padded_token_num) @@ -463,7 +588,10 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba return new_model_output def _create_unpad_prefill_model_output( - self, padded_model_output: ModelOutput, origin_handle_token_num: int, origin_batch_size: int + self, + padded_model_output: ModelOutput, + origin_handle_token_num: int, + origin_batch_size: int, ): if self.return_all_prompt_logics: new_model_output = copy.copy(padded_model_output) @@ -549,15 +677,16 @@ def _decode( ) origin_batch_size = model_input.batch_size - if self.args.enable_tpsp_mix_mode: - infer_batch_size = triton.cdiv(model_input.batch_size, self.tp_world_size_) * self.tp_world_size_ - else: - infer_batch_size = model_input.batch_size + infer_batch_size = self._get_decode_infer_batch_size(model_input) + is_mtp_verify_decode = is_mtp_verify_decode_fn(self.args.mtp_step, model_input.b_num_accepted_tokens) if self.graph is not None and self.graph.can_run( batch_size=infer_batch_size, max_len_in_batch=model_input.max_kv_seq_len ): - infer_batch_size = self.graph.find_closest_graph_batch_size(batch_size=infer_batch_size) + infer_batch_size = self.graph.find_closest_graph_batch_size( + batch_size=infer_batch_size, + is_mtp_verify_decode=is_mtp_verify_decode, + ) model_input = self._create_padded_decode_model_input( model_input=model_input, new_batch_size=infer_batch_size ) @@ -571,7 +700,7 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() - if self.graph.need_capture(infer_batch_size): + if self.graph.need_capture(infer_batch_size, is_mtp_verify_decode=is_mtp_verify_decode): infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: @@ -598,7 +727,6 @@ def _decode( @final def _context_forward(self, infer_state: InferStateInfo): - input_embs = self.pre_infer.context_forward(infer_state.input_ids, infer_state, self.pre_post_weight) if self.args.enable_dp_prefill_balance: assert not self.args.enable_prefill_cudagraph, "not support now" @@ -804,10 +932,14 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode origin_batch_size = model_input0.batch_size max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len) - infer_batch_size = triton.cdiv(origin_batch_size, self.tp_world_size_) * self.tp_world_size_ + infer_batch_size = self._get_decode_infer_batch_size(model_input0) + is_mtp_verify_decode = is_mtp_verify_decode_fn(self.args.mtp_step, model_input0.b_num_accepted_tokens) if self.graph is not None and self.graph.can_run(infer_batch_size, max_len_in_batch): - infer_batch_size = self.graph.find_closest_graph_batch_size(infer_batch_size) + infer_batch_size = self.graph.find_closest_graph_batch_size( + infer_batch_size, + is_mtp_verify_decode=is_mtp_verify_decode, + ) # TODO 如果支持动态步数的 mtp,在不同的mtp步上,model_input0 和 model_input1 的内部batch size可能不 # 一致,需要按照较高 batch size 进行graph的寻找,同时,进行有效的恢复。 padded_model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) @@ -832,7 +964,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.init_some_extra_state(self) infer_state1.init_att_state() - if self.graph.need_capture(infer_batch_size): + if self.graph.need_capture(infer_batch_size, is_mtp_verify_decode=is_mtp_verify_decode): infer_state0.is_cuda_graph = True infer_state1.is_cuda_graph = True @@ -884,7 +1016,11 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state g_cache_manager.cache_env_in() input_embs, input_embs1 = self.pre_infer.overlap_tpsp_context_forward( - infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight + infer_state.input_ids, + infer_state1.input_ids, + infer_state, + infer_state1, + self.pre_post_weight, ) # 决定是否进行 dp balance 优化,可以提升dp > 1 时的 prefill 效率。 @@ -900,7 +1036,11 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state for i in range(self.layers_num): input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward( - input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] + input_embs, + input_embs1, + infer_state, + infer_state1, + self.trans_layers_weight[i], ) # 折叠模式调用完infer_state 和 infer_state1 上的hook函数后,input_embs 和 input_embs1 才具备正确的运算数据。 @@ -914,7 +1054,11 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state last_input_embs1 = infer_state1._all_to_all_unbalance_get(data=last_input_embs1) predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( - last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight + last_input_embs, + last_input_embs1, + infer_state, + infer_state1, + self.pre_post_weight, ) g_cache_manager.cache_env_out() @@ -935,14 +1079,22 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state @final def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: InferStateInfo): input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward( - infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight + infer_state.input_ids, + infer_state1.input_ids, + infer_state, + infer_state1, + self.pre_post_weight, ) input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) input_embs1 = self.pre_infer._tpsp_sp_split(input=input_embs1, infer_state=infer_state1) for i in range(self.layers_num): input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_token_forward( - input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] + input_embs, + input_embs1, + infer_state, + infer_state1, + self.trans_layers_weight[i], ) # 折叠模式调用完infer_state 上的hook函数后,input_embs 和 input_embs 才具备正确的运算数据。 @@ -953,7 +1105,11 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: last_input_embs1 = self.post_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( - last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight + last_input_embs, + last_input_embs1, + infer_state, + infer_state1, + self.pre_post_weight, ) model_output = ModelOutput(logits=predict_logits.contiguous()) @@ -1060,7 +1216,12 @@ def _autotune_warmup(self): rand_gen = torch.Generator(device="cuda") rand_gen.manual_seed(input_len) dummy_input_ids = torch.randint( - 0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen + 0, + 10000, + (input_len,), + dtype=torch.int32, + device="cuda", + generator=rand_gen, ) b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() @@ -1124,10 +1285,14 @@ def _init_padded_req(self): batch_size = 1 dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") b_req_idx = torch.tensor( - [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], + dtype=torch.int32, + device="cuda", ) mem_indexes = torch.tensor( - [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], + dtype=torch.int32, + device="cuda", ) b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") @@ -1171,15 +1336,13 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_mtp_draft_model = ( - "Deepseek3MTPModel" in str(self.__class__) - or "Qwen3MOEMTPModel" in str(self.__class__) - or "MistralMTPModel" in str(self.__class__) - or "Glm4MoeLiteMTPModel" in str(self.__class__) - ) + is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False) if is_mtp_draft_model: special_model_input["mtp_draft_input_hiddens"] = torch.randn( - token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" + token_num, + self.config["hidden_size"], + dtype=self.data_type, + device="cuda", ) else: special_model_input["mtp_draft_input_hiddens"] = None diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 1795ff9a82..6104022733 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -6,6 +6,13 @@ from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor +def is_mtp_verify_decode(mtp_step: int, b_num_accepted_tokens) -> bool: + """Single source of truth for the MTP verify-decode predicate (#21). + A decode forward is a verify pass iff MTP is enabled and the per-real-request accept tensor is + present — decode_mtp sets it on the main verify and clears it (None) on every draft forward.""" + return mtp_step > 0 and b_num_accepted_tokens is not None + + @dataclass class ModelInput: # 通用变量 @@ -53,6 +60,8 @@ class ModelInput: # 的 draft 模型的输入 mtp_draft_input_hiddens: Optional[torch.Tensor] = None + b_num_accepted_tokens: Optional[torch.Tensor] = None + def to_cuda(self): if self.input_ids is not None: self.input_ids = self.input_ids.cuda(non_blocking=True) @@ -66,6 +75,8 @@ def to_cuda(self): self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) + if self.b_num_accepted_tokens is not None: + self.b_num_accepted_tokens = self.b_num_accepted_tokens.cuda(non_blocking=True) if self.b_ready_cache_len is not None: self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True) if self.b_prefill_start_loc is not None: diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 782150661e..d20de7afb8 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -2,12 +2,14 @@ import torch import copy import bisect +import math import triton from typing import Optional from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn from .infer_struct import InferStateInfo @@ -27,48 +29,135 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.normal_cuda_graph_batch_sizes = self._build_cuda_graph_batch_sizes(batch_size_multiple=1) + if self.mtp_step > 0: + self.mtp_verify_cuda_graph_batch_sizes = self._build_cuda_graph_batch_sizes( + batch_size_multiple=self.mtp_step + 1 + ) + logger.info(f"normal cuda graph batch_sizes: {self.normal_cuda_graph_batch_sizes}") + logger.info(f"mtp verify cuda graph batch_sizes: {self.mtp_verify_cuda_graph_batch_sizes}") + else: + self.mtp_verify_cuda_graph_batch_sizes = self.normal_cuda_graph_batch_sizes + logger.info(f"cuda graph batch_sizes: {self.normal_cuda_graph_batch_sizes}") + + def _build_cuda_graph_batch_sizes(self, batch_size_multiple: int): # gen cuda graph batch_sizes # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] - # and [graph_split_batch_size + graph_grow_step_size, - # if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1) - - graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1) - graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1) - - batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)] - for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size): + # and [graph_split_batch_size + graph_grow_step_size, ...] + graph_split_batch_size = self.args.graph_split_batch_size * batch_size_multiple + graph_grow_step_size = self.args.graph_grow_step_size * batch_size_multiple + + batch_sizes = [i * batch_size_multiple for i in range(1, self.args.graph_split_batch_size + 1)] + for _batch_size in range( + graph_split_batch_size + graph_grow_step_size, + self.max_batch_size, + graph_grow_step_size, + ): batch_sizes.append(_batch_size) - batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size])) - batch_sizes.append(max_batch_size) + batch_sizes = list(set([e for e in batch_sizes if e < self.max_batch_size])) + batch_sizes.append(self.max_batch_size) batch_sizes.sort() if self.args.enable_tpsp_mix_mode: - batch_sizes = [triton.cdiv(e, self.tp_world_size) * self.tp_world_size for e in batch_sizes] + padding_unit = math.lcm(self.tp_world_size, batch_size_multiple) + batch_sizes = [triton.cdiv(e, padding_unit) * padding_unit for e in batch_sizes] batch_sizes = list(set(batch_sizes)) batch_sizes.sort() - self.cuda_graph_batch_sizes = batch_sizes assert batch_sizes[-1] == self.max_batch_size - logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") + return batch_sizes def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch - def need_capture(self, batch_size): - find_batch_size = self.find_closest_graph_batch_size(batch_size) + def _decode_graph_key(self, infer_state: InferStateInfo): + is_mtp_verify_decode = is_mtp_verify_decode_fn(self.mtp_step, infer_state.b_num_accepted_tokens) + return (infer_state.input_ids.shape[0], is_mtp_verify_decode) + + def need_capture(self, batch_size, is_mtp_verify_decode=False): + find_batch_size = self.find_closest_graph_batch_size(batch_size, is_mtp_verify_decode=is_mtp_verify_decode) if find_batch_size is not None: - return find_batch_size not in self.graph + return (find_batch_size, is_mtp_verify_decode) not in self.graph else: assert False, "dead code" - def find_closest_graph_batch_size(self, batch_size): - index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size) - if index < len(self.cuda_graph_batch_sizes): - find_batch_size = self.cuda_graph_batch_sizes[index] + def _get_graph_batch_sizes(self, is_mtp_verify_decode=False): + if is_mtp_verify_decode: + return self.mtp_verify_cuda_graph_batch_sizes + return self.normal_cuda_graph_batch_sizes + + def find_closest_graph_batch_size(self, batch_size, is_mtp_verify_decode=False): + graph_batch_sizes = self._get_graph_batch_sizes(is_mtp_verify_decode=is_mtp_verify_decode) + index = bisect.bisect_left(graph_batch_sizes, batch_size) + if index < len(graph_batch_sizes): + find_batch_size = graph_batch_sizes[index] return find_batch_size else: return None + def _build_warmup_decode_model_input( + self, + model, + batch_size: int, + device: str = "cuda", + is_mtp_verify_decode: Optional[bool] = None, + ) -> ModelInput: + if is_mtp_verify_decode is None: + is_mtp_verify_decode = self.mtp_step > 0 + + mtp_size = self.mtp_step + 1 + input_ids = torch.ones(batch_size, dtype=torch.int32, device=device) + mem_indexes = model.mem_manager.alloc(batch_size).to(device) + b_req_idx = torch.full( + (batch_size,), + fill_value=model.req_manager.HOLD_REQUEST_ID, + dtype=torch.int32, + device=device, + ) + + b_num_accepted_tokens = None + if self.mtp_step > 0 and is_mtp_verify_decode: + assert batch_size % mtp_size == 0, "MTP decode CUDA graph batch size must be a multiple of mtp_step + 1" + real_batch_size = batch_size // mtp_size + b_mtp_index = torch.arange(mtp_size, dtype=torch.int32, device=device).repeat(real_batch_size) + b_seq_len = torch.arange(2, mtp_size + 2, dtype=torch.int32, device=device).repeat(real_batch_size) + b_num_accepted_tokens = torch.ones(real_batch_size, dtype=torch.int32, device=device) + total_token_num = real_batch_size * (mtp_size * (mtp_size + 3) // 2) + else: + seq_len = 2 + total_token_num = batch_size * seq_len + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device=device) + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device=device) + b_seq_len.fill_(seq_len) + + return ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_q_seq_len=1, + max_kv_seq_len=self.graph_max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_mtp_index=b_mtp_index, + b_num_accepted_tokens=b_num_accepted_tokens, + is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], + **model._gen_special_model_input(batch_size), + ) + + def _is_mtp_draft_model(self, model): + return getattr(model, "is_mtp_draft_model", False) + + def _iter_warmup_graph_layouts(self, model): + if self.mtp_step > 0: + if self._is_mtp_draft_model(model): + yield False, self.normal_cuda_graph_batch_sizes + else: + yield True, self.mtp_verify_cuda_graph_batch_sizes + else: + yield False, self.normal_cuda_graph_batch_sizes + def _capture_decode(self, decode_func, infer_state: InferStateInfo): graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids @@ -96,7 +185,11 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): with torch.cuda.graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) - self.graph[batch_size] = (graph_obj, infer_state, model_output) + self.graph[self._decode_graph_key(infer_state)] = ( + graph_obj, + infer_state, + model_output, + ) graph_obj.replay() return model_output @@ -130,7 +223,7 @@ def _capture_decode_overlap( with torch.cuda.graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) - self.graph[batch_size] = ( + self.graph[self._decode_graph_key(infer_state)] = ( graph_obj, infer_state, infer_state1, @@ -157,8 +250,7 @@ def capture_decode( return self._capture_decode(decode_func, infer_state) def _replay(self, infer_state: InferStateInfo): - batch_size = infer_state.input_ids.shape[0] - graph_obj, graph_infer_state, graph_output = self.graph[batch_size] + graph_obj, graph_infer_state, graph_output = self.graph[self._decode_graph_key(infer_state)] graph_infer_state.copy_for_cuda_graph(infer_state) graph_obj.replay() return graph_output @@ -168,14 +260,13 @@ def _replay_overlap( infer_state: InferStateInfo, infer_state1: InferStateInfo, ): - batch_size = infer_state.input_ids.shape[0] ( graph_obj, graph_infer_state, graph_infer_state1, graph_model_output, graph_model_output1, - ) = self.graph[batch_size] + ) = self.graph[self._decode_graph_key(infer_state)] graph_infer_state.copy_for_cuda_graph(infer_state) graph_infer_state1.copy_for_cuda_graph(infer_state1) graph_obj.replay() @@ -197,47 +288,23 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: - seq_len = 2 - total_token_num = batch_size * seq_len - max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() - b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") - b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - - model_input = ModelInput( - batch_size=batch_size, - total_token_num=total_token_num, - max_q_seq_len=1, - max_kv_seq_len=max_len_in_batch, - input_ids=input_ids, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - b_mtp_index=b_mtp_index, - is_prefill=False, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], - **model._gen_special_model_input(batch_size), - ) - model_output: ModelOutput = model.forward(model_input) - del model_output - del input_ids - del mem_indexes - del b_req_idx - del b_seq_len - - model.mem_manager.free_all() - model.req_manager.free_all() - # release local tensors - for var_name, var_value in list(locals().items()): - if isinstance(var_value, torch.Tensor): - del locals()[var_name] - torch.cuda.empty_cache() + for is_mtp_verify_decode, batch_sizes in self._iter_warmup_graph_layouts(model): + for batch_size in batch_sizes[::-1]: + model_input = self._build_warmup_decode_model_input( + model, + batch_size, + is_mtp_verify_decode=is_mtp_verify_decode, + ) + model_output: ModelOutput = model.forward(model_input) + del model_output + + model.mem_manager.free_all() + model.req_manager.free_all() + # release local tensors + for var_name, var_value in list(locals().items()): + if isinstance(var_value, torch.Tensor): + del locals()[var_name] + torch.cuda.empty_cache() logger.info( f"Capture cudagraph success, batch_size <={self.max_batch_size} " @@ -252,56 +319,36 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: - decode_batches = [] - for micro_batch_index in [0, 1]: - # dummy decoding, capture the cudagraph - seq_len = 2 - total_token_num = batch_size * seq_len - max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() - b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") - b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - - micro_batch = ModelInput( - is_prefill=False, - batch_size=batch_size, - total_token_num=total_token_num, - max_q_seq_len=1, - max_kv_seq_len=max_len_in_batch, - input_ids=input_ids, - b_mtp_index=b_mtp_index, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], - **model._gen_special_model_input(batch_size), - ) - decode_batches.append(micro_batch) - del micro_batch + for is_mtp_verify_decode, batch_sizes in self._iter_warmup_graph_layouts(model): + for batch_size in batch_sizes[::-1]: + decode_batches = [] + for micro_batch_index in [0, 1]: + # dummy decoding, capture the cudagraph + micro_batch = self._build_warmup_decode_model_input( + model, + batch_size, + is_mtp_verify_decode=is_mtp_verify_decode, + ) + decode_batches.append(micro_batch) + del micro_batch - for var_name, var_value in list(locals().items()): - if isinstance(var_value, torch.Tensor): - del locals()[var_name] - torch.cuda.empty_cache() + for var_name, var_value in list(locals().items()): + if isinstance(var_value, torch.Tensor): + del locals()[var_name] + torch.cuda.empty_cache() - _, _ = model.microbatch_overlap_decode(decode_batches[0], decode_batches[1]) + _, _ = model.microbatch_overlap_decode(decode_batches[0], decode_batches[1]) - model.mem_manager.free_all() - model.req_manager.free_all() + model.mem_manager.free_all() + model.req_manager.free_all() - del decode_batches + del decode_batches - # release local tensors - for var_name, var_value in list(locals().items()): - if isinstance(var_value, torch.Tensor): - del locals()[var_name] - torch.cuda.empty_cache() + # release local tensors + for var_name, var_value in list(locals().items()): + if isinstance(var_value, torch.Tensor): + del locals()[var_name] + torch.cuda.empty_cache() logger.info( f"Capture overlap cudagraph success, batch_size <={self.max_batch_size} " diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 711484c835..6de15f8910 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -39,6 +39,8 @@ def __init__(self): self.b_mtp_index: torch.Tensor = None + self.b_num_accepted_tokens: torch.Tensor = None + self.b_seq_len: torch.Tensor = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None diff --git a/lightllm/common/basemodel/mtp_verify_extra_state.py b/lightllm/common/basemodel/mtp_verify_extra_state.py new file mode 100644 index 0000000000..e4bc5f4f7a --- /dev/null +++ b/lightllm/common/basemodel/mtp_verify_extra_state.py @@ -0,0 +1,47 @@ +import torch + +from lightllm.utils.envs_utils import get_env_start_args + + +def init_mtp_verify_extra_state(self): + """Shared MTP-verify decode metadata, used by qwen3_5 and qwen3next infer-struct classes (#12). + Call AFTER super().init_some_extra_state(model). `self` is the InferStateInfo instance.""" + self.b_att_seq_len = self.b_seq_len + mtp_step = get_env_start_args().mtp_step + self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + # conv buffer is now ONE widened slot per request (indexed by req_idx), + # dropping the *(S+1) + mtp_index addressing used by the SSM block. + self.b_conv_buffer_idx = self.b_req_idx + # MTP verify batch: decode-mode, S+1 expanded, and gated on the + # per-real-request accept tensor that decode_mtp threads in. Gating on + # b_num_accepted_tokens (vs only b_mtp_index, which is set for any decode) + # distinguishes the main-model verify forward from draft/plain decode. + self.is_mtp_verify = ( + (mtp_step > 0) + and (not self.is_prefill) + and (self.b_mtp_index is not None) + and (self.b_num_accepted_tokens is not None) + ) + self.b_gdn_verify_cu_seqlens = None + self.b_ssm_index_rows = None + # b_num_accepted_tokens is threaded onto the infer_state from ModelInput by + # _create_inferstate (mirrors b_mtp_index) BEFORE this runs; nothing to do here. + if self.is_mtp_verify: + step = mtp_step + 1 + n_real = self.b_req_idx.shape[0] // step + self.b_gdn_verify_cu_seqlens = torch.arange( + 0, (n_real + 1) * step, step, dtype=torch.int32, device=self.b_req_idx.device + ) + req_first = self.b_req_idx.view(n_real, step)[:, 0] + base = (req_first * step).view(n_real, 1) + self.b_ssm_index_rows = base + torch.arange(step, device=base.device, dtype=base.dtype).view(1, step) + assert self.b_ssm_index_rows.shape == (n_real, step) + # The spec conv kernel is per-SEQUENCE (one program per real request), + # indexed by conv_state_indices[idx_seq] with idx_seq in [0, n_real), + # aligned 1:1 with b_gdn_verify_cu_seqlens / b_num_accepted_tokens. The + # default b_conv_buffer_idx = b_req_idx has the expanded length n_real*step, + # which launches n_real*step conv programs and reads num_accepted/ + # query_start_loc out of bounds for idx_seq >= n_real, corrupting the + # committed conv slot. Narrow it to one widened conv slot per request. + self.b_conv_buffer_idx = req_first + return diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 773320273c..f6508994b4 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -226,16 +226,20 @@ def enable_huge_page(): return enable_env_vars("LIGHTLLM_HUGE_PAGE_ENABLE") +def _mtp_added_layer_num(mtp_mode, mtp_step: int) -> int: + # Single source of truth for the mtp_mode -> added KV/full-att layer count (#9). + if mtp_mode == "eagle_with_att": + return 1 + if mtp_mode == "vanilla_with_att": + return mtp_step + return 0 + + @lru_cache(maxsize=None) def get_added_mtp_kv_layer_num() -> int: # mtp 模式下需要在mem manger上扩展draft model使用的layer - added_mtp_layer_num = 0 - if get_env_start_args().mtp_mode == "eagle_with_att": - added_mtp_layer_num += 1 - elif get_env_start_args().mtp_mode == "vanilla_with_att": - added_mtp_layer_num += get_env_start_args().mtp_step - - return added_mtp_layer_num + args = get_env_start_args() + return _mtp_added_layer_num(args.mtp_mode, args.mtp_step) @lru_cache(maxsize=None) diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 494908cb10..2a089a9bf2 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -120,8 +120,12 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": if args.mtp_mode is not None: # TODO 可能会存在不同mtp模式的精度问题 - assert is_linear_att_mixed_model(args.model_dir) is False, "linear att mixed model does not support mtp mode" - cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() + if is_linear_att_mixed_model(args.model_dir): + # Linear mixed models use one packed byte page; MTP draft full-attn + # slots are accounted in LinearAttCacheConfig.get_cpu_cache_big_page_bytes(). + pass + else: + cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() cpu_cache_page_num = int( (args.cpu_cache_storage_size * 1024 * 1024 * 1024) / (cpu_cache_meta.calcu_one_page_size()) From 32d901b2fc9d73590c559f2424a848f75ba51951 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Jun 2026 09:30:07 +0800 Subject: [PATCH 03/19] feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP draft models Self-contained dense (qwen3_5_mtp) and MoE (qwen3_5_moe_mtp) MTP draft packages: each carries its own draft wiring (reuse the main model's req/mem managers + rope caches, is_mtp_draft_model marker) and shares a weight-retarget mixin (mtp.* head, embeddings shared with the main model) plus the MTP pre-layer fuse. No shared model base class. --- lightllm/models/qwen3_5/infer_struct.py | 9 +- lightllm/models/qwen3_5_moe_mtp/__init__.py | 3 + .../qwen3_5_moe_mtp/layer_weights/__init__.py | 5 + .../layer_weights/transformer_layer_weight.py | 96 +++++++++++++++ lightllm/models/qwen3_5_moe_mtp/model.py | 8 ++ lightllm/models/qwen3_5_mtp/__init__.py | 0 .../qwen3_5_mtp/layer_infer/__init__.py | 0 .../layer_infer/pre_layer_infer.py | 40 +++++++ .../qwen3_5_mtp/layer_weights/__init__.py | 0 .../layer_weights/mtp_retarget_mixin.py | 61 ++++++++++ .../pre_and_post_layer_weight.py | 45 ++++++++ .../layer_weights/transformer_layer_weight.py | 23 ++++ lightllm/models/qwen3_5_mtp/model.py | 109 ++++++++++++++++++ 13 files changed, 392 insertions(+), 7 deletions(-) create mode 100644 lightllm/models/qwen3_5_moe_mtp/__init__.py create mode 100644 lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3_5_moe_mtp/model.py create mode 100644 lightllm/models/qwen3_5_mtp/__init__.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_infer/__init__.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_weights/__init__.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_weights/mtp_retarget_mixin.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3_5_mtp/model.py diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py index d23475c1cf..2687a4aca7 100644 --- a/lightllm/models/qwen3_5/infer_struct.py +++ b/lightllm/models/qwen3_5/infer_struct.py @@ -1,8 +1,4 @@ -import torch -from typing import List - from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args class Qwen35InferStateInfo(Qwen2VLInferStateInfo): @@ -12,8 +8,7 @@ def __init__(self): def init_some_extra_state(self, model): super().init_some_extra_state(model) - self.b_att_seq_len = self.b_seq_len - mtp_step = get_env_start_args().mtp_step + from lightllm.common.basemodel.mtp_verify_extra_state import init_mtp_verify_extra_state - self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + init_mtp_verify_extra_state(self) return diff --git a/lightllm/models/qwen3_5_moe_mtp/__init__.py b/lightllm/models/qwen3_5_moe_mtp/__init__.py new file mode 100644 index 0000000000..c8885f8869 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel + +__all__ = ["Qwen3_5MoeMTPModel"] diff --git a/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py new file mode 100644 index 0000000000..dcad1087d4 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py @@ -0,0 +1,5 @@ +from lightllm.models.qwen3_5_moe_mtp.layer_weights.transformer_layer_weight import ( + Qwen3_5MoeMTPTransformerLayerWeight, +) + +__all__ = ["Qwen3_5MoeMTPTransformerLayerWeight"] diff --git a/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..554db359f6 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,96 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import ( + COLMMWeight, + FusedMoeWeight, + ROWMMWeight, +) +from lightllm.models.qwen3_5_moe.layer_weights.transformer_layer_weight import ( + Qwen35MOETransformerLayerWeight, +) +from lightllm.models.qwen3_5_mtp.layer_weights.mtp_retarget_mixin import MTPRetargetMixin +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen3_5MoeMTPTransformerLayerWeight(MTPRetargetMixin, Qwen35MOETransformerLayerWeight): + def _init_weight_names(self): + super()._init_weight_names() + self._retarget_attn_norm_names() + + def _init_moe(self): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], + weight_names=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) + self._init_gated_ffn() + + def _init_gated_ffn(self): + hidden_size = self.network_config_["hidden_size"] + if "shared_expert_intermediate_size" not in self.network_config_: + return + + prefix = f"{self._MTP_PREFIX}{self.layer_num_}.mlp.shared_expert" + inter_size = self.network_config_["shared_expert_intermediate_size"] + if get_env_start_args().enable_ep_moe: + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + tp_rank=0, + tp_world_size=1, + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), + tp_rank=0, + tp_world_size=1, + ) + else: + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), + ) + + self.ffn_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/qwen3_5_moe_mtp/model.py b/lightllm/models/qwen3_5_moe_mtp/model.py new file mode 100644 index 0000000000..022864f6b3 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/model.py @@ -0,0 +1,8 @@ +from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel +from lightllm.models.qwen3_5_moe_mtp.layer_weights.transformer_layer_weight import ( + Qwen3_5MoeMTPTransformerLayerWeight, +) + + +class Qwen3_5MoeMTPModel(Qwen3_5MTPModel): + transformer_weight_class = Qwen3_5MoeMTPTransformerLayerWeight diff --git a/lightllm/models/qwen3_5_mtp/__init__.py b/lightllm/models/qwen3_5_mtp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_mtp/layer_infer/__init__.py b/lightllm/models/qwen3_5_mtp/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 0000000000..906a0ab62c --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,40 @@ +import torch + +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_5_mtp.layer_weights.pre_and_post_layer_weight import Qwen3_5MTPPreAndPostLayerWeight +from lightllm.models.llama.infer_struct import LlamaInferStateInfo + + +class Qwen3_5MTPPreLayerInfer(Qwen3VLMultimodalPreLayerInfer): + def __init__(self, network_config): + super().__init__(network_config) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_fuse( + self, + input_embdings: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3_5MTPPreAndPostLayerWeight, + ) -> torch.Tensor: + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert ( + input_embdings.shape[0] == tgt_embdings.shape[0] + ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" + + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + return layer_weight.eh_proj_weight_.mm(cat_embdings) + + def context_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3_5MTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_fuse(input_embdings, infer_state, layer_weight) + + def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3_5MTPPreAndPostLayerWeight): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_fuse(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_5_mtp/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/mtp_retarget_mixin.py b/lightllm/models/qwen3_5_mtp/layer_weights/mtp_retarget_mixin.py new file mode 100644 index 0000000000..cf9da94887 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_weights/mtp_retarget_mixin.py @@ -0,0 +1,61 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, QKVROWNMMWeight + + +class MTPRetargetMixin: + """Shared MTP weight-name retargeting (model.layers.* -> mtp.layers.*) and qkv/o_gate wiring, + used by both the dense and MoE Qwen3.5 MTP layer-weight classes (#11). The dense subclass adds + its dense-MLP retargets on top; the MoE subclass must not (it uses fused experts).""" + + _MAIN_PREFIX = "model.layers." + _MTP_PREFIX = "mtp.layers." + + _ATTN_NORM_NAME_ATTRS = ( + "_q_weight_name", + "_q_norm_name", + "_q_bias_name", + "_k_weight_name", + "_k_norm_name", + "_k_bias_name", + "_v_weight_name", + "_v_bias_name", + "_kv_weight_name", + "_kv_bias_name", + "_o_weight_name", + "_o_bias_name", + "_att_norm_weight_name", + "_att_norm_bias_name", + "_ffn_norm_weight_name", + "_ffn_norm_bias_name", + ) + + def _retarget(self, name): + if name is None: + return None + return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1) + + def _retarget_attn_norm_names(self): + for attr in self._ATTN_NORM_NAME_ATTRS: + setattr(self, attr, self._retarget(getattr(self, attr))) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight" + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..25c56a0d7e --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,45 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + NoTpGEMMANormWeight, + ROWMMWeight, +) +from lightllm.common.quantization import Quantcfg + + +class Qwen3_5MTPPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, quant_cfg: Quantcfg): + super().__init__(data_type, network_config) + self.quant_cfg: Quantcfg = quant_cfg + hidden_size = network_config["hidden_size"] + + self.eh_proj_weight_ = ROWMMWeight( + in_dim=hidden_size * 2, + out_dims=[hidden_size], + weight_names="mtp.fc.weight", + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(0, "eh_proj"), + tp_rank=0, + tp_world_size=1, + ) + self.enorm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_embedding.weight", + data_type=self.data_type_, + ) + self.hnorm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_hidden.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.norm.weight", + data_type=self.data_type_, + ) + + # Shared with the main Qwen3.5 model, injected by the model class (not loaded here). + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None + return diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..5aa0724580 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,23 @@ +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( + Qwen35TransformerLayerWeight, +) +from lightllm.models.qwen3_5_mtp.layer_weights.mtp_retarget_mixin import MTPRetargetMixin +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3_5MTPTransformerLayerWeight(MTPRetargetMixin, Qwen35TransformerLayerWeight): + def _init_weight_names(self): + super()._init_weight_names() + # Retarget all main-model layer key names to the mtp.* namespace. + self._retarget_attn_norm_names() + # MLP (dense) projection names retargeted by Qwen35TransformerLayerWeight. + self._gate_weight_name = self._retarget(self._gate_weight_name) + self._gate_bias_name = self._retarget(self._gate_bias_name) + self._up_weight_name = self._retarget(self._up_weight_name) + self._up_bias_name = self._retarget(self._up_bias_name) + self._gate_up_weight_name = self._retarget(self._gate_up_weight_name) + self._gate_up_bias_name = self._retarget(self._gate_up_bias_name) + self._down_weight_name = self._retarget(self._down_weight_name) + self._down_bias_name = self._retarget(self._down_bias_name) diff --git a/lightllm/models/qwen3_5_mtp/model.py b/lightllm/models/qwen3_5_mtp/model.py new file mode 100644 index 0000000000..b98524a997 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/model.py @@ -0,0 +1,109 @@ +from typing import List + +from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel +from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import Qwen35TransformerLayerInfer +from lightllm.models.qwen3_5_mtp.layer_weights.pre_and_post_layer_weight import Qwen3_5MTPPreAndPostLayerWeight +from lightllm.models.qwen3_5_mtp.layer_weights.transformer_layer_weight import Qwen3_5MTPTransformerLayerWeight +from lightllm.models.qwen3_5_mtp.layer_infer.pre_layer_infer import Qwen3_5MTPPreLayerInfer +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3_5MTPModel(Qwen3_5TpPartModel): + + pre_and_post_weight_class = Qwen3_5MTPPreAndPostLayerWeight + pre_layer_infer_class = Qwen3_5MTPPreLayerInfer + transformer_weight_class = Qwen3_5MTPTransformerLayerWeight + transformer_layer_infer_class = Qwen35TransformerLayerInfer + + # MTP draft model: reuses the main model's req/mem managers and rope caches, and is + # marked so the decode CUDA-graph / padding paths detect it (is_mtp_draft_model). + is_mtp_draft_model = True + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_config(self): + super()._init_config() + self.config["full_attention_interval"] = 1 + self.config["num_hidden_layers"] = 1 + self.config["n_layer"] = 1 + return + + def _init_some_value(self): + super()._init_some_value() + self.layers_num = 1 + return + + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(0, self.config["n_layer"]) + ] + # Shared with the main Qwen3.5 model (mtp_use_dedicated_embeddings: false). + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + return + + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None + # Build the single draft layer with layer_num == 0 so that, with + # full_attention_interval == 1, it takes the full-attention (mrope) path. + super()._init_infer_layer(start_layer_index=0) + self._assign_draft_kv_slot() + return + + def _assign_draft_kv_slot(self): + mem_manager = self.main_model.mem_manager + main_full_att = getattr(mem_manager, "main_full_att_layer_num", None) + interval = self.main_model.config["full_attention_interval"] + if main_full_att is None: + # Non-hybrid / unexpected mem_manager: nothing to remap. + return + + draft_idx = len(self.mtp_previous_draft_models) + draft_full_att_layers = getattr(mem_manager, "draft_full_att_layers", None) + if draft_full_att_layers is not None: + assert draft_idx < draft_full_att_layers, ( + f"draft_idx {draft_idx} out of range for draft_full_att_layers " + f"{draft_full_att_layers}; mem_manager not sized for this many MTP draft blocks" + ) + draft_kv_slot = main_full_att + draft_idx + layer_infer = self.layers_infer[0] + layer_infer.layer_num_ = draft_kv_slot * interval + logger.info( + f"Qwen3.5 MTP draft layer assigned dedicated full-attn KV slot {draft_kv_slot} " + f"(layer_num_={layer_infer.layer_num_}, interval={interval}, main_full_att={main_full_att})" + ) + return From 284def857a41e7e316e5c548e5800055a2d75d75 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Jun 2026 09:30:07 +0800 Subject: [PATCH 04/19] feat(qwen3next): GDN spec-decode verify path + linear-att cache split Gated-delta-net (linear attention) speculative-decode verify path for qwen3next: a per-sequence spec causal_conv1d kernel; a widened conv working slot split from the committed (narrow) persisted slot; MTP draft full-attn KV-slot accounting across the linear-att cache config, mem operator and req manager; and removal of the dead gen_b_req_mtp_start_loc kernel. --- .../triton_kernel/linear_att_copy.py | 97 ++-- .../linear_att_cpu_cache_copy.py | 7 +- .../basemodel/triton_kernel/mtp_utils.py | 40 -- .../operator/linear_att.py | 20 +- .../linear_att_cache_manager/config_objs.py | 29 +- .../linear_att_buffer_manager.py | 2 +- lightllm/common/req_manager.py | 54 +- lightllm/models/qwen3next/infer_struct.py | 8 +- .../layer_infer/transformer_layer_infer.py | 66 ++- lightllm/models/qwen3next/model.py | 16 +- .../triton_kernel/causal_conv1d_spec.py | 468 ++++++++++++++++++ 11 files changed, 688 insertions(+), 119 deletions(-) create mode 100644 lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py index d9f631cbd0..fd4c16043d 100644 --- a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py +++ b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py @@ -5,12 +5,13 @@ @triton.jit def _copy_linear_att_state_to_kv_buffer( - gpu_conv_ptr, # [linear_layer_num, size_num, xdim] + gpu_conv_ptr, # [linear_layer_num, size_num, conv_dim * gpu_widened_width] (uint8 tail) gpu_ssm_ptr, # [linear_layer_num, size_num, xxdim] - cpu_kv_conv_ptr, # [size, linear_layer_num, xdim] + cpu_kv_conv_ptr, # [size, linear_layer_num, conv_dim * width_narrow] (uint8 tail) cpu_kv_ssm_ptr, # [size, linear_layer_num, xxdim] b_req_idx, # [batch_size,] big_page_buffer_ids, # [batch_size,] + num_accepted_tokens_ptr, # [batch_size,] gpu_conv_stride_l, gpu_conv_stride_s, gpu_conv_stride_d, @@ -24,7 +25,9 @@ def _copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_stride_l, cpu_kv_ssm_stride_d, mtp_step, - gpu_conv_tail_dim, + conv_dim, # number of conv rows (the d dimension) + gpu_conv_row_bytes, # widened per-row byte length: gpu_widened_width * itemsize + conv_narrow_row_bytes, # narrow per-row byte length: width_narrow * itemsize gpu_ssm_tail_dim, BLOCK: tl.constexpr, ): @@ -40,28 +43,26 @@ def _copy_linear_att_state_to_kv_buffer( return cur_req_idx = tl.load(b_req_idx + cur_batch).to(tl.int64) - cur_state_req_idx = (cur_req_idx * (mtp_step + 1)).to(tl.int64) + accept_len = tl.load(num_accepted_tokens_ptr + cur_batch).to(tl.int64) + canonical_off = accept_len - 1 - for i in range(tl.cdiv(gpu_conv_tail_dim, BLOCK)): - gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) - mask = gpu_start_off < gpu_conv_tail_dim - conv_data = tl.load( - gpu_conv_ptr + cur_layer * gpu_conv_stride_l + cur_state_req_idx * gpu_conv_stride_s + gpu_start_off, - mask=mask, - ) - dest_conv_ptr = ( - cpu_kv_conv_ptr - + big_page_buffer_idx * cpu_kv_conv_stride_s - + cur_layer * cpu_kv_conv_stride_l - + gpu_start_off - ) - tl.store(dest_conv_ptr, conv_data, mask=mask) + conv_src_slot = cur_req_idx + conv_off_bytes = canonical_off * gpu_conv_stride_d + gpu_conv_base = gpu_conv_ptr + cur_layer * gpu_conv_stride_l + conv_src_slot * gpu_conv_stride_s + conv_off_bytes + cpu_conv_base = cpu_kv_conv_ptr + big_page_buffer_idx * cpu_kv_conv_stride_s + cur_layer * cpu_kv_conv_stride_l + for d in range(conv_dim): + for i in range(tl.cdiv(conv_narrow_row_bytes, BLOCK)): + off = i * BLOCK + tl.arange(0, BLOCK) + mask = off < conv_narrow_row_bytes + conv_data = tl.load(gpu_conv_base + d * gpu_conv_row_bytes + off, mask=mask) + tl.store(cpu_conv_base + d * cpu_kv_conv_stride_d + off, conv_data, mask=mask) + ssm_src_slot = (cur_req_idx * (mtp_step + 1) + canonical_off).to(tl.int64) for i in range(tl.cdiv(gpu_ssm_tail_dim, BLOCK)): gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) mask = gpu_start_off < gpu_ssm_tail_dim ssm_data = tl.load( - gpu_ssm_ptr + cur_layer * gpu_ssm_stride_l + cur_state_req_idx * gpu_ssm_stride_s + gpu_start_off, + gpu_ssm_ptr + cur_layer * gpu_ssm_stride_l + ssm_src_slot * gpu_ssm_stride_s + gpu_start_off, mask=mask, ) dest_ssm_ptr = ( @@ -75,32 +76,51 @@ def _copy_linear_att_state_to_kv_buffer( def copy_linear_att_state_to_kv_buffer( b_req_idx: torch.Tensor, big_page_buffer_ids: torch.Tensor, - gpu_conv_state: torch.Tensor, # [linear_layer_num, s, ...] - gpu_ssm_state: torch.Tensor, # [linear_layer_num, s, ...] - cpu_kv_conv_state: torch.Tensor, # [s, linear_layer_num, ...] - cpu_kv_ssm_state: torch.Tensor, # [s, linear_layer_num, ...] + gpu_conv_state: torch.Tensor, # [linear_layer_num, s_widened, conv_dim, gpu_widened_width] + gpu_ssm_state: torch.Tensor, # [linear_layer_num, s_block, ...] + cpu_kv_conv_state: torch.Tensor, # [size, linear_layer_num, conv_dim, width_narrow] + cpu_kv_ssm_state: torch.Tensor, # [size, linear_layer_num, ...] mtp_step: int, + b_num_accepted_tokens: torch.Tensor, # [batch_size,] per-req post-accept count (>=1) ): assert len(b_req_idx) == big_page_buffer_ids.shape[0] + assert len(b_req_idx) == b_num_accepted_tokens.shape[0] BLOCK = 4096 - gpu_conv_state = gpu_conv_state.view(gpu_conv_state.shape[0], gpu_conv_state.shape[1], -1).view(dtype=torch.uint8) - gpu_ssm_state = gpu_ssm_state.view(gpu_ssm_state.shape[0], gpu_ssm_state.shape[1], -1).view(dtype=torch.uint8) - cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], -1).view( - dtype=torch.uint8 + + assert gpu_conv_state.dim() >= 4, "gpu_conv_state must be [layer, s, conv_dim, widened_width]" + assert cpu_kv_conv_state.dim() >= 4, "cpu_kv_conv_state must be [size, layer, conv_dim, width_narrow]" + # #6: the byte snapshot hardcodes gpu_conv_stride_d=conv_itemsize, which is only valid when the + # widened-width axis is element-contiguous (stride 1). Fail fast instead of snapshotting wrong bytes. + assert gpu_conv_state.stride(3) == 1, ( + "gpu_conv_state widened-width axis must be element-contiguous (stride 1); " + "gpu_conv_stride_d=conv_itemsize assumes it" + ) + # #18: canonical_off = accept_len - 1 indexes into the widened slot; bound it to [0, mtp_step] + # (accept_len in [1, mtp_step+1]) so a stale/oversized accept-count can't slice past the slot. + assert int(b_num_accepted_tokens.min()) >= 1 and int(b_num_accepted_tokens.max()) <= mtp_step + 1, ( + f"b_num_accepted_tokens out of range [1, {mtp_step + 1}]: " + f"min={int(b_num_accepted_tokens.min())} max={int(b_num_accepted_tokens.max())}" ) + conv_itemsize = gpu_conv_state.element_size() + gpu_conv_state = gpu_conv_state.view( + gpu_conv_state.shape[0], gpu_conv_state.shape[1], gpu_conv_state.shape[2], -1 + ).view(dtype=torch.uint8) + cpu_kv_conv_state = cpu_kv_conv_state.view( + cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], cpu_kv_conv_state.shape[2], -1 + ).view(dtype=torch.uint8) + + gpu_ssm_state = gpu_ssm_state.view(gpu_ssm_state.shape[0], gpu_ssm_state.shape[1], -1).view(dtype=torch.uint8) cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], cpu_kv_ssm_state.shape[1], -1).view( dtype=torch.uint8 ) - assert gpu_conv_state.shape[-1] == cpu_kv_conv_state.shape[-1] + + assert gpu_conv_state.shape[2] == cpu_kv_conv_state.shape[2], "conv_dim mismatch between gpu and cpu conv buffers" assert gpu_ssm_state.shape[-1] == cpu_kv_ssm_state.shape[-1] - assert ( - gpu_conv_state.stride(-1) - == gpu_ssm_state.stride(-1) - == cpu_kv_conv_state.stride(-1) - == cpu_kv_ssm_state.stride(-1) - ) - gpu_conv_tail_dim = gpu_conv_state.shape[-1] + conv_dim = gpu_conv_state.shape[2] + gpu_conv_row_bytes = gpu_conv_state.shape[-1] # widened per-row byte length + conv_narrow_row_bytes = cpu_kv_conv_state.shape[-1] # narrow per-row byte length + assert conv_narrow_row_bytes <= gpu_conv_row_bytes gpu_ssm_tail_dim = gpu_ssm_state.shape[-1] layer_num = gpu_conv_state.shape[0] @@ -114,9 +134,10 @@ def copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_ptr=cpu_kv_ssm_state, b_req_idx=b_req_idx, big_page_buffer_ids=big_page_buffer_ids, + num_accepted_tokens_ptr=b_num_accepted_tokens, gpu_conv_stride_l=gpu_conv_state.stride(0), gpu_conv_stride_s=gpu_conv_state.stride(1), - gpu_conv_stride_d=gpu_conv_state.stride(2), + gpu_conv_stride_d=conv_itemsize, gpu_ssm_stride_l=gpu_ssm_state.stride(0), gpu_ssm_stride_s=gpu_ssm_state.stride(1), gpu_ssm_stride_d=gpu_ssm_state.stride(2), @@ -127,7 +148,9 @@ def copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_stride_l=cpu_kv_ssm_state.stride(1), cpu_kv_ssm_stride_d=cpu_kv_ssm_state.stride(2), mtp_step=mtp_step, - gpu_conv_tail_dim=gpu_conv_tail_dim, + conv_dim=conv_dim, + gpu_conv_row_bytes=gpu_conv_row_bytes, + conv_narrow_row_bytes=conv_narrow_row_bytes, gpu_ssm_tail_dim=gpu_ssm_tail_dim, BLOCK=BLOCK, ) diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py index 37b27cadb2..1251dddc33 100644 --- a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py +++ b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py @@ -193,11 +193,7 @@ def copy_kv_buffer_to_cpu_cache( cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1] full_att_layer_num = gpu_kv_full_att_state.shape[-2] - assert ( - full_att_layer_num - == (linear_config.all_layer_num // linear_config.full_attention_interval) - == (linear_config.all_layer_num - linear_config.linear_layer_num) - ) + assert full_att_layer_num == linear_config.get_persisted_full_att_layer_num() assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] @@ -428,6 +424,7 @@ def copy_cpu_cache_to_kv_buffer( cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1] full_att_layer_num = gpu_full_att_kv_state.shape[-2] + assert full_att_layer_num == linear_config.get_persisted_full_att_layer_num() assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] diff --git a/lightllm/common/basemodel/triton_kernel/mtp_utils.py b/lightllm/common/basemodel/triton_kernel/mtp_utils.py index 2d70a68c05..a020605c26 100644 --- a/lightllm/common/basemodel/triton_kernel/mtp_utils.py +++ b/lightllm/common/basemodel/triton_kernel/mtp_utils.py @@ -148,38 +148,6 @@ def mtp_scatter_next_token_ids( ) -@triton.jit -def _fwd_kernel_gen_b_req_mtp_start_loc( - b_mtp_index, - b_req_mtp_start_loc, - num_reqs: tl.constexpr, - batch_size: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - offset = tl.arange(0, BLOCK_SIZE) - cur_mtp_index = tl.load(b_mtp_index + offset, mask=offset < batch_size, other=-1) - non_zero_mask = tl.where(cur_mtp_index == 0, 1, 0) # 1 0 1 0 0 - output_offset = tl.cumsum(non_zero_mask) - 1 - tl.store(b_req_mtp_start_loc + output_offset, offset, mask=non_zero_mask == 1) - return - - -def gen_b_req_mtp_start_loc(b_mtp_index: torch.Tensor, num_reqs: int): - b_req_mtp_start_loc = torch.empty((num_reqs,), dtype=torch.int32, device=b_mtp_index.device) - BLOCK_SIZE = triton.next_power_of_2(b_mtp_index.shape[0]) - batch_size = b_mtp_index.shape[0] - grid = (1,) - _fwd_kernel_gen_b_req_mtp_start_loc[grid]( - b_mtp_index=b_mtp_index, - b_req_mtp_start_loc=b_req_mtp_start_loc, - num_reqs=num_reqs, - batch_size=batch_size, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=8, - ) - return b_req_mtp_start_loc - - def test_mtp_verify(): req_to_next_token_ids = torch.tensor( [[1, 2, -2, -1, -1], [1, 2, 0, -1, -1], [1, 3, 4, 4, 5]], dtype=torch.int32, device="cuda" @@ -201,13 +169,5 @@ def test_mtp_verify(): print(accepted_index) -def test_gen_b_req_mtp_start_loc(): - b_mtp_index = torch.tensor([0, 1, 0, 1, 2], dtype=torch.int32, device="cuda") - gt_output = torch.where(b_mtp_index == 0)[0] - b_req_mtp_start_loc = gen_b_req_mtp_start_loc(b_mtp_index, 2) - print(b_req_mtp_start_loc, gt_output) - - if __name__ == "__main__": test_mtp_verify() - # test_gen_b_req_mtp_start_loc() diff --git a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py index 109e813220..e3ae9493c7 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py +++ b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py @@ -24,6 +24,16 @@ def __init__(self, mem_manager): super().__init__(mem_manager) self.linear_config = LinearAttCacheConfig.load_from_args() + @staticmethod + def _get_persisted_full_att_layer_num(mem_manager) -> int: + persisted_full_att = getattr(mem_manager, "persisted_full_att_layer_num", None) + if persisted_full_att is None: + main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0]) + draft_full_att = getattr(mem_manager, "draft_full_att_layers", 0) + persisted_full_att = main_full_att + draft_full_att + assert 0 < persisted_full_att <= mem_manager.kv_buffer.shape[0] + return int(persisted_full_att) + def load_cpu_cache_to_gpu( self, mem_indexes: torch.Tensor, @@ -76,11 +86,14 @@ def load_cpu_cache_to_gpu( copy_cpu_cache_to_kv_buffer, ) + # Restore the persisted full-attn slice: main slots followed by MTP draft slots. + persisted_full_att = self._get_persisted_full_att_layer_num(mem_manager) + copy_cpu_cache_to_kv_buffer( mem_indexes=mem_indexes, big_page_buffer_ids=big_page_buffer_ids_gpu, page_indexes=page_indexes, - gpu_full_att_kv_state=mem_manager.kv_buffer, + gpu_full_att_kv_state=mem_manager.kv_buffer[:persisted_full_att], cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, @@ -169,12 +182,15 @@ def offload_gpu_kv_to_cpu_cache( copy_kv_buffer_to_cpu_cache, ) + # Persist the full-attn slice used for prefix reuse: main slots followed by MTP draft slots. + persisted_full_att = self._get_persisted_full_att_layer_num(mem_manager) + copy_kv_buffer_to_cpu_cache( mem_indexes=mem_indexes, page_indexes=page_indexes, page_readies=page_readies, big_page_buffer_ids=big_page_buffer_ids_gpu, - gpu_kv_full_att_state=mem_manager.kv_buffer, + gpu_kv_full_att_state=mem_manager.kv_buffer[:persisted_full_att], cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, diff --git a/lightllm/common/linear_att_cache_manager/config_objs.py b/lightllm/common/linear_att_cache_manager/config_objs.py index bc39067069..f533c71dbc 100644 --- a/lightllm/common/linear_att_cache_manager/config_objs.py +++ b/lightllm/common/linear_att_cache_manager/config_objs.py @@ -1,13 +1,18 @@ import torch import dataclasses import triton -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, _mtp_added_layer_num from lightllm.utils.log_utils import init_logger from lightllm.utils.torch_dtype_utils import get_torch_dtype logger = init_logger(__name__) +def get_mtp_draft_full_att_layer_num(args) -> int: + # Delegates to the single source of truth in envs_utils (#9). + return _mtp_added_layer_num(getattr(args, "mtp_mode", None), getattr(args, "mtp_step", 0)) + + @dataclasses.dataclass class LinearAttCacheConfig: tp_world_size: int @@ -30,6 +35,7 @@ class LinearAttCacheConfig: ssm_state_dtype: torch.dtype full_attention_interval: int all_layer_num: int # 包括 linear att 和 full att 的层加起来的层数 + draft_full_att_layer_num: int = 0 def get_conv_dim(self): # 第一项对应q的参数,第二项对应k的参数,第三项对应v的参数 @@ -41,9 +47,25 @@ def get_conv_dim(self): + self.head_linear_v_dim * self.num_linear_v_heads ) - def get_conv_state_shape(self): + def get_main_full_att_layer_num(self): + main_full_att_layer_num = self.all_layer_num - self.linear_layer_num + assert main_full_att_layer_num == self.all_layer_num // self.full_attention_interval + return main_full_att_layer_num + + def get_persisted_full_att_layer_num(self): + return self.get_main_full_att_layer_num() + self.draft_full_att_layer_num + + def get_persisted_conv_state_shape(self): + # NARROW shape used for the CPU/disk persisted page and ALL byte math. + # Persisted state is always the committed (narrow) sliding window. return (self.get_conv_dim(), self.conv_kernel_size - 1) + def get_gpu_conv_state_shape(self, mtp_step: int): + # WIDENED working shape for the GPU buffer: holds the tentatively + # rolled-in S speculative tokens before acceptance. width-1 + S, where + # S = mtp_step (a verify step has seqlen=S+1 -> width-1+(seqlen-1)). + return (self.get_conv_dim(), (self.conv_kernel_size - 1) + mtp_step) + def get_ssm_state_shape(self): return (self.num_linear_v_heads, self.head_linear_k_dim, self.head_linear_v_dim) @@ -66,7 +88,7 @@ def get_cpu_cache_full_att_bytes(self): ) assert big_page_token_num == get_env_start_args().cpu_cache_token_page_size full_att_bytes = 2 * self.full_att_all_num_kv_heads * self.full_att_head_dim * self.full_att_dtype.itemsize - a = full_att_bytes * (self.all_layer_num - self.linear_layer_num) * big_page_token_num + a = full_att_bytes * self.get_persisted_full_att_layer_num() * big_page_token_num return a def get_cpu_cache_conv_bytes(self): @@ -113,4 +135,5 @@ def load_from_args() -> "LinearAttCacheConfig": ssm_state_dtype=get_torch_dtype(args.linear_att_ssm_data_type), full_attention_interval=llm_config["full_attention_interval"], all_layer_num=n_layer, + draft_full_att_layer_num=get_mtp_draft_full_att_layer_num(args), ) diff --git a/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py b/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py index 30dc4d937c..2ab4313e37 100644 --- a/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py +++ b/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py @@ -24,7 +24,7 @@ def __init__( self.conv_state_cache = LayerCache( size=self.size, dtype=self.linear_config.conv_state_dtype, - shape=self.linear_config.get_conv_state_shape(), + shape=self.linear_config.get_persisted_conv_state_shape(), layer_num=self.linear_config.linear_layer_num, device="cpu", size_first=True, diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad35..6673243c9f 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -19,6 +19,18 @@ logger = init_logger(__name__) +# Width of req_to_next_token_ids: holds the seed token + up to (WIDTH - 1) MTP draft tokens. +REQ_NEXT_TOKEN_IDS_WIDTH = 8 + + +def assert_mtp_step_within_next_token_ids_width(mtp_step: int) -> None: + assert mtp_step <= REQ_NEXT_TOKEN_IDS_WIDTH - 1, ( + f"mtp_step={mtp_step} exceeds {REQ_NEXT_TOKEN_IDS_WIDTH - 1}; " + f"req_to_next_token_ids width is {REQ_NEXT_TOKEN_IDS_WIDTH} " + "(widening it is an explicit follow-up, spec §9)" + ) + + class _ReqNode: def __init__(self, index): self.index = index @@ -117,7 +129,7 @@ def __init__(self, max_request_num): self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_next_token_ids = torch.zeros( - (max_request_num + 1, 8), + (max_request_num + 1, REQ_NEXT_TOKEN_IDS_WIDTH), dtype=torch.int64, device="cuda", ) @@ -236,15 +248,13 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_con self.big_page_token_num = ( get_env_start_args().linear_att_page_block_num * get_env_start_args().linear_att_hash_page_size ) - assert ( - self.mtp_step == 0 - ), "currently only support mtp_step 0 for simplicity, more mtp_step support will be added in the future" + assert_mtp_step_within_next_token_ids_width(self.mtp_step) self.linear_config = linear_config self.req_to_conv_state = LayerCache( - size=(max_request_num + 1) * (self.mtp_step + 1), + size=(max_request_num + 1), dtype=self.linear_config.conv_state_dtype, - shape=self.linear_config.get_conv_state_shape(), + shape=self.linear_config.get_gpu_conv_state_shape(mtp_step=self.mtp_step), layer_num=self.linear_config.linear_layer_num, device="cuda", ) @@ -258,11 +268,13 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_con return def init_linear_att_state(self, req: "InferReq"): - index = req.req_idx * (self.mtp_step + 1) - conv_state = self.req_to_conv_state.buffer[:, index, ...] - ssm_state = self.req_to_ssm_state.buffer[:, index, ...] - conv_state.fill_(0) - ssm_state.fill_(0) + conv_index = req.req_idx + ssm_start = req.req_idx * (self.mtp_step + 1) + self.req_to_conv_state.buffer[:, conv_index, ...].fill_(0) + # #17: zero the FULL (mtp_step + 1)-row SSM block, not just canonical row +0, so a future + # first-step verify reading offset>0 after fresh init never hits a never-written row (NaN). + self.req_to_ssm_state.buffer[:, ssm_start : ssm_start + (self.mtp_step + 1), ...].fill_(0) + req.mtp_accept_len = 1 return def get_mamba_cache(self, layer_idx_in_all: int): @@ -275,16 +287,17 @@ def get_mamba_cache(self, layer_idx_in_all: int): return conv_states, ssm_states def copy_big_page_buffer_to_linear_att_state(self, big_page_buffer_idx: int, req: "InferReq"): - from .linear_att_cache_manager import LinearAttCacheManager big_page_buffers: LinearAttCacheManager = self.mem_manager.linear_att_big_page_buffers conv_state, ssm_state = big_page_buffers.get_state_cache(buffer_idx=big_page_buffer_idx) - dest_req_idx = req.req_idx * (self.mtp_step + 1) - - self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state - self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state + conv_dest = req.req_idx + ssm_dest = req.req_idx * (self.mtp_step + 1) + narrow_w = conv_state.shape[-1] # persisted (narrow) width + self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state + self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state + req.mtp_accept_len = 1 return def copy_small_page_buffer_to_linear_att_state( @@ -293,9 +306,12 @@ def copy_small_page_buffer_to_linear_att_state( conv_state, ssm_state = linear_att_small_page_buffers.get_state_cache( buffer_idx=req.shared_kv_node.small_page_buffer_idx ) - dest_req_idx = req.req_idx * (self.mtp_step + 1) + conv_dest = req.req_idx + ssm_dest = req.req_idx * (self.mtp_step + 1) + narrow_w = conv_state.shape[-1] # TODO 下面这个从 cpu cache 拷贝数据的 gpu的操作,是否是阻塞的操作。 # 同时,非连续对象的拷贝,可能存在效率问题。 - self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state - self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state + self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state + self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state + req.mtp_accept_len = 1 return diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py index 0006a682f1..b486bc6040 100644 --- a/lightllm/models/qwen3next/infer_struct.py +++ b/lightllm/models/qwen3next/infer_struct.py @@ -1,6 +1,4 @@ -import torch from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args class Qwen3NextInferStateInfo(LlamaInferStateInfo): @@ -10,7 +8,7 @@ def __init__(self): def init_some_extra_state(self, model): super().init_some_extra_state(model) - self.b_att_seq_len = self.b_seq_len - mtp_step = get_env_start_args().mtp_step - self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + from lightllm.common.basemodel.mtp_verify_extra_state import init_mtp_verify_extra_state + + init_mtp_verify_extra_state(self) return diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index bb48bfe49c..76c273c0e7 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -45,7 +45,6 @@ def __init__(self, layer_num, network_config): return def _init_linear_layer_metadata(self, layer_num, network_config): - # Linear attention specific dimensions self.num_v_heads = network_config["linear_num_value_heads"] self.num_k_heads = network_config["linear_num_key_heads"] @@ -121,7 +120,6 @@ def _compute_shared_expert( def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input.view(-1, self.embed_dim_) @@ -254,6 +252,18 @@ def gdn_forward( if is_prefill: core_attn_out, z = self._gdn_prefill_wrapper_run(mixed_qkvzba, infer_state, layer_weight) + elif getattr(infer_state, "is_mtp_verify", False): + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) + conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) + core_attn_out = self._gdn_verify_kernel( + mixed_qkv, + conv_states, + ssm_states, + a, + b, + infer_state, + layer_weight, + ) else: mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) @@ -374,7 +384,7 @@ def _gdn_prefill_kernel( layer_weight.linear_conv1d.mm_param.weight, bias=layer_weight.linear_conv1d.bias, query_start_loc=infer_state.b1_cu_q_seq_len, - cache_indices=infer_state.b_buffer_idx, + cache_indices=infer_state.b_conv_buffer_idx, has_initial_state=infer_state.b_ready_cache_len > 0, conv_states=conv_states, activation=self.activation, @@ -419,7 +429,7 @@ def _gdn_decode_kernel( layer_weight.linear_conv1d.mm_param.weight, bias=layer_weight.linear_conv1d.bias, activation=self.activation, - conv_state_indices=infer_state.b_buffer_idx, + conv_state_indices=infer_state.b_conv_buffer_idx, ) # Recurrent processing with fused gating @@ -439,3 +449,51 @@ def _gdn_decode_kernel( b_raw=b, ) return core_attn_out + + def _gdn_verify_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ): + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import ( + causal_conv1d_update as causal_conv1d_update_spec, + ) + + mixed_qkv = causal_conv1d_update_spec( + mixed_qkv, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + activation=self.activation, + conv_state_indices=infer_state.b_conv_buffer_idx, + num_accepted_tokens=infer_state.b_num_accepted_tokens, + query_start_loc=infer_state.b_gdn_verify_cu_seqlens, + ) + + query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=False) + assert infer_state.b_ssm_index_rows.dim() == 2, "SSM index rows must be 2D [N, S+1]" + # #8b: b_num_accepted_tokens >= 1 is guaranteed upstream (init sets accept_len=1; the + # offload/snapshot guards bound it to [1, mtp_step+1]). The old per-layer per-step .all() + # D2H sync stalled the GPU on the eager decode hot path; it is redundant here. + core_attn_out, _ = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + initial_state=ssm_states, + inplace_final_state=True, + cu_seqlens=infer_state.b_gdn_verify_cu_seqlens.to(torch.long), + ssm_state_indices=infer_state.b_ssm_index_rows, + ssm_state_write_indices=infer_state.b_ssm_index_rows, + num_accepted_tokens=infer_state.b_num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + A_log=layer_weight.linear_A_log.weight, + dt_bias=layer_weight.linear_dt_bias.weight, + a_raw=a, + b_raw=b, + ) + return core_attn_out diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 9b5e9b7a50..f61d9e4c6a 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -16,14 +16,16 @@ from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.req_manager import ReqManagerForMamba -from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig +from lightllm.common.linear_att_cache_manager.config_objs import ( + LinearAttCacheConfig, + get_mtp_draft_full_att_layer_num, +) logger = init_logger(__name__) @ModelRegistry("qwen3_next") class Qwen3NextTpPartModel(Qwen3MOEModel): - # weight class pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight transformer_weight_class = Qwen3NextTransformerLayerWeight @@ -59,6 +61,7 @@ def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 start_args: StartArgs = get_env_start_args() ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + draft_full_att_layers = get_mtp_draft_full_att_layer_num(start_args) self.linear_config = LinearAttCacheConfig( tp_world_size=self.tp_world_size_, full_att_all_num_kv_heads=self.config["num_key_value_heads"], @@ -78,17 +81,24 @@ def _init_mem_manager(self): ssm_state_dtype=ssm_dtype_dict[start_args.linear_att_ssm_data_type], full_attention_interval=self.config["full_attention_interval"], all_layer_num=self.config["n_layer"], + draft_full_att_layer_num=draft_full_att_layers, ) + main_full_att = self.linear_config.get_main_full_att_layer_num() + persisted_full_att = self.linear_config.get_persisted_full_att_layer_num() + self.mem_manager = Qwen3NextMemManager( size=self.max_total_token_num, dtype=self.data_type, num_kv_heads=self.num_kv_heads, head_dim=self.config["head_dim"], - full_att_layer_num=self.linear_config.all_layer_num - self.linear_config.linear_layer_num, + full_att_layer_num=persisted_full_att, linear_config=self.linear_config, mem_fraction=self.mem_fraction, ) + self.mem_manager.main_full_att_layer_num = main_full_att + self.mem_manager.draft_full_att_layers = draft_full_att_layers + self.mem_manager.persisted_full_att_layer_num = persisted_full_att def _init_req_manager(self): create_max_seq_len = 0 diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py new file mode 100644 index 0000000000..2f0e22fa3f --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py @@ -0,0 +1,468 @@ +# Vendored from vLLM v0.14.1 +# source: vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# commit: d7de043d55d1dd629554467e23874097e1c48993 +# Adapted for LightLLM: imports point at standard triton; the vLLM-specific +# block-table params (block_idx_last_scheduled_token, initial_state_idx, +# null_block_id) are dropped — LightLLM uses contiguous per-request slots. +# Supports spec-decode: writes per-position conv state to a single widened slot +# per request and reads from offset (num_accepted_tokens-1). +# +# Upstream copyright notice: +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Tri Dao. +# Adapted from +# https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +from typing import Optional + +import torch +import triton +import triton.language as tl + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + # LightLLM uses contiguous per-request slots, so the cache block for both + # the initial-state read and the final write is always conv_state_indices[idx_seq]. + conv_state_init = 0 + current_last_index = 0 + + # cache_idx + conv_states_input_coord = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init).to( + tl.int64 + ) + + if USE_PAD_SLOT: # noqa + if conv_states_input_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_VARLEN: + query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) + query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64) + # revise state_len and seqlen + state_len = state_len - (seqlen - (query_end_index - query_start_index)) + seqlen = query_end_index - query_start_index + x_offset = query_start_index * stride_x_token + o_offset = query_start_index * stride_o_token + else: + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = ( + conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 6: + conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] + col4 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_states_input_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] + + x_ptrs = x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + # Write the updated state back. In LightLLM the read and write slots are the + # same contiguous per-request slot (current_last_index == conv_state_init == 0), + # so this resolves to the same conv_state_indices[idx_seq] used for the read. + conv_states_offset = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index).to( + tl.int64 + ) + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[ + None, : + ] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok + )[ + :, None + ] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 5: + w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor + w_col4 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 6: + w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor + w_col5 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + elif KERNEL_WIDTH == 5: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x + elif KERNEL_WIDTH == 6: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & (idx_feats < dim) # token-index # feature-index + o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + pad_slot_id: int = -1, +): + """Spec-decode capable conv1d update. When num_accepted_tokens/query_start_loc + are None it must behave like a single-token decode update. x may be (batch, dim) + single-token or (num_tokens, dim) flattened varlen with query_start_loc grouping + each request's S+1 candidates. conv_state is (num_slots, dim, state_len) with + state_len = (width-1)+S widened. Read offset = num_accepted_tokens-1; writes to + the same slot. + + Args: + x: input tensor of shape ``(batch, dim)`` (single-token decode), + ``(batch, dim, seqlen)`` (single/multi token), or ``(num_tokens, dim)`` + flattened varlen grouped by ``query_start_loc``. + conv_state: ``(num_slots, dim, state_len)`` with ``state_len >= width - 1``. + For spec decode the slot is widened to ``(width - 1) + S`` where ``S`` is + the number of speculative tokens (so ``seqlen == S + 1``). + weight: depthwise filter of shape ``(dim, width)``. + bias: optional ``(dim,)`` bias. + activation: ``None``, ``"silu"`` or ``"swish"``. + cache_seqlens: accepted for call-compatibility with the non-spec wrapper; + unused here. + conv_state_indices: ``(batch,)`` int32 mapping each request to its conv_state + slot. Required when ``query_start_loc`` is given. + num_accepted_tokens: ``(batch,)`` int32. When not None the conv_state read + offset for each request is ``num_accepted_tokens - 1`` (sliding window + spec-decode update). + query_start_loc: ``(batch + 1,)`` int32 varlen cumulative token offsets; when + None the call is a plain single-/multi-token decode update. + pad_slot_id: slot id that marks padded entries to skip. + + Returns: + Output tensor with the same shape as ``x`` (the kernel overwrites ``x`` in + place), one conv output per input token. + """ + if activation is not None: + assert activation in ["silu", "swish"] + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + unsqueeze = query_start_loc is None and x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + if query_start_loc is None: + batch, dim, seqlen = x.shape + else: + assert conv_state_indices is not None + batch = conv_state_indices.size(0) + dim = x.size(1) + # The MTP verify layout is uniform (mtp_step+1) tokens per request, so seqlen is + # structurally x.size(0) // batch. Compute it without a D2H sync on query_start_loc on + # BOTH the capture and eager paths (#8a) — the eager .item() ran once per GDN layer per + # decode step. .item() is also illegal during CUDA-graph capture. + assert x.size(0) % batch == 0, "varlen conv update expects a uniform per-request length" + seqlen = x.size(0) // batch + _, width = weight.shape + # conv_state: (num_slots, dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + # adopt the strategy in vLLM that overwrites 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + if query_start_loc is None: + # X (batch, dim, seqlen) + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + # X (num_tokens, dim) + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 + + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0 + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + conv_state_indices, + num_accepted_tokens, + query_start_loc, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_VARLEN=query_start_loc is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) + if unsqueeze: + out = out.squeeze(-1) + return out.to(original_x_dtype) From 7c672ac37abe2b88d2399549f45e3db120bc8524 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Jun 2026 09:30:08 +0800 Subject: [PATCH 05/19] feat(scheduler): MTP verify backend + accept-len transport Wire the verify path through the inference backends: a single draft-model factory keyed on (model_type, mtp_mode); build the (mtp_step+1)-expanded verify decode batch; run the eagle + vanilla draft decode; verify accepted tokens; and thread per-request accept-lengths (b_num_accepted_tokens) from the chunked-prefill and dp backends into the model verify forward. --- .../server/router/model_infer/infer_batch.py | 27 ++- .../model_infer/mode_backend/base_backend.py | 89 +++++-- .../mode_backend/chunked_prefill/impl.py | 76 +++--- .../mode_backend/dp_backend/impl.py | 227 +++++++++++++----- .../mode_backend/mtp_model_factory.py | 33 +++ 5 files changed, 340 insertions(+), 112 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/mtp_model_factory.py diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 5c2d0d45fb..fc02095f85 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -361,6 +361,11 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L if not self.is_linear_att_mixed_model: return + # 当 dynamic prompt cache 被禁用时 radix_cache 为 None,没有大页/小页缓冲可写, + # 线性层状态仅存于 req_manager 的 GPU buffer 即可,直接跳过跨请求缓存拷贝。 + if self.radix_cache is None: + return + # 大页对应的 linear att 的拷贝 big_page_token_num = self.args.linear_att_hash_page_size * self.args.linear_att_page_block_num big_page_buffer_ids = [] @@ -384,6 +389,10 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer + b_num_accepted_tokens = torch.tensor( + [req.mtp_accept_len for req in reqs], dtype=torch.int32, requires_grad=False, device="cpu" + ).cuda(non_blocking=True) + copy_linear_att_state_to_kv_buffer( b_req_idx=b_req_idx, big_page_buffer_ids=big_page_buffer_ids, @@ -392,6 +401,7 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L cpu_kv_conv_state=self.radix_cache.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=self.radix_cache.linear_att_big_page_buffers.ssm_state_cache.buffer, mtp_step=self.args.mtp_step, + b_num_accepted_tokens=b_num_accepted_tokens, ) assert not self.args.disable_chunked_prefill, "chunked prefill mode must be enabled for linear att mixed model" @@ -407,9 +417,18 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L self.radix_cache.linear_att_small_page_buffers.alloc_one_state_cache() ) if req.tail_linear_att_small_page_buffer_id is not None: - src_buffer_idx = req.req_idx * (self.args.mtp_step + 1) - gpu_conv_state = self.req_manager.req_to_conv_state.buffer[:, src_buffer_idx, ...] - gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...] + assert 1 <= req.mtp_accept_len <= self.args.mtp_step + 1, ( + f"mtp_accept_len={req.mtp_accept_len} out of range " + f"[1, {self.args.mtp_step + 1}]; would slice past the widened conv slot" + ) + canonical_off = req.mtp_accept_len - 1 + conv_src_idx = req.req_idx + ssm_src_idx = req.req_idx * (self.args.mtp_step + 1) + canonical_off + narrow_w = self.req_manager.linear_config.get_persisted_conv_state_shape()[-1] + gpu_conv_state = self.req_manager.req_to_conv_state.buffer[ + :, conv_src_idx, ..., canonical_off : canonical_off + narrow_w + ] + gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, ssm_src_idx, ...] dst_buffer_idx = req.tail_linear_att_small_page_buffer_id dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache( @@ -559,6 +578,8 @@ def __init__( else: self.decode_need_token_num = self._normal_decode_need_token_num + self.mtp_accept_len: int = 1 + if g_infer_context.is_linear_att_mixed_model: self.get_chuncked_input_token_len = self.get_chuncked_input_token_len_for_linear_att self.get_chuncked_input_token_ids = self.get_chuncked_input_token_ids_for_linear_att diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a65dfb1bbb..cde5c03000 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -1,4 +1,5 @@ import os +import copy import numpy as np import torch import time @@ -41,10 +42,6 @@ ) from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack -from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel -from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel -from lightllm.models.mistral_mtp.model import MistralMTPModel -from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import PDChunckedTransTaskRet @@ -328,22 +325,11 @@ def init_mtp_draft_model(self, main_kvargs: dict): "mtp_previous_draft_models": self.draft_models.copy(), } - # Select MTP model class based on model type + # Select MTP model class based on model type (single source of truth: #10). + from lightllm.server.router.model_infer.mode_backend.mtp_model_factory import create_mtp_draft_model + model_type = mtp_model_cfg.get("model_type", "") - if model_type == "deepseek_v3": - assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] - self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) - elif model_type == "qwen3_moe": - assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] - self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) - elif model_type == "mistral": - assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] - self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "glm4_moe_lite": - assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] - self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) - else: - raise ValueError(f"Unsupported MTP model type: {model_type}") + self.draft_models.append(create_mtp_draft_model(model_type, self.args.mtp_mode, mtp_model_kvargs)) self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return @@ -584,7 +570,6 @@ def _get_classed_reqs( can_alloc_token_num = g_infer_context.get_can_alloc_token_num() for req_obj in ready_reqs: - if req_obj.filter_mark: finished_reqs.append(req_obj) continue @@ -761,11 +746,69 @@ def _verify_mtp_v2( ) return mtp_accept_len, accepted_index + def _build_eagle_accepted_draft_input( + self, + main_model_input: ModelInput, + main_model_output: ModelOutput, + next_token_ids: torch.Tensor, + mtp_accept_len: torch.Tensor, + b_req_mtp_start_loc: torch.Tensor, + ): + accepted_row_idx = b_req_mtp_start_loc + mtp_accept_len - 1 + accepted_row_idx_long = accepted_row_idx.long() + + draft_model_input = copy.copy(main_model_input) + draft_model_input.batch_size = accepted_row_idx.shape[0] + draft_model_input.total_token_num = draft_model_input.batch_size * main_model_input.max_kv_seq_len + draft_model_input.input_ids = next_token_ids.index_select(0, accepted_row_idx_long) + draft_model_input.mtp_draft_input_hiddens = main_model_output.mtp_main_output_hiddens.index_select( + 0, accepted_row_idx_long + ) + draft_model_input.b_req_idx = main_model_input.b_req_idx.index_select(0, accepted_row_idx_long) + draft_model_input.b_mtp_index = main_model_input.b_mtp_index.index_select(0, accepted_row_idx_long) + draft_model_input.b_seq_len = main_model_input.b_seq_len.index_select(0, accepted_row_idx_long) + draft_model_input.b_num_accepted_tokens = None + if main_model_input.mem_indexes is not None: + draft_model_input.mem_indexes = main_model_input.mem_indexes.index_select(0, accepted_row_idx_long) + draft_model_input.mem_indexes_cpu = None + if main_model_input.b_shared_seq_len is not None: + draft_model_input.b_shared_seq_len = main_model_input.b_shared_seq_len.index_select( + 0, accepted_row_idx_long + ) + if main_model_input.b_mark_shared_group is not None: + draft_model_input.b_mark_shared_group = main_model_input.b_mark_shared_group.index_select( + 0, accepted_row_idx_long + ) + + if accepted_row_idx.device.type == "cpu": + selected_rows = accepted_row_idx.tolist() + draft_model_input.multimodal_params = [main_model_input.multimodal_params[i] for i in selected_rows] + else: + draft_model_input.multimodal_params = [ + {"images": [], "audios": []} for _ in range(draft_model_input.batch_size) + ] + + accepted_next_token_ids = draft_model_input.input_ids + accepted_req_idx = draft_model_input.b_req_idx + return draft_model_input, accepted_next_token_ids, accepted_req_idx + + def _scatter_accepted_next_token_ids(self, accepted_req_idx: torch.Tensor, all_next_token_ids: torch.Tensor): + req_to_next_token_ids = self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids + width = all_next_token_ids.shape[1] + req_to_next_token_ids[:, :width].index_copy_( + 0, + accepted_req_idx.long(), + all_next_token_ids.to(dtype=req_to_next_token_ids.dtype), + ) + return + def _update_mtp_accept_ratio( self, decode_reqs: List[InferReq], mtp_accept_len_cpu: torch.Tensor, ): + # Master-only accept-ratio statistics. Unlike the phase-2 mtp_accept_len commit + # (inlined in decode_mtp) this only feeds metrics, so it may stay in phase 3. if self.is_master_in_dp: for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): req.update_mtp_accepted_token_num(accept_token_num=accept_len - 1) @@ -773,8 +816,9 @@ def _update_mtp_accept_ratio( def _gen_argmax_token_ids(self, model_output: ModelOutput): logits = model_output.logits - probs = torch.softmax(logits, dim=-1) - draft_next_token_ids_gpu = torch.argmax(probs, dim=-1) + # softmax is strictly monotonic, so argmax(softmax(logits)) == argmax(logits); + # skip the softmax to shorten the per-step MTP draft critical chain (need-to-fix #16). + draft_next_token_ids_gpu = torch.argmax(logits, dim=-1) return draft_next_token_ids_gpu def _sample_and_scatter_token( @@ -787,7 +831,6 @@ def _sample_and_scatter_token( b_prefill_has_output_cpu: torch.Tensor = None, mask_func: Optional[Callable] = None, ): - if mask_func is not None: assert len(run_reqs) == logits.shape[0] mask_func(run_reqs, logits) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 792a10a788..ed01c14a53 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -1,5 +1,6 @@ import torch import time +import copy from typing import List, Optional, Callable, Dict, Any from queue import Queue from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend @@ -240,17 +241,23 @@ def decode_mtp( """ model_input, run_reqs = prepare_decode_inputs(decode_reqs) + if self.mtp_step > 0: + accept_lens = [req.mtp_accept_len for req in decode_reqs] + model_input.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list( + key="b_num_accepted_tokens", + data=accept_lens, + dtype=torch.int32, + ) + with torch.cuda.stream(g_infer_context.get_overlap_stream()): - b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) - # verify the next_token_ids - b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] - b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list( - key="b_req_mtp_start_loc", - data=b_req_mtp_start_loc, - dtype=torch.int32, - ).cuda(non_blocking=True) + # verify the next_token_ids. The chunked decode batch is the contiguous + # (mtp_step+1)-expanded layout, so request starts are structurally + # arange(n_real)*(mtp_step+1). Compute on device instead of a per-step Python + # list-comp + pinned pack + H2D (#22). + n_real = model_input.batch_size // (self.mtp_step + 1) + b_req_mtp_start_loc = torch.arange(n_real, dtype=torch.int32, device="cuda") * (self.mtp_step + 1) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, @@ -292,6 +299,8 @@ def decode_mtp( # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() verify_event.synchronize() + for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): + req.mtp_accept_len = int(accept_len) verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1] update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) @@ -344,15 +353,19 @@ def _draft_decode_vanilla( mtp_accept_len: torch.Tensor, b_req_mtp_start_loc: torch.Tensor, ): - # share some inference info with the main model - draft_model_input = main_model_input + # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, + # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型 decode_mtp 设置的 + # verify 布局,命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError + # (cudagraph 关闭时则会在扁平的 draft batch 上误用 S+1 分组的 verify attention)。 + # 镜像 eagle 路径 _build_eagle_accepted_draft_input 中清空 b_num_accepted_tokens 的处理。 + draft_model_input = copy.copy(main_model_input) + draft_model_input.b_num_accepted_tokens = None draft_model_output = main_model_output draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) # process the draft model output for draft_model_idx in range(self.mtp_step): - draft_model_input.input_ids = draft_next_token_ids draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP @@ -379,44 +392,47 @@ def _draft_decode_eagle( mtp_accept_len: torch.Tensor, b_req_mtp_start_loc: torch.Tensor, ): - batch_size = main_model_input.batch_size - num_reqs = batch_size // (self.mtp_step + 1) + num_reqs = b_req_mtp_start_loc.shape[0] if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_reqs * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(num_reqs * self.mtp_step) eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) - # share some inference info with the main model - draft_model_input = main_model_input + (draft_model_input, draft_next_token_ids, accepted_req_idx,) = self._build_eagle_accepted_draft_input( + main_model_input=main_model_input, + main_model_output=main_model_output, + next_token_ids=next_token_ids, + mtp_accept_len=mtp_accept_len, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) draft_model_output = main_model_output - draft_next_token_ids = next_token_ids all_next_token_ids = [] - all_next_token_ids.append(next_token_ids) - # process the draft model output - for _step in range(self.mtp_step): + all_next_token_ids.append(draft_next_token_ids) + + mtp_size = self.mtp_step + 1 + main_mem_indexes = main_model_input.mem_indexes.view(num_reqs, mtp_size) + eagle_mem_indexes_by_req = eagle_mem_indexes.view(self.mtp_step, num_reqs).transpose(0, 1).contiguous() + mem_index_plan = torch.cat([main_mem_indexes, eagle_mem_indexes_by_req], dim=1) + accepted_offsets = mtp_accept_len.long() - 1 + req_offsets = torch.arange(num_reqs, dtype=torch.long, device=mtp_accept_len.device) + for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids - draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens + if _step > 0: + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens + draft_model_input.mem_indexes = mem_index_plan[req_offsets, accepted_offsets + _step] # spec decode: MTP draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 draft_model_input.max_kv_seq_len += 1 - eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] - draft_model_input.mem_indexes = torch.cat( - [draft_model_input.mem_indexes.view(-1, self.mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)], - dim=1, - ).view(-1) all_next_token_ids.append(draft_next_token_ids) all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - mtp_scatter_next_token_ids( - req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, + self._scatter_accepted_next_token_ids( + accepted_req_idx=accepted_req_idx, all_next_token_ids=all_next_token_ids, - b_req_idx=main_model_input.b_req_idx, - mtp_accept_len=mtp_accept_len, ) return eagle_mem_indexes_cpu diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index e6b9d1c18d..43ed89d691 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -1,3 +1,4 @@ +import copy import torch import time import torch.nn.functional as F @@ -263,7 +264,6 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer b_req_idx = torch.cat((model_input0.b_req_idx[0:req_num0], model_input1.b_req_idx[0:req_num1]), dim=0) if (req_num0 + req_num1) > 0: - _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=logits, b_req_idx=b_req_idx, @@ -405,7 +405,6 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] sync_event.record() if req_num > 0: - # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill) @@ -432,10 +431,22 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] return def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): - model_input, run_reqs, _ = padded_prepare_decode_inputs(decode_reqs) + model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs(decode_reqs) b_mtp_index_cpu = model_input.b_mtp_index req_num = len(run_reqs) + if self.mtp_step > 0: + # 标记 verify decode 布局:每个 req 一个 accept 数量(padding 出来的 fake req 记为 1)。 + # 不设置 b_num_accepted_tokens 会让主模型的 verify forward 走非 verify 的 GDN/FA3 布局, + # 并命中 hybrid 主模型从未捕获的 cudagraph key (bs, False) -> KeyError。 + # 与 chunked_prefill/impl.py 的 decode_mtp 保持一致。 + accept_lens = [req.mtp_accept_len for req in decode_reqs] + [1] * padded_req_num + model_input.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list( + key="b_num_accepted_tokens", + data=accept_lens, + dtype=torch.int32, + ) + with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None @@ -496,6 +507,11 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() verify_event.synchronize() + # 写回每个 req 的本步 accept 数量,供下一步 verify 经 b_num_accepted_tokens 传入 + # GDN/linear-att verify kernel(据此提交 conv/ssm 递归状态的正确偏移)。chunked 路径 + # 在 chunked_prefill/impl.py 同样写回;dp 缺失会让状态停留在 accept=1 -> 状态错乱、精度崩塌。 + for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): + req.mtp_accept_len = int(accept_len) verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1] update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) @@ -536,8 +552,11 @@ def _draft_decode_vanilla( req_num: int, ): all_next_token_ids = [] - # share some inference info with the main model - draft_model_input = model_input + # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, + # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型设置的 verify 布局, + # 命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError。 + draft_model_input = copy.copy(model_input) + draft_model_input.b_num_accepted_tokens = None draft_model_output = model_output draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") if req_num > 0: @@ -547,7 +566,6 @@ def _draft_decode_vanilla( # process the draft model output for draft_model_idx in range(self.mtp_step): - draft_model_input.input_ids = draft_next_token_ids_gpu draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP @@ -567,6 +585,64 @@ def _draft_decode_vanilla( ) return None + def _build_padding_draft_input(self, model_input: ModelInput, model_output: ModelOutput, common_req_num: int): + """ + 构造一个纯 padding 的 draft 输入,用于本 rank 没有真实 decode 请求 (real_req_num == 0) + 但其它 dp rank 仍有请求、需要本 rank 同步参与 mtp_step 次 draft forward 的集合通信的场景。 + + 从已 padding 的 main model_input 中按 (mtp_step+1) 分组取每组首行 (mtp_index==0) 即可, + 这些行均为 HOLD_REQUEST_ID / HOLD_TOKEN_MEMINDEX 的占位行。step0 的 hiddens 沿用主模型 + 对应占位行的 mtp_main_output_hiddens, 与原 DP 实现 (step0 使用 model_output.mtp_main_output_hiddens) + 保持一致, 避免 None 触发 draft forward 崩溃。 + """ + mtp_size = self.mtp_step + 1 + select_idx = torch.arange(common_req_num, dtype=torch.long, device=model_input.b_req_idx.device) * mtp_size + + draft_model_input = copy.copy(model_input) + draft_model_input.batch_size = common_req_num + draft_model_input.total_token_num = common_req_num * model_input.max_kv_seq_len + draft_model_input.b_num_accepted_tokens = None + draft_model_input.b_req_idx = model_input.b_req_idx.index_select(0, select_idx) + draft_model_input.b_mtp_index = model_input.b_mtp_index.index_select(0, select_idx) + draft_model_input.b_seq_len = model_input.b_seq_len.index_select(0, select_idx) + draft_model_input.mem_indexes = model_input.mem_indexes.index_select(0, select_idx) + draft_model_input.mem_indexes_cpu = None + draft_model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens.index_select(0, select_idx) + draft_model_input.multimodal_params = [{"images": [], "audios": []} for _ in range(common_req_num)] + return draft_model_input + + def _pad_draft_input_to(self, draft_model_input: ModelInput, target_req_num: int): + """ + 将 shrink 到 real_req_num 行的 draft 输入再 padding 回 target_req_num (= common_req_num) 行, + 使本 rank 的 draft forward 行数与其它 dp rank 对齐,保证 MoE all-to-all / dp all-gather 的 + shape 一致。padding 行采用与 padded_prepare_decode_inputs 相同的占位约定: + b_req_idx -> HOLD_REQUEST_ID, mem_indexes -> HOLD_TOKEN_MEMINDEX。 + """ + cur_req_num = draft_model_input.batch_size + pad_num = target_req_num - cur_req_num + if pad_num <= 0: + return draft_model_input + + hold_req_id = g_infer_context.req_manager.HOLD_REQUEST_ID + hold_mem_idx = g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX + + draft_model_input.input_ids = F.pad(draft_model_input.input_ids, (0, pad_num), value=0) + draft_model_input.b_req_idx = F.pad(draft_model_input.b_req_idx, (0, pad_num), value=hold_req_id) + draft_model_input.b_mtp_index = F.pad(draft_model_input.b_mtp_index, (0, pad_num), value=0) + # padding 行用一个合法的小 seq_len (沿用 padded_prepare_decode_inputs 中 fake req 的约定值 2) + draft_model_input.b_seq_len = F.pad(draft_model_input.b_seq_len, (0, pad_num), value=2) + draft_model_input.mem_indexes = F.pad(draft_model_input.mem_indexes, (0, pad_num), value=hold_mem_idx) + # mtp_draft_input_hiddens 为 (rows, hidden),沿 dim0 在尾部补 0 行 + draft_model_input.mtp_draft_input_hiddens = F.pad( + draft_model_input.mtp_draft_input_hiddens, (0, 0, 0, pad_num), value=0 + ) + draft_model_input.multimodal_params = draft_model_input.multimodal_params + [ + {"images": [], "audios": []} for _ in range(pad_num) + ] + draft_model_input.batch_size = target_req_num + draft_model_input.total_token_num = target_req_num * draft_model_input.max_kv_seq_len + return draft_model_input + def _draft_decode_eagle( self, model_input: ModelInput, @@ -576,57 +652,65 @@ def _draft_decode_eagle( mtp_accept_len: torch.Tensor, req_num: int, ): - all_next_token_ids = [] - # share some inference info with the main model - draft_model_input = model_input - draft_model_output = model_output - all_next_token_ids.append(next_token_ids) - draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") - if req_num > 0: - draft_next_token_ids_gpu[:req_num].copy_(next_token_ids, non_blocking=True) + mtp_size = self.mtp_step + 1 + real_req_num = req_num // mtp_size + common_req_num = model_input.batch_size // mtp_size + padded_req_num = common_req_num - real_req_num - real_req_num = req_num // (self.mtp_step + 1) - padded_req_num = model_input.batch_size // (self.mtp_step + 1) - real_req_num - eagle_mem_indexes_cpu = None + # 即使本 rank 没有真实请求, 也要为其它 rank 同步运行 mtp_step 次 draft forward 的集合通信。 if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(real_req_num * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(real_req_num * self.mtp_step) eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) - # process the draft model output - for _step in range(self.mtp_step): + if real_req_num > 0: + (draft_model_input, draft_next_token_ids, accepted_req_idx,) = self._build_eagle_accepted_draft_input( + main_model_input=model_input, + main_model_output=model_output, + next_token_ids=next_token_ids, + mtp_accept_len=mtp_accept_len, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) + if padded_req_num > 0: + draft_model_input = self._pad_draft_input_to(draft_model_input, common_req_num) + draft_next_token_ids = F.pad(draft_next_token_ids, (0, padded_req_num), value=0) + + main_mem_indexes = model_input.mem_indexes.view(common_req_num, mtp_size) + eagle_padded = F.pad( + eagle_mem_indexes.view(self.mtp_step, real_req_num).transpose(0, 1).contiguous(), + (0, 0, 0, padded_req_num), + value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, + ) # (common_req_num, mtp_step) + mem_index_plan = torch.cat([main_mem_indexes, eagle_padded], dim=1) + accepted_offsets = F.pad(mtp_accept_len.long() - 1, (0, padded_req_num), value=0) + req_offsets = torch.arange(common_req_num, dtype=torch.long, device=mem_index_plan.device) + else: + # 本 rank 无真实请求: 纯 padding draft 输入, 仅用于跟随集合通信, 结果不写回。 + draft_model_input = self._build_padding_draft_input(model_input, model_output, common_req_num) + draft_next_token_ids = torch.zeros((common_req_num,), dtype=torch.int64, device="cuda") + mem_index_plan = model_input.mem_indexes.view(common_req_num, mtp_size) + accepted_offsets = torch.zeros((common_req_num,), dtype=torch.long, device=mem_index_plan.device) + req_offsets = torch.arange(common_req_num, dtype=torch.long, device=mem_index_plan.device) - draft_model_input.input_ids = draft_next_token_ids_gpu - draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens - # spec decode: MTP + draft_model_output = model_output + all_next_token_ids = [draft_next_token_ids] + for _step in range(self.mtp_step): + draft_model_input.input_ids = draft_next_token_ids + if _step > 0: + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens + draft_model_input.mem_indexes = mem_index_plan[req_offsets, accepted_offsets + _step] draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - # update the meta info of the inference + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 draft_model_input.max_kv_seq_len += 1 - eagle_mem_indexes_i = eagle_mem_indexes[_step * real_req_num : (_step + 1) * real_req_num] - eagle_mem_indexes_i = F.pad( - input=eagle_mem_indexes_i, - pad=(0, padded_req_num), - mode="constant", - value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, - ) - draft_model_input.mem_indexes = torch.cat( - [draft_model_input.mem_indexes.view(-1, self.mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)], - dim=1, - ).view(-1) - draft_next_token_ids_gpu = self._gen_argmax_token_ids(draft_model_output) - all_next_token_ids.append(draft_next_token_ids_gpu) + all_next_token_ids.append(draft_next_token_ids) - if req_num > 0: - all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - all_next_token_ids = all_next_token_ids[0:req_num, :] - mtp_scatter_next_token_ids( - req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, + if real_req_num > 0: + all_next_token_ids = torch.stack(all_next_token_ids, dim=1)[:real_req_num, :] + self._scatter_accepted_next_token_ids( + accepted_req_idx=accepted_req_idx[:real_req_num], all_next_token_ids=all_next_token_ids, - b_req_idx=model_input.b_req_idx[:req_num], - mtp_accept_len=mtp_accept_len, ) return eagle_mem_indexes_cpu @@ -680,7 +764,6 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I draft_model_output0, draft_model_output1 = model_output0, model_output1 for draft_model_idx in range(self.num_mtp_models): - draft_model_input0 = prepare_mtp_prefill_inputs( model_input=draft_model_input0, b_next_token_ids=draft_next_token_ids_gpu0, @@ -732,17 +815,36 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf ( model_input0, run_reqs0, - _, + padded_req_num0, model_input1, run_reqs1, - _, + padded_req_num1, ) = padded_overlap_prepare_decode_inputs(decode_reqs) req_num0, req_num1 = len(run_reqs0), len(run_reqs1) all_next_token_ids = [] b_mtp_index_cpu0 = model_input0.b_mtp_index b_mtp_index_cpu1 = model_input1.b_mtp_index - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + if self.mtp_step > 0: + # 标记两个 micro-batch 的 verify decode 布局,每个 req 一个 accept 数量 + # (padding 出来的 fake req 记为 1)。run_reqs* 内每个真实 req 占 mtp_step+1 行, + # 取每组首行即可得到逐 req 的列表。不设置会让主模型 verify forward 走非 verify 布局, + # 命中 hybrid 主模型从未捕获的 cudagraph key (bs, False) -> KeyError。 + mtp_size = self.mtp_step + 1 + accept_lens0 = [r.mtp_accept_len for r in run_reqs0[::mtp_size]] + [1] * padded_req_num0 + accept_lens1 = [r.mtp_accept_len for r in run_reqs1[::mtp_size]] + [1] * padded_req_num1 + model_input0.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list( + key="b_num_accepted_tokens_0", + data=accept_lens0, + dtype=torch.int32, + ) + model_input1.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list( + key="b_num_accepted_tokens_1", + data=accept_lens1, + dtype=torch.int32, + ) + + with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -811,6 +913,11 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf if req_num0 + req_num1 > 0: event_pack.notify_post_handle_and_wait_pre_post_handle() verify_event.synchronize() + # 写回每个 req 的本步 accept 数量,供下一步 verify 经 b_num_accepted_tokens 传入 + # GDN/linear-att verify kernel(据此提交 conv/ssm 递归状态的正确偏移)。chunked 路径 + # 在 chunked_prefill/impl.py 同样写回;dp 缺失会让状态停留在 accept=1 -> 状态错乱、精度崩塌。 + for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): + req.mtp_accept_len = int(accept_len) verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1] update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) @@ -871,22 +978,26 @@ def _draft_decode_vanilla_overlap( ): all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - # share some inference info with the main model - draft_model_input0, draft_model_input1 = model_input0, model_input1 + # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, + # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型设置的 verify 布局, + # 命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError。 + draft_model_input0 = copy.copy(model_input0) + draft_model_input1 = copy.copy(model_input1) + draft_model_input0.b_num_accepted_tokens = None + draft_model_input1.b_num_accepted_tokens = None draft_model_output0, draft_model_output1 = model_output0, model_output1 draft_next_token_ids_gpu0 = torch.zeros((model_input0.batch_size), dtype=torch.int64, device="cuda") draft_next_token_ids_gpu1 = torch.zeros((model_input1.batch_size), dtype=torch.int64, device="cuda") if req_num0 > 0: draft_next_token_ids_gpu0[0:req_num0].copy_(next_token_ids[0:req_num0], non_blocking=True) - if req_num1 > 1: + if req_num1 > 0: draft_next_token_ids_gpu1[0:req_num1].copy_( next_token_ids[req_num0 : (req_num0 + req_num1)], non_blocking=True ) # process the draft model output for draft_model_idx in range(self.mtp_step): - draft_model_input0.input_ids = draft_next_token_ids_gpu0 draft_model_input0.mtp_draft_input_hiddens = draft_model_output0.mtp_main_output_hiddens draft_model_input1.input_ids = draft_next_token_ids_gpu1 @@ -929,15 +1040,20 @@ def _draft_decode_eagle_overlap( ): all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - # share some inference info with the main model - draft_model_input0, draft_model_input1 = model_input0, model_input1 + # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, + # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型设置的 verify 布局, + # 命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError。 + draft_model_input0 = copy.copy(model_input0) + draft_model_input1 = copy.copy(model_input1) + draft_model_input0.b_num_accepted_tokens = None + draft_model_input1.b_num_accepted_tokens = None draft_model_output0, draft_model_output1 = model_output0, model_output1 draft_next_token_ids_gpu0 = torch.zeros((model_input0.batch_size), dtype=torch.int64, device="cuda") draft_next_token_ids_gpu1 = torch.zeros((model_input1.batch_size), dtype=torch.int64, device="cuda") if req_num0 > 0: draft_next_token_ids_gpu0[0:req_num0].copy_(next_token_ids[0:req_num0], non_blocking=True) - if req_num1 > 1: + if req_num1 > 0: draft_next_token_ids_gpu1[0:req_num1].copy_( next_token_ids[req_num0 : (req_num0 + req_num1)], non_blocking=True ) @@ -955,7 +1071,6 @@ def _draft_decode_eagle_overlap( # process the draft model output for _step in range(self.mtp_step): - draft_model_input0.input_ids = draft_next_token_ids_gpu0 draft_model_input0.mtp_draft_input_hiddens = draft_model_output0.mtp_main_output_hiddens draft_model_input1.input_ids = draft_next_token_ids_gpu1 diff --git a/lightllm/server/router/model_infer/mode_backend/mtp_model_factory.py b/lightllm/server/router/model_infer/mode_backend/mtp_model_factory.py new file mode 100644 index 0000000000..1b4ade1ac0 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/mtp_model_factory.py @@ -0,0 +1,33 @@ +from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel +from lightllm.models.mistral_mtp.model import MistralMTPModel +from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel + + +def create_mtp_draft_model(model_type: str, mtp_mode: str, mtp_model_kvargs: dict): + """Single source of truth for (model_type, mtp_mode) -> MTP draft model (#10). + Shared by base_backend and the static MTP benchmark.""" + if model_type == "deepseek_v3": + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + return Deepseek3MTPModel(mtp_model_kvargs) + elif model_type == "qwen3_moe": + assert mtp_mode in ["vanilla_no_att", "eagle_no_att"] + return Qwen3MOEMTPModel(mtp_model_kvargs) + elif model_type == "mistral": + assert mtp_mode in ["vanilla_no_att", "eagle_no_att"] + return MistralMTPModel(mtp_model_kvargs) + elif model_type == "glm4_moe_lite": + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + return Glm4MoeLiteMTPModel(mtp_model_kvargs) + elif model_type in ("qwen3_5", "qwen3_5_text"): + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel + + return Qwen3_5MTPModel(mtp_model_kvargs) + elif model_type in ("qwen3_5_moe", "qwen3_5_moe_text"): + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel + + return Qwen3_5MoeMTPModel(mtp_model_kvargs) + else: + raise ValueError(f"Unsupported MTP model type: {model_type}") From 76096064eba5ce34cfda5a9be908ef78d261ff9b Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Jun 2026 09:30:08 +0800 Subject: [PATCH 06/19] test(mtp): MTP unit tests + static benchmark Behavioural/CUDA coverage for the subtle MTP paths: verify-extra-state metadata, decode CUDA-graph verify layouts, fa3 fp8 verify narrowing, GDN verify equivalence, the spec causal_conv1d kernel and its prefill->decode roundtrip, and the linear-att conv/SSM widened-slot split + snapshot + CPU-cache persistence. Also extends the static-inference MTP benchmark and anchors the .gitignore benchmark-output rule to /benchmark. --- .gitignore | 4 +- .../benchmark/static_inference/model_infer.py | 2 +- .../static_inference/model_infer_mtp.py | 282 ++++++++++--- test/benchmark/static_inference/test_model.py | 21 +- test/cpu_cache_kernel/test_speed.py | 2 +- .../test_fp8_decode_verify_narrowed.py | 59 +++ .../basemodel/test_mtp_decode_cuda_graph.py | 393 ++++++++++++++++++ .../test_init_linear_att_state_zeros_block.py | 41 ++ .../common/test_linear_att_copy_guards.py | 39 ++ ...st_linear_att_mtp_cpu_cache_persistence.py | 219 ++++++++++ .../common/test_linear_att_snapshot_split.py | 41 ++ .../common/test_mtp_verify_extra_state.py | 36 ++ .../qwen3next/test_causal_conv1d_spec.py | 147 +++++++ .../test_conv_prefill_decode_roundtrip.py | 74 ++++ .../qwen3next/test_gdn_verify_equivalence.py | 194 +++++++++ 15 files changed, 1483 insertions(+), 71 deletions(-) create mode 100644 unit_tests/common/basemodel/test_fp8_decode_verify_narrowed.py create mode 100644 unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py create mode 100644 unit_tests/common/test_init_linear_att_state_zeros_block.py create mode 100644 unit_tests/common/test_linear_att_copy_guards.py create mode 100644 unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py create mode 100644 unit_tests/common/test_linear_att_snapshot_split.py create mode 100644 unit_tests/common/test_mtp_verify_extra_state.py create mode 100644 unit_tests/models/qwen3next/test_causal_conv1d_spec.py create mode 100644 unit_tests/models/qwen3next/test_conv_prefill_decode_roundtrip.py create mode 100644 unit_tests/models/qwen3next/test_gdn_verify_equivalence.py diff --git a/.gitignore b/.gitignore index 9b69e2eb4c..1156bab780 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,6 @@ dist .vscode tmp/ requirements-musa.txt -logs/ \ No newline at end of file +logs/ + +/benchmark/ diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index f2c900af09..b93c5fee55 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -36,7 +36,7 @@ def test_model_inference(args): "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, "mem_fraction": args.mem_fraction, - "max_req_num": 2048, + "max_req_num": 512, "batch_max_tokens": 1024, "run_mode": "normal", "max_seq_length": args.max_req_total_len, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index 72f06a919c..ff31133ae2 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -1,4 +1,5 @@ import os +import copy import torch import numpy as np from multiprocessing import Queue @@ -9,42 +10,60 @@ from lightllm.models import get_model from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from lightllm.server.core.objs.start_args_type import StartArgs -from torch.profiler import profile, record_function, ProfilerActivity +from torch.profiler import profile, ProfilerActivity from lightllm.utils.log_utils import init_logger -from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel -import torch.cuda as cuda logger = init_logger(__name__) def init_mtp_model(args: StartArgs, kvargs, main_model): - mtp_step = args.mtp_step draft_models = [] os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - mtp_model_kvargs = kvargs - mtp_model_kvargs.update( - { - "weight_dir": args.mtp_draft_model_dir, + + if args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: + num_mtp_modules = args.mtp_step + elif args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: + num_mtp_modules = 1 + else: + assert False, f"error mtp mode {args.mtp_mode}" + + for i in range(num_mtp_modules): + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir[i]) + model_type = mtp_model_cfg.get("model_type", "") + mtp_model_kvargs = { + "weight_dir": args.mtp_draft_model_dir[i], "max_total_token_num": main_model.mem_manager.size, - "disable_chunked_prefill": True, - "mtp_mode": args.mtp_mode, + "load_way": kvargs["load_way"], + "max_req_num": kvargs.get("max_req_num", 1000), + "max_seq_length": kvargs.get("max_seq_length", 1024 * 5), + "is_token_healing": False, + "return_all_prompt_logics": False, + "disable_chunked_prefill": args.disable_chunked_prefill, + "data_type": kvargs.get("data_type", "float16"), + "graph_max_batch_size": kvargs.get("graph_max_batch_size", 16), + "graph_max_len_in_batch": kvargs.get("graph_max_len_in_batch", 8196), + "disable_cudagraph": kvargs.get("disable_cudagraph", False), + "mem_fraction": kvargs["mem_fraction"], + "batch_max_tokens": kvargs.get("batch_max_tokens", None), + "quant_type": kvargs.get("quant_type", None), + "quant_cfg": kvargs.get("quant_cfg", None), + "run_mode": "normal", + "llm_prefill_att_backend": kvargs.get("llm_prefill_att_backend", args.llm_prefill_att_backend), + "llm_decode_att_backend": kvargs.get("llm_decode_att_backend", args.llm_decode_att_backend), + "vit_att_backend": kvargs.get("vit_att_backend", args.vit_att_backend), + "llm_kv_type": kvargs.get("llm_kv_type", args.llm_kv_type), + "llm_kv_quant_group_size": kvargs.get("llm_kv_quant_group_size", args.llm_kv_quant_group_size), "main_model": main_model, + "mtp_previous_draft_models": draft_models.copy(), + "mtp_mode": args.mtp_mode, } - ) - for i in range(mtp_step): - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir) - mtp_model_kvargs.update( - { - "weight_dir": args.spec_model_dir, - "max_total_token_num": main_model.mem_manager.size, - "disable_chunked_prefill": True, - "mtp_mode": args.mtp_mode, - "main_model": main_model, - "mem_layer_start": main_model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], - } - ) - draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + + from lightllm.server.router.model_infer.mode_backend.mtp_model_factory import create_mtp_draft_model + + draft_models.append(create_mtp_draft_model(model_type, args.mtp_mode, mtp_model_kvargs)) + + logger.info(f"loaded mtp model class {draft_models[i].__class__}") return draft_models @@ -68,13 +87,22 @@ def test_model_inference_mtp(args): "max_total_token_num": args.max_total_token_num, "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, - "mem_faction": args.mem_fraction, - "max_req_num": 2000, + "mem_fraction": args.mem_fraction, + # Static bench runs explicit batch sizes (<= a few hundred). The hybrid Qwen3.5 + # GDN req-state cache is sized max_req_num * (mtp_step + 1) at ~34 MB/slot, so the + # old default of 2000 alloc'd ~140 GB and OOM'd under MTP. 512 covers any realistic + # static batch sweep while keeping the GDN cache small. + "max_req_num": 512, "batch_max_tokens": 2048, "run_mode": "normal", "max_seq_length": args.max_req_total_len, - "spec_algo": args.spec_algo, "disable_cudagraph": args.disable_cudagraph, + "quant_cfg": args.quant_cfg, + "llm_prefill_att_backend": args.llm_prefill_att_backend, + "llm_decode_att_backend": args.llm_decode_att_backend, + "vit_att_backend": args.vit_att_backend, + "llm_kv_type": args.llm_kv_type, + "llm_kv_quant_group_size": args.llm_kv_quant_group_size, } proc = multiprocessing.Process( target=tppart_model_infer, @@ -113,28 +141,36 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) test_data = test_data.reshape(-1) - test_data = torch.from_numpy(test_data).cuda() + test_data = torch.from_numpy(test_data) b_req_idx = torch.tensor( - [main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cpu" ) - b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") for i in range(batch_size): b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]) + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32) + b_prefill_start_loc = b_seq_len.cumsum(dim=0, dtype=torch.int32) - b_seq_len # Main model Prefill model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, + max_q_seq_len=input_len, + max_kv_seq_len=input_len, + max_cache_len=0, input_ids=test_data, - mem_indexes=mem_indexes, + mem_indexes_cpu=mem_indexes, b_req_idx=b_req_idx, + b_mtp_index=b_mtp_index, b_seq_len=b_seq_len, is_prefill=True, b_ready_cache_len=b_ready_cache_len, + b_prefill_start_loc=b_prefill_start_loc, + prefix_total_token_num=0, multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], ) @@ -167,8 +203,22 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ torch.cuda.synchronize() + # Speculative width = args.mtp_step in BOTH modes (mirrors base_backend: self.mtp_step = + # args.mtp_step). The number of draft MODEL INSTANCES differs: vanilla loads mtp_step + # instances (each forwarded once), eagle loads ONE instance forwarded mtp_step times + # (chunked_prefill/impl.py: draft_models[_step % num_instances]). The verify batch always + # expands to (mtp_step + 1) rows per request. + spec_width = args.mtp_step + num_instances = len(draft_models) + # The draft prefill above produced (1 + num_instances) columns; pad/truncate to + # (spec_width + 1) so the decode verify batch matches the server's expand width. Only the + # SHAPE matters for throughput here (argmax over random inputs); token values do not. + while len(draft_ids) < spec_width + 1: + draft_ids.append(draft_ids[-1]) + draft_ids = draft_ids[: spec_width + 1] decode_input_ids = np.stack(draft_ids, axis=-1).reshape(-1) - decode_input_ids = torch.from_numpy(decode_input_ids).cuda() + decode_input_ids = torch.from_numpy(decode_input_ids) + mtp_step = spec_width # build main decode input: nopad_b_seq_idx = [] @@ -177,67 +227,167 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_max_len_in_batch = 0 for i in range(batch_size): - nopad_b_seq_idx.append(b_req_idx[i]) + nopad_b_seq_idx.append(b_req_idx[i].item()) seq_len = b_seq_len[i].item() nopad_b_seq_len.append(seq_len + 1) nopad_total_token_num += seq_len + 1 - nopad_max_len_in_batch = max(nopad_max_len_in_batch, b_seq_len[i] + 1) + nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + 1) - for step in range(len(draft_models)): - nopad_b_seq_idx.append(b_req_idx[i]) + for step in range(mtp_step): + nopad_b_seq_idx.append(b_req_idx[i].item()) nopad_b_seq_len.append(seq_len + step + 2) nopad_total_token_num += seq_len + step + 2 nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + step + 2) - nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() + nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cpu") + nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cpu") + b_mtp_index = torch.arange(mtp_step + 1, dtype=torch.int32).repeat(batch_size) + mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (mtp_step + 1)) model_input = ModelInput( - batch_size=batch_size * (len(draft_models) + 1), + batch_size=batch_size * (mtp_step + 1), total_token_num=nopad_total_token_num, + max_q_seq_len=1, + max_kv_seq_len=nopad_max_len_in_batch, input_ids=decode_input_ids, - mem_indexes=mem_indexes, + mem_indexes_cpu=mem_indexes, b_req_idx=nopad_b_seq_idx, + b_mtp_index=b_mtp_index, b_seq_len=nopad_b_seq_len, is_prefill=False, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size * (len(draft_models) + 1))], + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size * (mtp_step + 1))], ) + # MTP verify layout. The main decode is a VERIFY forward over the (mtp_step+1)-expanded + # batch. Setting b_num_accepted_tokens (one entry per real request) flips is_mtp_verify=True + # so the hybrid GDN main model runs the fused spec-decode verify kernel — the production path. + # Without it the main decode silently takes the plain _gdn_decode_kernel on the S+1-expanded + # batch (whose rows share req_idx), colliding on the single widened conv slot and mismeasuring + # cost. accept_len is fixed at 1 (steady-state low-acceptance); the verify-forward COST is + # ~constant in accept_len (it always processes mtp_step+1 rows), so this faithfully measures + # per-step decode cost. Vary accept_len in [1, mtp_step+1] to sweep the acceptance regime. + accept_len = 1 + is_eagle = args.mtp_mode.startswith("eagle") + model_input.b_num_accepted_tokens = torch.full((batch_size,), accept_len, dtype=torch.int32, device="cuda") + req_offsets = torch.arange(batch_size, dtype=torch.long, device="cuda") + accepted_row_idx = req_offsets * (mtp_step + 1) + (accept_len - 1) + if is_eagle: + # EAGLE draft scratch slots (n_real * mtp_step), mirroring _draft_decode_eagle. Allocated + # once and reused across steps (throughput bench overwrites draft KV; no correctness check). + eagle_mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * mtp_step).cuda() + + # Prize-sizing profiler (need-to-fix #22): env-gated, eagle-only, additive. Times the verify + # forward vs the S-step draft chain to decide whether collapsing the chain into a CUDA graph is + # worth it. host_bound_ratio ~1 (or per_step flat across bs) => host/launch-bound => graph wins. + _mtp_profile = os.environ.get("MTP_PROFILE", "0") == "1" + _prof = {"verify_ms": 0.0, "draft_ms": 0.0, "draft_host_ms": 0.0, "n": 0, "per_step_ms": [0.0] * mtp_step} + # Main decode - for i in range(0, output_len, len(draft_models) + 1): + for i in range(0, output_len, mtp_step + 1): torch.cuda.synchronize() step_start_time = time.time() - model_output = main_model.forward( - model_input, - ) - prob_out = torch.softmax(model_output.logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - # draft decode - model_input.input_ids = predict_ids.reshape(-1) - model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens - - for draft_model_id in range(len(draft_models)): - draft_model = draft_models[draft_model_id] - model_output = draft_model.forward( - model_input, + # --- main VERIFY forward: mtp_step+1 rows/req through the fused GDN verify kernel --- + if _mtp_profile and not warmup: + _ev_v0 = torch.cuda.Event(enable_timing=True) + _ev_v1 = torch.cuda.Event(enable_timing=True) + _ev_v0.record() + model_output = main_model.forward(model_input) + if _mtp_profile and not warmup: + _ev_v1.record() + predict_ids = torch.argmax(model_output.logits, dim=1, keepdim=True) + + if is_eagle: + # EAGLE draft: shrink to the single accepted row per request (1 row/req), then run the + # draft model mtp_step times. The Qwen3.5 MTP draft is full-attention and takes the + # plain decode layout (b_num_accepted_tokens=None). Mirrors chunked_prefill + # _build_eagle_accepted_draft_input + _draft_decode_eagle so the measured draft cost is + # the real n_real-row cost, not the (mtp_step+1)x-inflated full-batch cost. + main_mem = model_input.mem_indexes.view(batch_size, mtp_step + 1) + eagle_mem_by_req = eagle_mem_indexes.view(mtp_step, batch_size).transpose(0, 1).contiguous() + mem_index_plan = torch.cat([main_mem, eagle_mem_by_req], dim=1) + + draft_model_input = copy.copy(model_input) + draft_model_input.batch_size = batch_size + draft_model_input.total_token_num = batch_size * model_input.max_kv_seq_len + draft_model_input.b_num_accepted_tokens = None + draft_model_input.mem_indexes_cpu = None + draft_model_input.b_req_idx = model_input.b_req_idx.index_select(0, accepted_row_idx) + draft_model_input.b_seq_len = model_input.b_seq_len.index_select(0, accepted_row_idx) + draft_model_input.b_mtp_index = model_input.b_mtp_index.index_select(0, accepted_row_idx) + draft_model_input.input_ids = predict_ids.reshape(-1).index_select(0, accepted_row_idx) + draft_model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens.index_select( + 0, accepted_row_idx ) - prob_out = torch.softmax(model_output.logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - model_input.input_ids = predict_ids.reshape(-1) - model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens + draft_model_input.multimodal_params = [{"images": [], "audios": []} for _ in range(batch_size)] + + if _mtp_profile and not warmup: + _step_evs = [] + _ev_d0 = torch.cuda.Event(enable_timing=True) + _ev_d1 = torch.cuda.Event(enable_timing=True) + _host_t0 = time.time() + _ev_d0.record() + for _step in range(mtp_step): + draft_model_input.mem_indexes = mem_index_plan[req_offsets, (accept_len - 1) + _step] + draft_model = draft_models[_step % num_instances] + if _mtp_profile and not warmup: + _es = torch.cuda.Event(enable_timing=True) + _ee = torch.cuda.Event(enable_timing=True) + _es.record() + draft_output = draft_model.forward(draft_model_input) + if _mtp_profile and not warmup: + _ee.record() + _step_evs.append((_es, _ee)) + draft_model_input.input_ids = torch.argmax(draft_output.logits, dim=1, keepdim=True).reshape(-1) + draft_model_input.mtp_draft_input_hiddens = draft_output.mtp_main_output_hiddens + draft_model_input.b_seq_len = draft_model_input.b_seq_len + 1 + draft_model_input.max_kv_seq_len += 1 + if _mtp_profile and not warmup: + _ev_d1.record() + _host_t1 = time.time() + else: + # VANILLA draft: full (mtp_step+1)-expanded batch, plain decode layout. Mirrors + # chunked_prefill _draft_decode_vanilla (b_num_accepted_tokens cleared on a copy so the + # MTP draft model does not inherit the main model's verify layout / cudagraph key). + draft_model_input = copy.copy(model_input) + draft_model_input.b_num_accepted_tokens = None + draft_model_input.input_ids = predict_ids.reshape(-1) + draft_model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens + for _step in range(mtp_step): + draft_model = draft_models[_step % num_instances] + draft_output = draft_model.forward(draft_model_input) + draft_model_input.input_ids = torch.argmax(draft_output.logits, dim=1, keepdim=True).reshape(-1) + draft_model_input.mtp_draft_input_hiddens = draft_output.mtp_main_output_hiddens - # accept all draft ids by default. - model_input.input_ids = predict_ids.reshape(-1) - model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens torch.cuda.synchronize() + if _mtp_profile and not warmup and is_eagle and i >= 3 * (mtp_step + 1): + # skip first 3 macro-steps (lazy cudagraph capture / cache warmup) + _prof["verify_ms"] += _ev_v0.elapsed_time(_ev_v1) + _prof["draft_ms"] += _ev_d0.elapsed_time(_ev_d1) + _prof["draft_host_ms"] += (_host_t1 - _host_t0) * 1000.0 + for _k, (_es, _ee) in enumerate(_step_evs): + _prof["per_step_ms"][_k] += _es.elapsed_time(_ee) + _prof["n"] += 1 if i % 100 == 0 or i == output_len - 1: step_end_time = time.time() if get_current_rank_in_dp() == 0 and not warmup: step_time = step_end_time - step_start_time print(i, " step cost time:", step_time * 1000) - print(f"Decode throughput: {batch_size * (len(draft_models) + 1) * args.dp / step_time} tokens/s") + # Peak (all-accepted) throughput: mtp_step+1 candidate tokens per req per step. + print(f"Decode throughput: {batch_size * (mtp_step + 1) * args.dp / step_time} tokens/s") + + if _mtp_profile and is_eagle and _prof["n"] > 0 and get_current_rank_in_dp() == 0 and not warmup: + n = _prof["n"] + ps = ", ".join(f"{v / n:.3f}" for v in _prof["per_step_ms"]) + print(f"[MTP_PROFILE] bs={batch_size} S={mtp_step} steps={n}") + print(f"[MTP_PROFILE] verify_gpu_ms = {_prof['verify_ms'] / n:.3f}") + print(f"[MTP_PROFILE] draft_chain_gpu_ms = {_prof['draft_ms'] / n:.3f}") + print(f"[MTP_PROFILE] draft_chain_host_ms = {_prof['draft_host_ms'] / n:.3f} (host-enqueue, no sync)") + print(f"[MTP_PROFILE] per_draft_step_gpu_ms = [{ps}]") + print( + f"[MTP_PROFILE] host_bound_ratio = " + f"{_prof['draft_host_ms'] / max(_prof['draft_ms'], 1e-9):.3f} (~1 => host-bound => graph wins)" + ) main_model.mem_manager.free_all() main_model.req_manager.free_all() diff --git a/test/benchmark/static_inference/test_model.py b/test/benchmark/static_inference/test_model.py index 5b3751bcc3..7992c03743 100644 --- a/test/benchmark/static_inference/test_model.py +++ b/test/benchmark/static_inference/test_model.py @@ -11,12 +11,29 @@ from lightllm.utils.config_utils import get_config_json, get_dtype +def parse_batch_size(value): + parts = [part.strip() for part in value.split(",") if part.strip()] + if not parts: + raise ValueError("batch_size must contain at least one integer") + + batch_sizes = [] + for part in parts: + size = int(part) + if size <= 0: + raise ValueError("batch_size values must be positive integers") + batch_sizes.append(size) + + if len(batch_sizes) == 1: + return batch_sizes[0] + return batch_sizes + + class TestModelInfer(unittest.TestCase): def test_model_infer(self): args = get_env_start_args() if args.data_type is None: args.data_type = get_dtype(args.model_dir) - if args.mtp_mode == "deepseekv3": + if args.mtp_mode is not None: test_model_inference_mtp(args) else: test_model_inference(args) @@ -27,7 +44,7 @@ def test_model_infer(self): import torch parser = make_argument_parser() - parser.add_argument("--batch_size", type=int, default=None, help="batch size") + parser.add_argument("--batch_size", type=parse_batch_size, default=None, help="batch size, e.g. 8 or 1,2,4,8") parser.add_argument("--input_len", type=int, default=64, help="input sequence length") parser.add_argument("--output_len", type=int, default=128, help="output sequence length") parser.add_argument( diff --git a/test/cpu_cache_kernel/test_speed.py b/test/cpu_cache_kernel/test_speed.py index 254142050c..b1c761f4c2 100644 --- a/test/cpu_cache_kernel/test_speed.py +++ b/test/cpu_cache_kernel/test_speed.py @@ -104,7 +104,7 @@ buffer_count = triton.cdiv(SEQ_LEN, big_page_token_num) + 2 # matches Qwen3NextMemManager -conv_shape = linear_config.get_conv_state_shape() +conv_shape = linear_config.get_persisted_conv_state_shape() cpu_kv_conv_state = torch.empty( (buffer_count, linear_config.linear_layer_num, *conv_shape), dtype=linear_config.conv_state_dtype, diff --git a/unit_tests/common/basemodel/test_fp8_decode_verify_narrowed.py b/unit_tests/common/basemodel/test_fp8_decode_verify_narrowed.py new file mode 100644 index 0000000000..a671550d31 --- /dev/null +++ b/unit_tests/common/basemodel/test_fp8_decode_verify_narrowed.py @@ -0,0 +1,59 @@ +import types +import torch +import pytest + +import lightllm.common.basemodel.attention.fa3.fp8 as fp8_mod +from lightllm.common.basemodel.attention.fa3.fp8 import Fp8Fa3DecodeAttState + + +def _make_verify_state(n_real, mtp_size, head_num=2, head_dim=8): + """Build an Fp8Fa3DecodeAttState as init_state would leave it in MTP-verify mode, + bypassing init_state. b_att_seq_len/page_table are NARROW (n_real); infer_state.b_seq_len + is the FULL expanded tensor (n_real*mtp_size) that must NOT be used as cache_seqlens.""" + state = object.__new__(Fp8Fa3DecodeAttState) + batch = n_real * mtp_size + state.b_att_seq_len = torch.full((n_real,), 16, dtype=torch.int32) + state.page_table = torch.zeros((n_real, 16), dtype=torch.int32) + state.cu_seqlens_q = torch.arange(0, (n_real + 1) * mtp_size, mtp_size, dtype=torch.int32) + state.cu_seqlens_k = torch.zeros((n_real + 1,), dtype=torch.int32) + state.decode_max_q_seq_len = mtp_size + state.infer_state = types.SimpleNamespace( + b_seq_len=torch.full((batch,), 16, dtype=torch.int32), + batch_size=batch, + ) + # k/v descale sized per real request (att_batch_size), indexed by layer + state.k_descale = torch.ones((1, n_real, head_num)) + state.v_descale = torch.ones((1, n_real, head_num)) + state.backend = types.SimpleNamespace(_find_layer_index=lambda k, v, att_state: 0) + return state, batch + + +def test_fp8_decode_uses_narrowed_cache_seqlens_and_causal(monkeypatch): + n_real, mtp_size, head_num, head_dim = 3, 4, 2, 8 + state, batch = _make_verify_state(n_real, mtp_size, head_num, head_dim) + + captured = {} + + def fake_flash(**kwargs): + captured.update(kwargs) + q = kwargs["q"] + return torch.zeros((q.shape[0], q.shape[1], q.shape[2])) + + def fake_quant(x, use_per_token_if_dynamic=True): + return x, torch.ones((x.shape[0], 1)) + + monkeypatch.setattr(fp8_mod, "flash_attn_with_kvcache", fake_flash) + monkeypatch.setattr(fp8_mod, "scaled_fp8_quant", fake_quant) + + q = torch.randn((batch, head_num, head_dim)) + k = torch.randn((batch, head_num, head_dim)) + v = torch.randn((batch, head_num, head_dim)) + + state._fp8_decode_att(q=q, k=k, v=v) + + # The KV-side seqlens must be the NARROW per-real-request tensor, matching page_table rows. + assert captured["cache_seqlens"] is state.b_att_seq_len + assert captured["cache_seqlens"].shape[0] == n_real + assert captured["cache_seqlens"].shape[0] == captured["page_table"].shape[0] + # Verify decode must be causal, like the non-fp8 sibling. + assert captured["causal"] is True diff --git a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py new file mode 100644 index 0000000000..dd8fec1640 --- /dev/null +++ b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py @@ -0,0 +1,393 @@ +from types import SimpleNamespace + +import torch + +from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.batch_objs import ModelInput + + +def test_mtp_decode_cuda_graph_warmup_uses_verify_layout(): + from lightllm.common.basemodel.cuda_graph import CudaGraph + + graph = CudaGraph.__new__(CudaGraph) + graph.mtp_step = 2 + graph.graph_max_len_in_batch = 128 + + class FakeMemManager: + HOLD_TOKEN_MEMINDEX = -1 + + def alloc(self, size): + return torch.arange(size, dtype=torch.int32) + + model = SimpleNamespace( + req_manager=SimpleNamespace(HOLD_REQUEST_ID=99), + mem_manager=FakeMemManager(), + _gen_special_model_input=lambda token_num: {"mtp_draft_input_hiddens": None}, + ) + + model_input = graph._build_warmup_decode_model_input(model, batch_size=6, device="cpu") + + assert model_input.batch_size == 6 + assert model_input.b_mtp_index.tolist() == [0, 1, 2, 0, 1, 2] + assert model_input.b_seq_len.tolist() == [2, 3, 4, 2, 3, 4] + assert model_input.b_num_accepted_tokens.tolist() == [1, 1] + assert model_input.total_token_num == 18 + + +def test_mtp_decode_cuda_graph_warmup_supports_normal_layout_for_draft(): + from lightllm.common.basemodel.cuda_graph import CudaGraph + + graph = CudaGraph.__new__(CudaGraph) + graph.mtp_step = 2 + graph.graph_max_len_in_batch = 128 + + class FakeMemManager: + HOLD_TOKEN_MEMINDEX = -1 + + def alloc(self, size): + return torch.arange(size, dtype=torch.int32) + + model = SimpleNamespace( + req_manager=SimpleNamespace(HOLD_REQUEST_ID=99), + mem_manager=FakeMemManager(), + _gen_special_model_input=lambda token_num: {"mtp_draft_input_hiddens": torch.full((token_num, 4), 3.0)}, + ) + + model_input = graph._build_warmup_decode_model_input( + model, + batch_size=5, + device="cpu", + is_mtp_verify_decode=False, + ) + + assert model_input.batch_size == 5 + assert model_input.b_mtp_index.tolist() == [0, 0, 0, 0, 0] + assert model_input.b_seq_len.tolist() == [2, 2, 2, 2, 2] + assert model_input.b_num_accepted_tokens is None + assert model_input.total_token_num == 10 + assert model_input.mtp_draft_input_hiddens.shape == (5, 4) + + +def test_mtp_decode_cuda_graph_keys_verify_and_normal_layouts(): + from lightllm.common.basemodel.cuda_graph import CudaGraph + + graph = CudaGraph.__new__(CudaGraph) + graph.mtp_step = 2 + graph.graph = {} + graph.normal_cuda_graph_batch_sizes = [1, 2, 4, 8] + graph.mtp_verify_cuda_graph_batch_sizes = [3, 6, 9, 12] + + verify_state = SimpleNamespace( + input_ids=torch.ones(6, dtype=torch.int64), + b_num_accepted_tokens=torch.ones(2, dtype=torch.int32), + ) + normal_state = SimpleNamespace( + input_ids=torch.ones(6, dtype=torch.int64), + b_num_accepted_tokens=None, + ) + + assert graph._decode_graph_key(verify_state) == (6, True) + assert graph._decode_graph_key(normal_state) == (6, False) + assert graph.find_closest_graph_batch_size(5, is_mtp_verify_decode=True) == 6 + assert graph.find_closest_graph_batch_size(5, is_mtp_verify_decode=False) == 8 + + graph.graph[(6, True)] = "verify graph" + assert graph.need_capture(6, is_mtp_verify_decode=True) is False + assert graph.need_capture(5, is_mtp_verify_decode=False) is True + + +def test_mtp_decode_cuda_graph_warmup_layouts_split_main_and_draft_models(): + from lightllm.common.basemodel.cuda_graph import CudaGraph + + class Qwen3_5MOETpPartModel: + pass + + class Qwen3_5MoeMTPModel: + pass + + graph = CudaGraph.__new__(CudaGraph) + graph.mtp_step = 2 + graph.normal_cuda_graph_batch_sizes = [1, 2, 4, 8] + graph.mtp_verify_cuda_graph_batch_sizes = [3, 6, 9] + + assert list(graph._iter_warmup_graph_layouts(Qwen3_5MOETpPartModel())) == [(True, [3, 6, 9])] + assert list(graph._iter_warmup_graph_layouts(Qwen3_5MoeMTPModel())) == [(False, [1, 2, 4, 8])] + + +def test_mtp_decode_warmup_layout_marks_qwen3next_verify(monkeypatch): + import pytest + + if not torch.cuda.is_available(): + pytest.skip("needs CUDA for gen_decode_params") + + import lightllm.common.basemodel.mtp_verify_extra_state as mtp_verify_extra_state_mod + from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo + + monkeypatch.setattr(mtp_verify_extra_state_mod, "get_env_start_args", lambda: SimpleNamespace(mtp_step=2)) + + state = Qwen3NextInferStateInfo() + state.is_prefill = False + state.b_req_idx = torch.tensor([5, 5, 5, 6, 6, 6], dtype=torch.int32, device="cuda") + state.b_mtp_index = torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.int32, device="cuda") + state.b_seq_len = torch.tensor([2, 3, 4, 2, 3, 4], dtype=torch.int32, device="cuda") + state.b_num_accepted_tokens = torch.tensor([1, 2], dtype=torch.int32, device="cuda") + + model = SimpleNamespace( + _cos_cached=torch.zeros((16, 4), dtype=torch.float32, device="cuda"), + _sin_cached=torch.zeros((16, 4), dtype=torch.float32, device="cuda"), + ) + + state.init_some_extra_state(model) + + assert state.is_mtp_verify is True + assert state.b_gdn_verify_cu_seqlens.tolist() == [0, 3, 6] + assert state.b_conv_buffer_idx.tolist() == [5, 6] + assert state.b_ssm_index_rows.tolist() == [[15, 16, 17], [18, 19, 20]] + + +def test_mtp_decode_padding_preserves_verify_groups(monkeypatch): + import lightllm.common.basemodel.basemodel as basemodel_mod + + monkeypatch.setattr(basemodel_mod, "enable_diverse_mode_gqa_decode_fast_kernel", lambda: False) + + model = TpPartBaseModel.__new__(TpPartBaseModel) + model.args = SimpleNamespace(mtp_step=2) + model.req_manager = SimpleNamespace(HOLD_REQUEST_ID=99) + model.mem_manager = SimpleNamespace(HOLD_TOKEN_MEMINDEX=-1) + + model_input = ModelInput( + batch_size=3, + total_token_num=12, + max_q_seq_len=1, + max_kv_seq_len=4, + input_ids=torch.tensor([10, 11, 12], dtype=torch.int32), + mem_indexes=torch.tensor([20, 21, 22], dtype=torch.int32), + b_req_idx=torch.tensor([7, 7, 7], dtype=torch.int32), + b_mtp_index=torch.tensor([0, 1, 2], dtype=torch.int32), + b_seq_len=torch.tensor([2, 3, 4], dtype=torch.int32), + b_num_accepted_tokens=torch.tensor([2], dtype=torch.int32), + is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(3)], + ) + + padded = model._create_padded_decode_model_input(model_input, new_batch_size=6) + + assert padded.batch_size == 6 + assert padded.b_req_idx.tolist() == [7, 7, 7, 99, 99, 99] + assert padded.b_mtp_index.tolist() == [0, 1, 2, 0, 1, 2] + assert padded.b_seq_len.tolist() == [2, 3, 4, 2, 3, 4] + assert padded.b_num_accepted_tokens.tolist() == [2, 1] + assert padded.mem_indexes.tolist() == [20, 21, 22, -1, -1, -1] + assert len(padded.multimodal_params) == 6 + + +def test_qwen3next_hybrid_mtp_keeps_decode_cuda_graph_enabled(monkeypatch): + import lightllm.models.qwen3next.model as qwen3next_model + from lightllm.models.qwen3next.model import Qwen3NextTpPartModel + + monkeypatch.setattr(qwen3next_model, "get_env_start_args", lambda: SimpleNamespace(mtp_step=2)) + + called = {} + + def fake_base_init_cudagraph(self): + called["disable_cudagraph"] = self.disable_cudagraph + self.graph = "captured" + + monkeypatch.setattr(TpPartBaseModel, "_init_cudagraph", fake_base_init_cudagraph) + + model = Qwen3NextTpPartModel.__new__(Qwen3NextTpPartModel) + model.disable_cudagraph = False + + Qwen3NextTpPartModel._init_cudagraph(model) + + assert called["disable_cudagraph"] is False + assert model.disable_cudagraph is False + assert model.graph == "captured" + + +def test_fa3_decode_uses_normal_layout_for_narrowed_mtp_draft(monkeypatch): + import lightllm.common.basemodel.attention.fa3.fp as fa3_fp + from lightllm.common.basemodel.attention.fa3.fp import Fa3DecodeAttState + + monkeypatch.setattr(fa3_fp, "get_env_start_args", lambda: SimpleNamespace(mtp_step=2)) + + copied = {} + + def fake_page_table_copy(page_table, req_to_token_indexs, b_req_idx): + copied["page_table_shape"] = tuple(page_table.shape) + copied["b_req_idx"] = b_req_idx.clone() + + monkeypatch.setattr(fa3_fp, "page_table_copy", fake_page_table_copy) + + model = SimpleNamespace( + graph_max_batch_size=16, + graph_max_len_in_batch=32, + req_manager=SimpleNamespace(req_to_token_indexs=torch.empty((8, 32), dtype=torch.int32)), + ) + backend = SimpleNamespace( + model=model, + get_page_table_buffer=lambda: [torch.empty(16 * 32, dtype=torch.int32)], + ) + infer_state = SimpleNamespace( + batch_size=2, + max_kv_seq_len=16, + input_ids=torch.ones(2, dtype=torch.int64), + b_seq_len=torch.tensor([5, 7], dtype=torch.int32), + b1_cu_q_seq_len=torch.tensor([0, 1, 2], dtype=torch.int32), + b1_cu_kv_seq_len=torch.tensor([0, 5, 12], dtype=torch.int32), + b_req_idx=torch.tensor([3, 4], dtype=torch.int32), + b_num_accepted_tokens=None, + microbatch_index=0, + ) + + state = Fa3DecodeAttState(backend=backend, infer_state=infer_state) + state.init_state() + + assert state.decode_max_q_seq_len == 1 + assert state.b_att_seq_len.tolist() == [5, 7] + assert copied["page_table_shape"] == (2, 16) + assert copied["b_req_idx"].tolist() == [3, 4] + + +def test_build_eagle_accepted_draft_input_narrows_to_accepted_rows(): + from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput + from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ( + ChunkedPrefillBackend, + ) + + backend = ChunkedPrefillBackend.__new__(ChunkedPrefillBackend) + backend.mtp_step = 2 + + main_input = ModelInput( + batch_size=6, + total_token_num=27, + max_q_seq_len=1, + max_kv_seq_len=9, + input_ids=torch.tensor([10, 11, 12, 20, 21, 22], dtype=torch.int64), + mem_indexes=torch.tensor([100, 101, 102, 200, 201, 202], dtype=torch.int32), + b_req_idx=torch.tensor([3, 3, 3, 4, 4, 4], dtype=torch.int32), + b_mtp_index=torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.int32), + b_seq_len=torch.tensor([5, 6, 7, 6, 7, 8], dtype=torch.int32), + b_num_accepted_tokens=torch.tensor([1, 1], dtype=torch.int32), + is_prefill=False, + multimodal_params=[ + {"row": 0}, + {"row": 1}, + {"row": 2}, + {"row": 3}, + {"row": 4}, + {"row": 5}, + ], + ) + hidden = torch.arange(6 * 4, dtype=torch.float32).view(6, 4) + main_output = ModelOutput(logits=torch.empty(6, 8), mtp_main_output_hiddens=hidden) + next_token_ids = torch.tensor([110, 111, 112, 220, 221, 222], dtype=torch.int64) + b_req_mtp_start_loc = torch.tensor([0, 3], dtype=torch.int32) + mtp_accept_len = torch.tensor([2, 3], dtype=torch.int32) + + (draft_input, accepted_next_tokens, accepted_req_idx,) = backend._build_eagle_accepted_draft_input( + main_model_input=main_input, + main_model_output=main_output, + next_token_ids=next_token_ids, + mtp_accept_len=mtp_accept_len, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) + + assert draft_input.batch_size == 2 + assert draft_input.input_ids.tolist() == [111, 222] + assert draft_input.b_req_idx.tolist() == [3, 4] + assert draft_input.b_mtp_index.tolist() == [1, 2] + assert draft_input.b_seq_len.tolist() == [6, 8] + assert draft_input.mem_indexes.tolist() == [101, 202] + assert draft_input.b_num_accepted_tokens is None + assert draft_input.multimodal_params == [{"row": 1}, {"row": 5}] + assert accepted_next_tokens.tolist() == [111, 222] + assert accepted_req_idx.tolist() == [3, 4] + torch.testing.assert_close(draft_input.mtp_draft_input_hiddens, hidden[[1, 5]]) + + +def test_eagle_draft_decode_uses_narrowed_hidden_on_first_step(monkeypatch): + import lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl as chunked_impl + from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput + from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ( + ChunkedPrefillBackend, + ) + + class FakeMemManager: + HOLD_TOKEN_MEMINDEX = -1 + + def alloc(self, need_size): + return torch.arange(300, 300 + need_size, dtype=torch.int32) + + req_to_next_token_ids = torch.empty((8, 3), dtype=torch.int64) + monkeypatch.setattr( + chunked_impl, + "g_infer_context", + SimpleNamespace( + radix_cache=None, + req_manager=SimpleNamespace( + mem_manager=FakeMemManager(), + req_sampling_params_manager=SimpleNamespace(req_to_next_token_ids=req_to_next_token_ids), + ), + ), + ) + monkeypatch.setattr(torch.Tensor, "cuda", lambda self, non_blocking=False: self) + + backend = ChunkedPrefillBackend.__new__(ChunkedPrefillBackend) + backend.mtp_step = 2 + backend.num_mtp_models = 1 + + seen_hiddens = [] + + class FakeDraftModel: + def forward(self, model_input): + seen_hiddens.append(model_input.mtp_draft_input_hiddens.clone()) + logits = torch.zeros((model_input.batch_size, 8), dtype=torch.float32) + return ModelOutput( + logits=logits, + mtp_main_output_hiddens=model_input.mtp_draft_input_hiddens + 100, + ) + + backend.draft_models = [FakeDraftModel()] + + scattered = {} + + def fake_scatter(accepted_req_idx, all_next_token_ids): + scattered["accepted_req_idx"] = accepted_req_idx.clone() + scattered["all_next_token_ids"] = all_next_token_ids.clone() + + backend._scatter_accepted_next_token_ids = fake_scatter + + main_input = ModelInput( + batch_size=6, + total_token_num=27, + max_q_seq_len=1, + max_kv_seq_len=9, + input_ids=torch.tensor([10, 11, 12, 20, 21, 22], dtype=torch.int64), + mem_indexes=torch.tensor([100, 101, 102, 200, 201, 202], dtype=torch.int32), + b_req_idx=torch.tensor([3, 3, 3, 4, 4, 4], dtype=torch.int32), + b_mtp_index=torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.int32), + b_seq_len=torch.tensor([5, 6, 7, 6, 7, 8], dtype=torch.int32), + b_num_accepted_tokens=torch.tensor([1, 1], dtype=torch.int32), + is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(6)], + ) + hidden = torch.arange(6 * 4, dtype=torch.float32).view(6, 4) + main_output = ModelOutput(logits=torch.empty(6, 8), mtp_main_output_hiddens=hidden) + next_token_ids = torch.tensor([110, 111, 112, 220, 221, 222], dtype=torch.int64) + b_req_mtp_start_loc = torch.tensor([0, 3], dtype=torch.int32) + mtp_accept_len = torch.tensor([2, 3], dtype=torch.int32) + + returned_mem = backend._draft_decode_eagle( + main_model_input=main_input, + main_model_output=main_output, + next_token_ids=next_token_ids, + mtp_accept_len=mtp_accept_len, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) + + assert returned_mem.tolist() == [300, 301, 302, 303] + torch.testing.assert_close(seen_hiddens[0], hidden[[1, 5]]) + torch.testing.assert_close(seen_hiddens[1], hidden[[1, 5]] + 100) + assert scattered["accepted_req_idx"].tolist() == [3, 4] + assert scattered["all_next_token_ids"].shape == (2, 3) diff --git a/unit_tests/common/test_init_linear_att_state_zeros_block.py b/unit_tests/common/test_init_linear_att_state_zeros_block.py new file mode 100644 index 0000000000..b20e489bce --- /dev/null +++ b/unit_tests/common/test_init_linear_att_state_zeros_block.py @@ -0,0 +1,41 @@ +import types +import torch + +# NOTE: importing lightllm.common.req_manager *first* trips a pre-existing circular import +# (req_manager line-8 imports gen_sampling_params -> basemodel -> infer_struct, which re-enters +# the half-initialized req_manager before ReqManager is defined). Importing basemodel first +# fully resolves that chain, after which ReqManagerForMamba imports cleanly. This is an +# import-ordering fix only; it does not alter the method-under-test or the duck-typed call below. +import lightllm.common.basemodel # noqa: F401 (resolves circular import; must precede req_manager) +from lightllm.common.req_manager import ReqManagerForMamba + + +class _Buf: + def __init__(self, t): + self.buffer = t + + +def test_init_zeros_full_ssm_block(): + mtp_step = 3 + layer, n_req = 2, 4 + conv_dim, width = 8, 3 + conv_buf = torch.ones(layer, n_req, conv_dim, width) + ssm_buf = torch.ones(layer, n_req * (mtp_step + 1), 5) + + dummy = types.SimpleNamespace( + mtp_step=mtp_step, + req_to_conv_state=_Buf(conv_buf), + req_to_ssm_state=_Buf(ssm_buf), + ) + req = types.SimpleNamespace(req_idx=2, mtp_accept_len=None) + + ReqManagerForMamba.init_linear_att_state(dummy, req) + + start = 2 * (mtp_step + 1) + block = ssm_buf[:, start : start + (mtp_step + 1), ...] + assert torch.count_nonzero(block) == 0, "all S+1 SSM rows of the block must be zeroed on init" + # other requests' rows must be untouched + assert torch.count_nonzero(ssm_buf[:, :start, ...]) > 0 + # conv slot for this request zeroed; canonical accept-len reset + assert torch.count_nonzero(conv_buf[:, 2, ...]) == 0 + assert req.mtp_accept_len == 1 diff --git a/unit_tests/common/test_linear_att_copy_guards.py b/unit_tests/common/test_linear_att_copy_guards.py new file mode 100644 index 0000000000..b6c48a0c86 --- /dev/null +++ b/unit_tests/common/test_linear_att_copy_guards.py @@ -0,0 +1,39 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.linear_att_copy import ( + copy_linear_att_state_to_kv_buffer, +) + + +def _args(gpu_conv, accept_len, mtp_step): + layer_num = gpu_conv.shape[0] + dim_conv = gpu_conv.shape[2] + width_narrow = 3 + return dict( + b_req_idx=torch.tensor([0], dtype=torch.int32), + big_page_buffer_ids=torch.tensor([0], dtype=torch.int32), + gpu_conv_state=gpu_conv, + gpu_ssm_state=torch.zeros(layer_num, 1 * (mtp_step + 1), 8), + cpu_kv_conv_state=torch.zeros(1, layer_num, dim_conv, width_narrow), + cpu_kv_ssm_state=torch.zeros(1, layer_num, 8), + mtp_step=mtp_step, + b_num_accepted_tokens=torch.tensor([accept_len], dtype=torch.int32), + ) + + +def test_rejects_non_contiguous_width_axis(): + mtp_step = 2 + # widened slot allocated 2x, then strided ::2 along the width axis -> stride(3) == 2 + base = torch.zeros(2, 1, 32, (3 + mtp_step) * 2) + gpu_conv = base[:, :, :, ::2] + assert gpu_conv.stride(3) != 1 + with pytest.raises(AssertionError, match="width"): + copy_linear_att_state_to_kv_buffer(**_args(gpu_conv, accept_len=1, mtp_step=mtp_step)) + + +def test_rejects_out_of_range_accept_len(): + mtp_step = 2 + gpu_conv = torch.zeros(2, 1, 32, 3 + mtp_step) # contiguous, passes the #6 guard + with pytest.raises(AssertionError, match="b_num_accepted_tokens"): + copy_linear_att_state_to_kv_buffer(**_args(gpu_conv, accept_len=mtp_step + 2, mtp_step=mtp_step)) diff --git a/unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py b/unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py new file mode 100644 index 0000000000..fb22f0ed1f --- /dev/null +++ b/unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py @@ -0,0 +1,219 @@ +from types import SimpleNamespace + +import pytest +import torch + + +def _make_start_args(**overrides): + base = dict( + model_dir="/tmp/qwen3_5", + tp=1, + dp=1, + data_type="bfloat16", + linear_att_ssm_data_type="bfloat16", + mtp_mode=None, + mtp_step=0, + linear_att_page_block_num=2, + linear_att_hash_page_size=4, + cpu_cache_token_page_size=8, + ) + base.update(overrides) + return SimpleNamespace(**base) + + +def _make_model_cfg(): + return { + "model_type": "qwen3_5", + "num_hidden_layers": 64, + "num_key_value_heads": 16, + "head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 48, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_conv_kernel_dim": 4, + "full_attention_interval": 4, + } + + +def _patch_linear_config_args(monkeypatch, args): + import lightllm.common.linear_att_cache_manager.config_objs as config_objs + + monkeypatch.setattr(config_objs, "get_env_start_args", lambda: args) + + +def _make_config(draft_full_att_layer_num=0): + from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + + return LinearAttCacheConfig( + tp_world_size=1, + full_att_all_num_kv_heads=16, + full_att_dtype=torch.bfloat16, + full_att_num_kv_heads=16, + full_att_head_dim=128, + num_linear_k_heads=16, + num_linear_v_heads=48, + head_linear_k_dim=128, + head_linear_v_dim=128, + conv_kernel_size=4, + linear_layer_num=48, + conv_state_dtype=torch.bfloat16, + ssm_state_dtype=torch.bfloat16, + full_attention_interval=4, + all_layer_num=64, + draft_full_att_layer_num=draft_full_att_layer_num, + ) + + +def test_load_from_args_includes_mtp_draft_full_att_layers(monkeypatch): + from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + from transformers.configuration_utils import PretrainedConfig + + args = _make_start_args(mtp_mode="vanilla_with_att", mtp_step=3) + _patch_linear_config_args(monkeypatch, args) + monkeypatch.setattr(PretrainedConfig, "get_config_dict", lambda _model_path: (_make_model_cfg(), None)) + + cfg = LinearAttCacheConfig.load_from_args() + + assert cfg.get_main_full_att_layer_num() == 16 + assert cfg.draft_full_att_layer_num == 3 + assert cfg.get_persisted_full_att_layer_num() == 19 + + +def test_cpu_cache_full_att_bytes_include_mtp_draft_layers(monkeypatch): + args = _make_start_args() + _patch_linear_config_args(monkeypatch, args) + main_only = _make_config(draft_full_att_layer_num=0) + with_draft = _make_config(draft_full_att_layer_num=2) + + bytes_per_full_att_layer = ( + args.cpu_cache_token_page_size + * 2 + * main_only.full_att_all_num_kv_heads + * main_only.full_att_head_dim + * main_only.full_att_dtype.itemsize + ) + + assert main_only.get_main_full_att_layer_num() == 16 + assert with_draft.get_persisted_full_att_layer_num() == 18 + assert with_draft.get_cpu_cache_full_att_bytes() == ( + main_only.get_cpu_cache_full_att_bytes() + 2 * bytes_per_full_att_layer + ) + + +def test_linear_operator_persisted_full_att_slice_includes_draft_slots(): + from lightllm.common.kv_cache_mem_manager.operator.linear_att import LinearAttMemOperator + + class MtpMemManager: + main_full_att_layer_num = 16 + draft_full_att_layers = 2 + kv_buffer = torch.empty((18, 1)) + + class MainOnlyMemManager: + main_full_att_layer_num = 16 + kv_buffer = torch.empty((18, 1)) + + class PlainMemManager: + kv_buffer = torch.empty((7, 1)) + + assert LinearAttMemOperator._get_persisted_full_att_layer_num(MtpMemManager()) == 18 + assert LinearAttMemOperator._get_persisted_full_att_layer_num(MainOnlyMemManager()) == 16 + assert LinearAttMemOperator._get_persisted_full_att_layer_num(PlainMemManager()) == 7 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") +def test_linear_cpu_cache_roundtrips_mtp_draft_full_att_slot(monkeypatch): + from lightllm.common.basemodel.triton_kernel.linear_att_cpu_cache_copy import ( + copy_cpu_cache_to_kv_buffer, + copy_kv_buffer_to_cpu_cache, + ) + from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + + args = _make_start_args( + linear_att_page_block_num=1, + linear_att_hash_page_size=2, + cpu_cache_token_page_size=2, + ) + _patch_linear_config_args(monkeypatch, args) + cfg = LinearAttCacheConfig( + tp_world_size=1, + full_att_all_num_kv_heads=2, + full_att_dtype=torch.float32, + full_att_num_kv_heads=2, + full_att_head_dim=8, + num_linear_k_heads=1, + num_linear_v_heads=1, + head_linear_k_dim=8, + head_linear_v_dim=8, + conv_kernel_size=2, + linear_layer_num=1, + conv_state_dtype=torch.float32, + ssm_state_dtype=torch.float32, + full_attention_interval=2, + all_layer_num=2, + draft_full_att_layer_num=1, + ) + + gpu_kv = torch.arange(2 * 2 * 4 * 8, dtype=torch.float32, device="cuda").reshape(2, 2, 4, 8) + cpu_cache_tensor = torch.zeros( + (1, 1, 1, 1, cfg.get_cpu_cache_big_page_bytes()), + dtype=torch.uint8, + device="cuda", + ) + conv_state = torch.zeros( + (1, cfg.linear_layer_num, cfg.get_conv_dim(), cfg.conv_kernel_size - 1), + dtype=torch.float32, + device="cuda", + ) + ssm_state = torch.zeros( + ( + 1, + cfg.linear_layer_num, + cfg.num_linear_v_heads, + cfg.head_linear_k_dim, + cfg.head_linear_v_dim, + ), + dtype=torch.float32, + device="cuda", + ) + mem_indexes = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + page_indexes = torch.tensor([0], dtype=torch.int32, device="cuda") + page_readies = torch.tensor([False], dtype=torch.bool, device="cuda") + big_page_buffer_ids = torch.tensor([0], dtype=torch.int64, device="cuda") + + copy_kv_buffer_to_cpu_cache( + mem_indexes=mem_indexes, + page_indexes=page_indexes, + page_readies=page_readies, + big_page_buffer_ids=big_page_buffer_ids, + gpu_kv_full_att_state=gpu_kv, + cpu_kv_conv_state=conv_state, + cpu_kv_ssm_state=ssm_state, + cpu_cache_tensor=cpu_cache_tensor, + tp_rank=0, + tp_world_size=1, + big_page_token_num=args.cpu_cache_token_page_size, + linear_config=cfg, + grid_num=1, + ) + + restored_gpu_kv = torch.full_like(gpu_kv, fill_value=-1) + restored_conv = torch.empty_like(conv_state) + restored_ssm = torch.empty_like(ssm_state) + copy_cpu_cache_to_kv_buffer( + mem_indexes=mem_indexes, + big_page_buffer_ids=big_page_buffer_ids, + page_indexes=page_indexes, + gpu_full_att_kv_state=restored_gpu_kv, + cpu_kv_conv_state=restored_conv, + cpu_kv_ssm_state=restored_ssm, + cpu_cache_tensor=cpu_cache_tensor, + tp_rank=0, + tp_world_size=1, + big_page_token_num=args.cpu_cache_token_page_size, + linear_config=cfg, + grid_num=1, + ) + torch.cuda.synchronize() + + torch.testing.assert_close(restored_gpu_kv, gpu_kv) diff --git a/unit_tests/common/test_linear_att_snapshot_split.py b/unit_tests/common/test_linear_att_snapshot_split.py new file mode 100644 index 0000000000..2ce2833bcf --- /dev/null +++ b/unit_tests/common/test_linear_att_snapshot_split.py @@ -0,0 +1,41 @@ +import pytest +import torch + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") + + +@pytest.mark.parametrize("S", [1, 2, 3]) +@pytest.mark.parametrize("accept_len", [1, 2]) +def test_snapshot_reads_committed_conv_and_ssm(S, accept_len): + from lightllm.common.basemodel.triton_kernel.linear_att_copy import ( + copy_linear_att_state_to_kv_buffer, + ) + + layer_num, dim_conv = 2, 32 + width_narrow = 3 + gpu_conv = torch.zeros(layer_num, 1, dim_conv, width_narrow + S, device="cuda") + off = accept_len - 1 + marker_conv = torch.arange(dim_conv * width_narrow, device="cuda").float().reshape(dim_conv, width_narrow) + gpu_conv[:, 0, :, off : off + width_narrow] = marker_conv + + hv, k, v = 4, 8, 8 + gpu_ssm = torch.zeros(layer_num, 1 * (S + 1), hv, k, v, device="cuda") + marker_ssm = torch.arange(hv * k * v, device="cuda").float().reshape(hv, k, v) + gpu_ssm[:, off, ...] = marker_ssm # block slot 0*(S+1)+off + + cpu_conv = torch.zeros(1, layer_num, dim_conv, width_narrow, device="cuda") + cpu_ssm = torch.zeros(1, layer_num, hv, k, v, device="cuda") + + copy_linear_att_state_to_kv_buffer( + b_req_idx=torch.tensor([0], dtype=torch.int32, device="cuda"), + big_page_buffer_ids=torch.tensor([0], dtype=torch.int32, device="cuda"), + gpu_conv_state=gpu_conv, + gpu_ssm_state=gpu_ssm, + cpu_kv_conv_state=cpu_conv, + cpu_kv_ssm_state=cpu_ssm, + mtp_step=S, + b_num_accepted_tokens=torch.tensor([accept_len], dtype=torch.int32, device="cuda"), + ) + + torch.testing.assert_close(cpu_conv[0], marker_conv.expand(layer_num, dim_conv, width_narrow)) + torch.testing.assert_close(cpu_ssm[0], marker_ssm.expand(layer_num, hv, k, v)) diff --git a/unit_tests/common/test_mtp_verify_extra_state.py b/unit_tests/common/test_mtp_verify_extra_state.py new file mode 100644 index 0000000000..7252af0736 --- /dev/null +++ b/unit_tests/common/test_mtp_verify_extra_state.py @@ -0,0 +1,36 @@ +import types +import torch + +import lightllm.common.basemodel.mtp_verify_extra_state as mod + + +def _state(n_real, mtp_step, is_prefill=False, with_accept=True): + step = mtp_step + 1 + s = types.SimpleNamespace() + s.b_seq_len = torch.arange(1, n_real * step + 1, dtype=torch.int32) + s.b_req_idx = torch.arange(n_real, dtype=torch.int32).repeat_interleave(step) + s.b_mtp_index = torch.arange(step, dtype=torch.int32).repeat(n_real) + s.is_prefill = is_prefill + s.b_num_accepted_tokens = torch.ones(n_real, dtype=torch.int32) if with_accept else None + return s + + +def test_verify_branch_sets_index_rows(monkeypatch): + monkeypatch.setattr(mod, "get_env_start_args", lambda: types.SimpleNamespace(mtp_step=2)) + n_real, mtp_step = 3, 2 + step = mtp_step + 1 + s = _state(n_real, mtp_step) + mod.init_mtp_verify_extra_state(s) + assert s.is_mtp_verify is True + assert s.b_ssm_index_rows.shape == (n_real, step) + assert s.b_gdn_verify_cu_seqlens.tolist() == [0, 3, 6, 9] + assert s.b_conv_buffer_idx.tolist() == [0, 1, 2] # one widened conv slot per req + + +def test_non_verify_branch_no_index_rows(monkeypatch): + monkeypatch.setattr(mod, "get_env_start_args", lambda: types.SimpleNamespace(mtp_step=2)) + s = _state(3, 2, with_accept=False) + mod.init_mtp_verify_extra_state(s) + assert s.is_mtp_verify is False + assert s.b_ssm_index_rows is None + assert s.b_gdn_verify_cu_seqlens is None diff --git a/unit_tests/models/qwen3next/test_causal_conv1d_spec.py b/unit_tests/models/qwen3next/test_causal_conv1d_spec.py new file mode 100644 index 0000000000..e99497ec33 --- /dev/null +++ b/unit_tests/models/qwen3next/test_causal_conv1d_spec.py @@ -0,0 +1,147 @@ +import pytest +import torch + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") + + +def _eager_conv_update(x_seq, conv_state, weight, bias, activation): + # x_seq: (dim, seqlen) tokens to roll in, conv_state: (dim, width-1) history + dim, width = weight.shape + state = conv_state.clone() # (dim, width-1) + outs = [] + for t in range(x_seq.shape[1]): + window = torch.cat([state, x_seq[:, t : t + 1]], dim=1) # (dim, width) + y = (window * weight).sum(dim=1) # depthwise conv + if bias is not None: + y = y + bias + if activation in ("silu", "swish"): + y = torch.nn.functional.silu(y) + outs.append(y) + state = window[:, 1:] # slide + return torch.stack(outs, dim=1), state + + +@pytest.mark.parametrize("S", [0, 1, 2, 3]) +def test_spec_conv_matches_eager_after_partial_accept(S): + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update + + torch.manual_seed(0) + dim, width = 64, 4 + seqlen = S + 1 + state_len = (width - 1) + S + device = "cuda" + dtype = torch.float32 + + weight = torch.randn(dim, width, device=device, dtype=dtype) + bias = torch.randn(dim, device=device, dtype=dtype) + + conv_state = torch.zeros(1, dim, state_len, device=device, dtype=dtype) + committed_hist = torch.randn(dim, width - 1, device=device, dtype=dtype) + conv_state[0, :, : width - 1] = committed_hist + + x = torch.randn(seqlen, dim, device=device, dtype=dtype) # candidate tokens + + out = causal_conv1d_update( + x.clone(), + conv_state, + weight, + bias=bias, + activation="silu", + conv_state_indices=torch.zeros(1, dtype=torch.int32, device=device), + num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), # fresh: read offset 0 + query_start_loc=torch.tensor([0, seqlen], dtype=torch.int32, device=device), + ) + + ref_out, _ = _eager_conv_update(x.t(), committed_hist, weight, bias, "silu") + torch.testing.assert_close(out.t(), ref_out, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("S", [1, 2, 3]) +def test_spec_conv_reads_from_partial_accept_offset(S): + # Exercise the nonzero read offset: num_accepted_tokens=2 -> read offset 1. + # The widened slot front-loads a STALE token then the real committed history; + # the kernel must read history starting at (num_accepted_tokens-1)==1, i.e. + # conv_state[:, 1:width], NOT the stale token at index 0. + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update + + torch.manual_seed(0) + dim, width = 64, 4 + seqlen = S + 1 + state_len = (width - 1) + S + device = "cuda" + dtype = torch.float32 + + weight = torch.randn(dim, width, device=device, dtype=dtype) + bias = torch.randn(dim, device=device, dtype=dtype) + + conv_state = torch.zeros(1, dim, state_len, device=device, dtype=dtype) + # tokens [0 .. width-1] hold [stale, h1, h2, ...]: a stale front token then history + seed = torch.randn(dim, width, device=device, dtype=dtype) + conv_state[0, :, :width] = seed + stale_front = conv_state[0, :, :width].clone() # snapshot of the seeded window + + x = torch.randn(seqlen, dim, device=device, dtype=dtype) # candidate tokens + + out = causal_conv1d_update( + x.clone(), + conv_state, + weight, + bias=bias, + activation="silu", + conv_state_indices=torch.zeros(1, dtype=torch.int32, device=device), + num_accepted_tokens=2 * torch.ones(1, dtype=torch.int32, device=device), # read offset 1 + query_start_loc=torch.tensor([0, seqlen], dtype=torch.int32, device=device), + ) + + # Eager reference starts from the offset-1 window: committed history excluding + # the stale front token == conv_state[:, 1:width]. + committed_hist = stale_front[:, 1:width] + ref_out, _ = _eager_conv_update(x.t(), committed_hist, weight, bias, "silu") + torch.testing.assert_close(out.t(), ref_out, rtol=1e-3, atol=1e-3) + + +def test_spec_conv_varlen_update_is_cuda_graph_capturable(): + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update + + torch.manual_seed(0) + dim, width, S = 64, 4, 1 + seqlen = S + 1 + state_len = (width - 1) + S + device = "cuda" + dtype = torch.float32 + + weight = torch.randn(dim, width, device=device, dtype=dtype) + bias = torch.randn(dim, device=device, dtype=dtype) + conv_state = torch.zeros(1, dim, state_len, device=device, dtype=dtype) + x = torch.randn(seqlen, dim, device=device, dtype=dtype) + conv_state_indices = torch.zeros(1, dtype=torch.int32, device=device) + num_accepted_tokens = torch.ones(1, dtype=torch.int32, device=device) + query_start_loc = torch.tensor([0, seqlen], dtype=torch.int32, device=device) + + # Compile/warm the Triton kernel before capture; the regression is the wrapper's + # host sync on query_start_loc during capture, not first-use compilation. + causal_conv1d_update( + x.clone(), + conv_state, + weight, + bias=bias, + activation="silu", + conv_state_indices=conv_state_indices, + num_accepted_tokens=num_accepted_tokens, + query_start_loc=query_start_loc, + ) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + static_x = x.clone() + with torch.cuda.graph(graph): + causal_conv1d_update( + static_x, + conv_state, + weight, + bias=bias, + activation="silu", + conv_state_indices=conv_state_indices, + num_accepted_tokens=num_accepted_tokens, + query_start_loc=query_start_loc, + ) diff --git a/unit_tests/models/qwen3next/test_conv_prefill_decode_roundtrip.py b/unit_tests/models/qwen3next/test_conv_prefill_decode_roundtrip.py new file mode 100644 index 0000000000..2fca8bfc57 --- /dev/null +++ b/unit_tests/models/qwen3next/test_conv_prefill_decode_roundtrip.py @@ -0,0 +1,74 @@ +import pytest +import torch + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") + + +def _eager_conv_update(x_seq, conv_state, weight, bias, activation): + # x_seq: (dim, seqlen) tokens to roll in, conv_state: (dim, width-1) history + state = conv_state.clone() + outs = [] + for t in range(x_seq.shape[1]): + window = torch.cat([state, x_seq[:, t : t + 1]], dim=1) # (dim, width) + y = (window * weight).sum(dim=1) + if bias is not None: + y = y + bias + if activation in ("silu", "swish"): + y = torch.nn.functional.silu(y) + outs.append(y) + state = window[:, 1:] + return torch.stack(outs, dim=1), state + + +@pytest.mark.parametrize("S", [1, 2, 3]) +def test_prefill_writes_first_columns_then_decode_reads_them(S): + from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update + + torch.manual_seed(0) + dim, width = 64, 4 + prefill_len = 7 + state_len = (width - 1) + S # widened slot + device, dtype = "cuda", torch.float32 + + weight = torch.randn(dim, width, device=device, dtype=dtype) + bias = torch.randn(dim, device=device, dtype=dtype) + + # ---- PREFILL: populate one widened conv slot from a fresh (no initial state) sequence ---- + conv_states = torch.zeros(1, dim, state_len, device=device, dtype=dtype) + x_prefill = torch.randn(dim, prefill_len, device=device, dtype=dtype) # (dim, total_tokens) + causal_conv1d_fn( + x_prefill.clone(), + weight, + bias=bias, + query_start_loc=torch.tensor([0, prefill_len], dtype=torch.int32, device=device), + cache_indices=torch.zeros(1, dtype=torch.int32, device=device), + has_initial_state=torch.zeros(1, dtype=torch.bool, device=device), + conv_states=conv_states, + activation="silu", + ) + + # Contract (a): committed state lands in the FIRST width-1 columns; widened tail untouched. + committed_hist = conv_states[0, :, : width - 1].clone() + expected_hist = x_prefill[:, -(width - 1) :] # trailing window for a fresh causal conv + torch.testing.assert_close(committed_hist, expected_hist, rtol=1e-3, atol=1e-3) + if state_len > width - 1: + assert torch.count_nonzero(conv_states[0, :, width - 1 :]) == 0, "widened tail must be untouched by prefill" + + # ---- FIRST DECODE: verify reads at offset accept_len-1 == 0 -> columns [0:width-1] ---- + seqlen = S + 1 + x_decode = torch.randn(seqlen, dim, device=device, dtype=dtype) + out = causal_conv1d_update( + x_decode.clone(), + conv_states, + weight, + bias=bias, + activation="silu", + conv_state_indices=torch.zeros(1, dtype=torch.int32, device=device), + num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), # offset 0 + query_start_loc=torch.tensor([0, seqlen], dtype=torch.int32, device=device), + ) + + # Contract (b): decode output must match an eager conv seeded from the prefill-written history. + ref_out, _ = _eager_conv_update(x_decode.t(), committed_hist, weight, bias, "silu") + torch.testing.assert_close(out.t(), ref_out, rtol=1e-3, atol=1e-3) diff --git a/unit_tests/models/qwen3next/test_gdn_verify_equivalence.py b/unit_tests/models/qwen3next/test_gdn_verify_equivalence.py new file mode 100644 index 0000000000..7481607d54 --- /dev/null +++ b/unit_tests/models/qwen3next/test_gdn_verify_equivalence.py @@ -0,0 +1,194 @@ +import pytest +import torch + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") + + +@pytest.mark.parametrize("S", [1, 2, 3]) +def test_gdn_verify_state_equals_sequential_decode(S): + from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( + fused_recurrent_gated_delta_rule, + ) + + torch.manual_seed(0) + HV, K, V = 4, 16, 16 + T = S + 1 + device = "cuda" + + def rand_qkv(t): + q = torch.randn(1, t, HV, K, device=device) + k = torch.nn.functional.normalize(torch.randn(1, t, HV, K, device=device), dim=-1) + v = torch.randn(1, t, HV, V, device=device) + g = torch.nn.functional.logsigmoid(torch.rand(1, t, HV, device=device)) + beta = torch.rand(1, t, HV, device=device).sigmoid() + return q, k, v, g, beta + + q, k, v, g, beta = rand_qkv(T) + + ref_state = torch.zeros(1, HV, K, V, device=device) + for t in range(T): + _, ref_state = fused_recurrent_gated_delta_rule( + q=q[:, t : t + 1], + k=k[:, t : t + 1], + v=v[:, t : t + 1], + g=g[:, t : t + 1], + beta=beta[:, t : t + 1], + initial_state=ref_state, + inplace_final_state=False, + ) + + block = torch.zeros(T, HV, K, V, device=device) + ssm_idx = torch.arange(T, device=device).view(1, T) + fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=block, + inplace_final_state=True, + cu_seqlens=torch.tensor([0, T], dtype=torch.long, device=device), + ssm_state_indices=ssm_idx, + ssm_state_write_indices=ssm_idx, + num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), + ) + torch.testing.assert_close(block[T - 1], ref_state[0], rtol=2e-2, atol=2e-2) + + +@pytest.mark.parametrize("S", [1, 2, 3]) +def test_gdn_verify_output_equals_sequential_decode_fused(S): + """H1: the LIVE verify combination - varlen + FUSED gating (A_log/dt_bias/a_raw/b_raw) + + spec-decode - must produce per-position OUTPUT o[t] identical to running the proven + T=1 decode recurrence sequentially. The original test only checked the final SSM state + with EXPLICIT g/beta; it never verified o[t] nor the fused-gating path that + _gdn_verify_kernel actually uses.""" + from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( + fused_recurrent_gated_delta_rule, + ) + + torch.manual_seed(0) + HV, K, V = 4, 16, 16 + H = HV + T = S + 1 + device = "cuda" + + q = torch.randn(1, T, H, K, device=device) + k = torch.nn.functional.normalize(torch.randn(1, T, H, K, device=device), dim=-1) + v = torch.randn(1, T, HV, V, device=device) + # Raw gating inputs (pre-activation), exactly as the model feeds the fused path. + a_raw = torch.randn(T, HV, device=device) + b_raw = torch.randn(T, HV, device=device) + A_log = torch.randn(HV, device=device) + dt_bias = torch.randn(HV, device=device) + + # Reference: sequential T=1 decode through the proven non-varlen fused path. + ref_state = torch.zeros(1, HV, K, V, device=device) + ref_o = torch.zeros(T, HV, V, device=device) + for t in range(T): + o_t, ref_state = fused_recurrent_gated_delta_rule( + q=q[:, t : t + 1], + k=k[:, t : t + 1], + v=v[:, t : t + 1], + initial_state=ref_state, + inplace_final_state=False, + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw[t : t + 1], + b_raw=b_raw[t : t + 1], + ) + ref_o[t] = o_t[0, 0] + + # Verify path: single varlen call with fused gating + spec-decode indices, + # mirroring _gdn_verify_kernel for a single request, num_accepted=1. + block = torch.zeros(T, HV, K, V, device=device) + ssm_idx = torch.arange(T, device=device).view(1, T) + o, _ = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + initial_state=block, + inplace_final_state=True, + cu_seqlens=torch.tensor([0, T], dtype=torch.long, device=device), + ssm_state_indices=ssm_idx, + ssm_state_write_indices=ssm_idx, + num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw, + b_raw=b_raw, + ) + o = o.view(T, HV, V) + torch.testing.assert_close(o, ref_o, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(block[T - 1], ref_state[0], rtol=2e-2, atol=2e-2) + + +@pytest.mark.parametrize("num_accepted", [1, 2]) +def test_gdn_verify_reads_committed_slot_by_num_accepted(num_accepted): + """The verify kernel must read the per-request initial state from the SSM block + slot at offset (num_accepted-1) -- i.e. the state committed after the previous + step's last accepted token. This is the read path exercised by the FIRST decode + after an accept-`num_accepted` step. A decoy is written into the OTHER block slot + to prove the kernel reads the correct one and ignores the rest of the (S+1) block.""" + from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( + fused_recurrent_gated_delta_rule, + ) + + torch.manual_seed(0) + HV, K, V = 4, 16, 16 + S = 1 + T = S + 1 + device = "cuda" + + q = torch.randn(1, T, HV, K, device=device) + k = torch.nn.functional.normalize(torch.randn(1, T, HV, K, device=device), dim=-1) + v = torch.randn(1, T, HV, V, device=device) + a_raw = torch.randn(T, HV, device=device) + b_raw = torch.randn(T, HV, device=device) + A_log = torch.randn(HV, device=device) + dt_bias = torch.randn(HV, device=device) + + # (S+1) block: the committed slot is (num_accepted-1); the others hold decoys + # that MUST NOT be read. + block = torch.randn(T, HV, K, V, device=device) * 5.0 + committed = torch.randn(1, HV, K, V, device=device) + block[num_accepted - 1] = committed[0] + + ref_state = committed.clone() + ref_o = torch.zeros(T, HV, V, device=device) + for t in range(T): + o_t, ref_state = fused_recurrent_gated_delta_rule( + q=q[:, t : t + 1], + k=k[:, t : t + 1], + v=v[:, t : t + 1], + initial_state=ref_state, + inplace_final_state=False, + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw[t : t + 1], + b_raw=b_raw[t : t + 1], + ) + ref_o[t] = o_t[0, 0] + + blk = block.clone() + ssm_idx = torch.arange(T, device=device).view(1, T) + o, _ = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + initial_state=blk, + inplace_final_state=True, + cu_seqlens=torch.tensor([0, T], dtype=torch.long, device=device), + ssm_state_indices=ssm_idx, + ssm_state_write_indices=ssm_idx, + num_accepted_tokens=torch.tensor([num_accepted], dtype=torch.int32, device=device), + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw, + b_raw=b_raw, + ) + o = o.view(T, HV, V) + torch.testing.assert_close(o, ref_o, rtol=2e-2, atol=2e-2) From 28879595980473caf5a095a15224e492bb965b96 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Jun 2026 19:06:09 +0800 Subject: [PATCH 07/19] Fix Qwen3Next MTP linear-att page moves --- .../qwen3next_mem_manager.py | 27 +++-- .../basemodel/test_mtp_decode_cuda_graph.py | 2 +- .../test_qwen3next_linear_att_page_helper.py | 112 ++++++++++++++++++ 3 files changed, 130 insertions(+), 11 deletions(-) create mode 100644 unit_tests/common/test_qwen3next_linear_att_page_helper.py diff --git a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py index c7ce9d96ba..566ce5ea3f 100644 --- a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py @@ -208,9 +208,9 @@ def write_req_to_page( dp_mems: List["Qwen3NextMemManager"], ): conv_page, ssm_page = self.view_page_to_linear_att_state(page_index) - req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1) + conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx) for tp_index, mem in enumerate(dp_mems): - self._write_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page) + self._write_one_rank(mem, tp_index, conv_req_idx, ssm_req_idx, conv_page, ssm_page) return def read_page_to_req( @@ -220,21 +220,27 @@ def read_page_to_req( dp_mems: List["Qwen3NextMemManager"], ): conv_page, ssm_page = self.view_page_to_linear_att_state(page_index) - req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1) + conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx) for tp_index, mem in enumerate(dp_mems): - self._read_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page) + self._read_one_rank(mem, tp_index, conv_req_idx, ssm_req_idx, conv_page, ssm_page) return + def _get_req_state_indexes(self, req_idx: int): + mtp_size = get_env_start_args().mtp_step + 1 + # Conv is one widened slot per request; SSM keeps the historical S+1 block layout. + return req_idx, req_idx * mtp_size + def _write_one_rank( self, mem: "Qwen3NextMemManager", tp_index: int, - req_buffer_idx: int, + conv_req_idx: int, + ssm_req_idx: int, conv_page: torch.Tensor, ssm_page: torch.Tensor, ): - conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...] - ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...] + conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]] + ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...] self._copy_conv_state_to_page(conv_state, conv_page, mem, tp_index) self._copy_ssm_state_to_page(ssm_state, ssm_page, mem, tp_index) return @@ -408,12 +414,13 @@ def _read_one_rank( self, mem: "Qwen3NextMemManager", tp_index: int, - req_buffer_idx: int, + conv_req_idx: int, + ssm_req_idx: int, conv_page: torch.Tensor, ssm_page: torch.Tensor, ): - conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...] - ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...] + conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]] + ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...] self._copy_page_to_conv_state(conv_page, conv_state, mem, tp_index) self._copy_page_to_ssm_state(ssm_page, ssm_state, mem, tp_index) return diff --git a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py index dd8fec1640..54d61f18c7 100644 --- a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py +++ b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py @@ -103,7 +103,7 @@ class Qwen3_5MOETpPartModel: pass class Qwen3_5MoeMTPModel: - pass + is_mtp_draft_model = True graph = CudaGraph.__new__(CudaGraph) graph.mtp_step = 2 diff --git a/unit_tests/common/test_qwen3next_linear_att_page_helper.py b/unit_tests/common/test_qwen3next_linear_att_page_helper.py new file mode 100644 index 0000000000..e4c2e71c76 --- /dev/null +++ b/unit_tests/common/test_qwen3next_linear_att_page_helper.py @@ -0,0 +1,112 @@ +from types import SimpleNamespace + +import torch + + +class _Buf: + def __init__(self, tensor): + self.buffer = tensor + + +def _make_config(): + return SimpleNamespace( + tp_world_size=1, + linear_layer_num=1, + conv_kernel_size=4, + global_linear_k_heads=1, + global_linear_v_heads=1, + num_linear_k_heads=1, + num_linear_v_heads=1, + head_linear_k_dim=2, + head_linear_v_dim=3, + ) + + +def _make_mem(mtp_step=2, req_slots=4): + config = _make_config() + conv_dim = ( + 2 * config.num_linear_k_heads * config.head_linear_k_dim + + config.num_linear_v_heads * config.head_linear_v_dim + ) + narrow_w = config.conv_kernel_size - 1 + conv = torch.full( + (config.linear_layer_num, req_slots, conv_dim, narrow_w + mtp_step), + -9.0, + dtype=torch.float32, + ) + ssm = torch.full( + ( + config.linear_layer_num, + req_slots * (mtp_step + 1), + config.num_linear_v_heads, + config.head_linear_k_dim, + config.head_linear_v_dim, + ), + -11.0, + dtype=torch.float32, + ) + return SimpleNamespace( + linear_config=config, + req_to_conv_state=_Buf(conv), + req_to_ssm_state=_Buf(ssm), + kv_move_buffer=torch.zeros((1, 4096), dtype=torch.uint8), + ) + + +def test_page_helper_writes_req_conv_slot_and_narrow_width(monkeypatch): + import lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager as qwen3next_mem_manager + from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextLinearAttPageHelper + + mtp_step = 2 + req_idx = 2 + monkeypatch.setattr(qwen3next_mem_manager, "get_env_start_args", lambda: SimpleNamespace(mtp_step=mtp_step)) + + mem = _make_mem(mtp_step=mtp_step) + helper = Qwen3NextLinearAttPageHelper(mem) + mem.kv_move_buffer = torch.zeros((1, helper.state_nbytes), dtype=torch.uint8) + + narrow_w = helper.conv_shape[-1] + marker_conv = torch.arange( + helper.conv_shape[0] * helper.conv_shape[1] * narrow_w, + dtype=torch.float32, + ).view(helper.conv_shape) + marker_ssm = torch.arange( + helper.ssm_shape[0] * helper.ssm_shape[1] * helper.ssm_shape[2] * helper.ssm_shape[3], + dtype=torch.float32, + ).view(helper.ssm_shape) + + mem.req_to_conv_state.buffer[:, req_idx, :, :narrow_w] = marker_conv + mem.req_to_conv_state.buffer[:, req_idx, :, narrow_w:] = 999.0 + mem.req_to_ssm_state.buffer[:, req_idx * (mtp_step + 1), ...] = marker_ssm + + helper.write_req_to_page(page_index=0, req_idx=req_idx, dp_mems=[mem]) + + conv_page, ssm_page = helper.view_page_to_linear_att_state(page_index=0) + torch.testing.assert_close(conv_page, marker_conv) + torch.testing.assert_close(ssm_page, marker_ssm) + + +def test_page_helper_restores_narrow_conv_to_req_slot(monkeypatch): + import lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager as qwen3next_mem_manager + from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextLinearAttPageHelper + + mtp_step = 2 + req_idx = 2 + monkeypatch.setattr(qwen3next_mem_manager, "get_env_start_args", lambda: SimpleNamespace(mtp_step=mtp_step)) + + mem = _make_mem(mtp_step=mtp_step) + helper = Qwen3NextLinearAttPageHelper(mem) + mem.kv_move_buffer = torch.zeros((1, helper.state_nbytes), dtype=torch.uint8) + conv_page, ssm_page = helper.view_page_to_linear_att_state(page_index=0) + + marker_conv = torch.arange(conv_page.numel(), dtype=torch.float32).view_as(conv_page) + marker_ssm = torch.arange(ssm_page.numel(), dtype=torch.float32).view_as(ssm_page) + conv_page.copy_(marker_conv) + ssm_page.copy_(marker_ssm) + + helper.read_page_to_req(page_index=0, req_idx=req_idx, dp_mems=[mem]) + + narrow_w = helper.conv_shape[-1] + torch.testing.assert_close(mem.req_to_conv_state.buffer[:, req_idx, :, :narrow_w], marker_conv) + assert torch.all(mem.req_to_conv_state.buffer[:, req_idx, :, narrow_w:] == -9.0) + torch.testing.assert_close(mem.req_to_ssm_state.buffer[:, req_idx * (mtp_step + 1), ...], marker_ssm) From faacfca893eeecdaf1b48b294ade80cab02cb487 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 15 Jun 2026 21:20:42 +0800 Subject: [PATCH 08/19] revert formatting churn on pre-existing code Restore blank lines that were stripped from pre-existing definitions (black-induced reformatting of upstream code that this PR didn't functionally change). Keeps the diff focused on the MTP feature; fixing historical formatting is out of scope for this PR. --- lightllm/common/req_manager.py | 1 + .../models/qwen3next/layer_infer/transformer_layer_infer.py | 2 ++ lightllm/models/qwen3next/model.py | 1 + 3 files changed, 4 insertions(+) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 6673243c9f..c01f2d7c0e 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -287,6 +287,7 @@ def get_mamba_cache(self, layer_idx_in_all: int): return conv_states, ssm_states def copy_big_page_buffer_to_linear_att_state(self, big_page_buffer_idx: int, req: "InferReq"): + from .linear_att_cache_manager import LinearAttCacheManager big_page_buffers: LinearAttCacheManager = self.mem_manager.linear_att_big_page_buffers diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 76c273c0e7..504e340452 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -45,6 +45,7 @@ def __init__(self, layer_num, network_config): return def _init_linear_layer_metadata(self, layer_num, network_config): + # Linear attention specific dimensions self.num_v_heads = network_config["linear_num_value_heads"] self.num_k_heads = network_config["linear_num_key_heads"] @@ -120,6 +121,7 @@ def _compute_shared_expert( def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): + shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input.view(-1, self.embed_dim_) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index f61d9e4c6a..5d60bb28ff 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -26,6 +26,7 @@ @ModelRegistry("qwen3_next") class Qwen3NextTpPartModel(Qwen3MOEModel): + # weight class pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight transformer_weight_class = Qwen3NextTransformerLayerWeight From c2dc76b44ec6d1a651d464a5b22895baabb47b9d Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 15 Jun 2026 21:22:08 +0800 Subject: [PATCH 09/19] revert(mtp): drop eagle reduced-batch draft optimization Scope this branch to Qwen3.5 MTP support only by rolling back the EAGLE-mode draft optimization. The draft model again runs the full (mtp_step+1)-expanded verify layout instead of being narrowed to the single accepted row per request. - dp/chunked _draft_decode_eagle: restore full-layout draft (copy.copy + b_num_accepted_tokens=None so it routes to the (bs, False) graph); drop the per-rank padding helpers and accepted-row narrowing. - base_backend: remove _build_eagle_accepted_draft_input / _scatter_accepted_next_token_ids. - cuda_graph: the draft runs at multiples of (mtp_step+1) again, so collapse the dual batch-size sets to one and delete the now-redundant _get_graph_batch_sizes routing. Keep the (bs, is_mtp_verify_decode) graph key + verify-layout warmup (core GDN verify support, not the optimization). - static benchmark: eagle path now measures the full-layout draft cost. - tests: drop the two narrowed-draft tests; rewrite the dual-set tests to the single-set model (still cover the verify/normal key distinction). --- lightllm/common/basemodel/basemodel.py | 10 +- lightllm/common/basemodel/cuda_graph.py | 47 ++--- .../model_infer/mode_backend/base_backend.py | 59 +------ .../mode_backend/chunked_prefill/impl.py | 44 +++-- .../mode_backend/dp_backend/impl.py | 143 +++++----------- .../static_inference/model_infer_mtp.py | 34 ++-- .../basemodel/test_mtp_decode_cuda_graph.py | 162 +----------------- 7 files changed, 100 insertions(+), 399 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 54e3be1512..4d307ec2dd 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -683,10 +683,7 @@ def _decode( if self.graph is not None and self.graph.can_run( batch_size=infer_batch_size, max_len_in_batch=model_input.max_kv_seq_len ): - infer_batch_size = self.graph.find_closest_graph_batch_size( - batch_size=infer_batch_size, - is_mtp_verify_decode=is_mtp_verify_decode, - ) + infer_batch_size = self.graph.find_closest_graph_batch_size(batch_size=infer_batch_size) model_input = self._create_padded_decode_model_input( model_input=model_input, new_batch_size=infer_batch_size ) @@ -936,10 +933,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode is_mtp_verify_decode = is_mtp_verify_decode_fn(self.args.mtp_step, model_input0.b_num_accepted_tokens) if self.graph is not None and self.graph.can_run(infer_batch_size, max_len_in_batch): - infer_batch_size = self.graph.find_closest_graph_batch_size( - infer_batch_size, - is_mtp_verify_decode=is_mtp_verify_decode, - ) + infer_batch_size = self.graph.find_closest_graph_batch_size(infer_batch_size) # TODO 如果支持动态步数的 mtp,在不同的mtp步上,model_input0 和 model_input1 的内部batch size可能不 # 一致,需要按照较高 batch size 进行graph的寻找,同时,进行有效的恢复。 padded_model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index d20de7afb8..ddf78f4c6f 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -29,21 +29,15 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap - self.normal_cuda_graph_batch_sizes = self._build_cuda_graph_batch_sizes(batch_size_multiple=1) - if self.mtp_step > 0: - self.mtp_verify_cuda_graph_batch_sizes = self._build_cuda_graph_batch_sizes( - batch_size_multiple=self.mtp_step + 1 - ) - logger.info(f"normal cuda graph batch_sizes: {self.normal_cuda_graph_batch_sizes}") - logger.info(f"mtp verify cuda graph batch_sizes: {self.mtp_verify_cuda_graph_batch_sizes}") - else: - self.mtp_verify_cuda_graph_batch_sizes = self.normal_cuda_graph_batch_sizes - logger.info(f"cuda graph batch_sizes: {self.normal_cuda_graph_batch_sizes}") + # With MTP enabled, both the main-model verify forward and the draft (MTP) forward run over + # the (mtp_step+1)-expanded decode layout, so all decode batch sizes are multiples of + # (mtp_step+1); a single graph batch-size set serves both. Verify vs normal graphs are told + # apart by the is_mtp_verify_decode component of the graph key, not by a separate set. + batch_size_multiple = self.mtp_step + 1 if self.mtp_step > 0 else 1 + self.cuda_graph_batch_sizes = self._build_cuda_graph_batch_sizes(batch_size_multiple=batch_size_multiple) + logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") def _build_cuda_graph_batch_sizes(self, batch_size_multiple: int): - # gen cuda graph batch_sizes - # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] - # and [graph_split_batch_size + graph_grow_step_size, ...] graph_split_batch_size = self.args.graph_split_batch_size * batch_size_multiple graph_grow_step_size = self.args.graph_grow_step_size * batch_size_multiple @@ -75,22 +69,16 @@ def _decode_graph_key(self, infer_state: InferStateInfo): return (infer_state.input_ids.shape[0], is_mtp_verify_decode) def need_capture(self, batch_size, is_mtp_verify_decode=False): - find_batch_size = self.find_closest_graph_batch_size(batch_size, is_mtp_verify_decode=is_mtp_verify_decode) + find_batch_size = self.find_closest_graph_batch_size(batch_size) if find_batch_size is not None: return (find_batch_size, is_mtp_verify_decode) not in self.graph else: assert False, "dead code" - def _get_graph_batch_sizes(self, is_mtp_verify_decode=False): - if is_mtp_verify_decode: - return self.mtp_verify_cuda_graph_batch_sizes - return self.normal_cuda_graph_batch_sizes - - def find_closest_graph_batch_size(self, batch_size, is_mtp_verify_decode=False): - graph_batch_sizes = self._get_graph_batch_sizes(is_mtp_verify_decode=is_mtp_verify_decode) - index = bisect.bisect_left(graph_batch_sizes, batch_size) - if index < len(graph_batch_sizes): - find_batch_size = graph_batch_sizes[index] + def find_closest_graph_batch_size(self, batch_size): + index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size) + if index < len(self.cuda_graph_batch_sizes): + find_batch_size = self.cuda_graph_batch_sizes[index] return find_batch_size else: return None @@ -150,13 +138,12 @@ def _is_mtp_draft_model(self, model): return getattr(model, "is_mtp_draft_model", False) def _iter_warmup_graph_layouts(self, model): - if self.mtp_step > 0: - if self._is_mtp_draft_model(model): - yield False, self.normal_cuda_graph_batch_sizes - else: - yield True, self.mtp_verify_cuda_graph_batch_sizes + # main-model decode is a verify forward; the draft (MTP) model takes the normal layout. + # Both warm up over the same batch-size set; only the verify flag (graph key + layout) differs. + if self.mtp_step > 0 and not self._is_mtp_draft_model(model): + yield True, self.cuda_graph_batch_sizes else: - yield False, self.normal_cuda_graph_batch_sizes + yield False, self.cuda_graph_batch_sizes def _capture_decode(self, decode_func, infer_state: InferStateInfo): graph_obj = torch.cuda.CUDAGraph() diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index cde5c03000..e7fa58712a 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -1,5 +1,4 @@ import os -import copy import numpy as np import torch import time @@ -17,7 +16,7 @@ from lightllm.common.linear_att_cache_manager import LinearAttCacheManager from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import LinearAttPagedRadixCache from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache -from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput +from lightllm.common.basemodel.batch_objs import ModelOutput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name @@ -746,62 +745,6 @@ def _verify_mtp_v2( ) return mtp_accept_len, accepted_index - def _build_eagle_accepted_draft_input( - self, - main_model_input: ModelInput, - main_model_output: ModelOutput, - next_token_ids: torch.Tensor, - mtp_accept_len: torch.Tensor, - b_req_mtp_start_loc: torch.Tensor, - ): - accepted_row_idx = b_req_mtp_start_loc + mtp_accept_len - 1 - accepted_row_idx_long = accepted_row_idx.long() - - draft_model_input = copy.copy(main_model_input) - draft_model_input.batch_size = accepted_row_idx.shape[0] - draft_model_input.total_token_num = draft_model_input.batch_size * main_model_input.max_kv_seq_len - draft_model_input.input_ids = next_token_ids.index_select(0, accepted_row_idx_long) - draft_model_input.mtp_draft_input_hiddens = main_model_output.mtp_main_output_hiddens.index_select( - 0, accepted_row_idx_long - ) - draft_model_input.b_req_idx = main_model_input.b_req_idx.index_select(0, accepted_row_idx_long) - draft_model_input.b_mtp_index = main_model_input.b_mtp_index.index_select(0, accepted_row_idx_long) - draft_model_input.b_seq_len = main_model_input.b_seq_len.index_select(0, accepted_row_idx_long) - draft_model_input.b_num_accepted_tokens = None - if main_model_input.mem_indexes is not None: - draft_model_input.mem_indexes = main_model_input.mem_indexes.index_select(0, accepted_row_idx_long) - draft_model_input.mem_indexes_cpu = None - if main_model_input.b_shared_seq_len is not None: - draft_model_input.b_shared_seq_len = main_model_input.b_shared_seq_len.index_select( - 0, accepted_row_idx_long - ) - if main_model_input.b_mark_shared_group is not None: - draft_model_input.b_mark_shared_group = main_model_input.b_mark_shared_group.index_select( - 0, accepted_row_idx_long - ) - - if accepted_row_idx.device.type == "cpu": - selected_rows = accepted_row_idx.tolist() - draft_model_input.multimodal_params = [main_model_input.multimodal_params[i] for i in selected_rows] - else: - draft_model_input.multimodal_params = [ - {"images": [], "audios": []} for _ in range(draft_model_input.batch_size) - ] - - accepted_next_token_ids = draft_model_input.input_ids - accepted_req_idx = draft_model_input.b_req_idx - return draft_model_input, accepted_next_token_ids, accepted_req_idx - - def _scatter_accepted_next_token_ids(self, accepted_req_idx: torch.Tensor, all_next_token_ids: torch.Tensor): - req_to_next_token_ids = self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids - width = all_next_token_ids.shape[1] - req_to_next_token_ids[:, :width].index_copy_( - 0, - accepted_req_idx.long(), - all_next_token_ids.to(dtype=req_to_next_token_ids.dtype), - ) - return - def _update_mtp_accept_ratio( self, decode_reqs: List[InferReq], diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index ed01c14a53..2a9a2e7492 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -357,7 +357,6 @@ def _draft_decode_vanilla( # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型 decode_mtp 设置的 # verify 布局,命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError # (cudagraph 关闭时则会在扁平的 draft batch 上误用 S+1 分组的 verify attention)。 - # 镜像 eagle 路径 _build_eagle_accepted_draft_input 中清空 b_num_accepted_tokens 的处理。 draft_model_input = copy.copy(main_model_input) draft_model_input.b_num_accepted_tokens = None draft_model_output = main_model_output @@ -392,47 +391,46 @@ def _draft_decode_eagle( mtp_accept_len: torch.Tensor, b_req_mtp_start_loc: torch.Tensor, ): - num_reqs = b_req_mtp_start_loc.shape[0] + batch_size = main_model_input.batch_size + num_reqs = batch_size // (self.mtp_step + 1) if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_reqs * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(num_reqs * self.mtp_step) eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) - (draft_model_input, draft_next_token_ids, accepted_req_idx,) = self._build_eagle_accepted_draft_input( - main_model_input=main_model_input, - main_model_output=main_model_output, - next_token_ids=next_token_ids, - mtp_accept_len=mtp_accept_len, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) + # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, + # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型 decode_mtp 设置的 + # verify 布局,命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError。 + draft_model_input = copy.copy(main_model_input) + draft_model_input.b_num_accepted_tokens = None draft_model_output = main_model_output + draft_next_token_ids = next_token_ids all_next_token_ids = [] - all_next_token_ids.append(draft_next_token_ids) - - mtp_size = self.mtp_step + 1 - main_mem_indexes = main_model_input.mem_indexes.view(num_reqs, mtp_size) - eagle_mem_indexes_by_req = eagle_mem_indexes.view(self.mtp_step, num_reqs).transpose(0, 1).contiguous() - mem_index_plan = torch.cat([main_mem_indexes, eagle_mem_indexes_by_req], dim=1) - accepted_offsets = mtp_accept_len.long() - 1 - req_offsets = torch.arange(num_reqs, dtype=torch.long, device=mtp_accept_len.device) - + all_next_token_ids.append(next_token_ids) + # process the draft model output for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids - if _step > 0: - draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens - draft_model_input.mem_indexes = mem_index_plan[req_offsets, accepted_offsets + _step] + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 draft_model_input.max_kv_seq_len += 1 + eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] + draft_model_input.mem_indexes = torch.cat( + [draft_model_input.mem_indexes.view(-1, self.mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)], + dim=1, + ).view(-1) all_next_token_ids.append(draft_next_token_ids) all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - self._scatter_accepted_next_token_ids( - accepted_req_idx=accepted_req_idx, + mtp_scatter_next_token_ids( + req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, + b_req_mtp_start_loc=b_req_mtp_start_loc, all_next_token_ids=all_next_token_ids, + b_req_idx=main_model_input.b_req_idx, + mtp_accept_len=mtp_accept_len, ) return eagle_mem_indexes_cpu diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 43ed89d691..b245ac33d8 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -585,64 +585,6 @@ def _draft_decode_vanilla( ) return None - def _build_padding_draft_input(self, model_input: ModelInput, model_output: ModelOutput, common_req_num: int): - """ - 构造一个纯 padding 的 draft 输入,用于本 rank 没有真实 decode 请求 (real_req_num == 0) - 但其它 dp rank 仍有请求、需要本 rank 同步参与 mtp_step 次 draft forward 的集合通信的场景。 - - 从已 padding 的 main model_input 中按 (mtp_step+1) 分组取每组首行 (mtp_index==0) 即可, - 这些行均为 HOLD_REQUEST_ID / HOLD_TOKEN_MEMINDEX 的占位行。step0 的 hiddens 沿用主模型 - 对应占位行的 mtp_main_output_hiddens, 与原 DP 实现 (step0 使用 model_output.mtp_main_output_hiddens) - 保持一致, 避免 None 触发 draft forward 崩溃。 - """ - mtp_size = self.mtp_step + 1 - select_idx = torch.arange(common_req_num, dtype=torch.long, device=model_input.b_req_idx.device) * mtp_size - - draft_model_input = copy.copy(model_input) - draft_model_input.batch_size = common_req_num - draft_model_input.total_token_num = common_req_num * model_input.max_kv_seq_len - draft_model_input.b_num_accepted_tokens = None - draft_model_input.b_req_idx = model_input.b_req_idx.index_select(0, select_idx) - draft_model_input.b_mtp_index = model_input.b_mtp_index.index_select(0, select_idx) - draft_model_input.b_seq_len = model_input.b_seq_len.index_select(0, select_idx) - draft_model_input.mem_indexes = model_input.mem_indexes.index_select(0, select_idx) - draft_model_input.mem_indexes_cpu = None - draft_model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens.index_select(0, select_idx) - draft_model_input.multimodal_params = [{"images": [], "audios": []} for _ in range(common_req_num)] - return draft_model_input - - def _pad_draft_input_to(self, draft_model_input: ModelInput, target_req_num: int): - """ - 将 shrink 到 real_req_num 行的 draft 输入再 padding 回 target_req_num (= common_req_num) 行, - 使本 rank 的 draft forward 行数与其它 dp rank 对齐,保证 MoE all-to-all / dp all-gather 的 - shape 一致。padding 行采用与 padded_prepare_decode_inputs 相同的占位约定: - b_req_idx -> HOLD_REQUEST_ID, mem_indexes -> HOLD_TOKEN_MEMINDEX。 - """ - cur_req_num = draft_model_input.batch_size - pad_num = target_req_num - cur_req_num - if pad_num <= 0: - return draft_model_input - - hold_req_id = g_infer_context.req_manager.HOLD_REQUEST_ID - hold_mem_idx = g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX - - draft_model_input.input_ids = F.pad(draft_model_input.input_ids, (0, pad_num), value=0) - draft_model_input.b_req_idx = F.pad(draft_model_input.b_req_idx, (0, pad_num), value=hold_req_id) - draft_model_input.b_mtp_index = F.pad(draft_model_input.b_mtp_index, (0, pad_num), value=0) - # padding 行用一个合法的小 seq_len (沿用 padded_prepare_decode_inputs 中 fake req 的约定值 2) - draft_model_input.b_seq_len = F.pad(draft_model_input.b_seq_len, (0, pad_num), value=2) - draft_model_input.mem_indexes = F.pad(draft_model_input.mem_indexes, (0, pad_num), value=hold_mem_idx) - # mtp_draft_input_hiddens 为 (rows, hidden),沿 dim0 在尾部补 0 行 - draft_model_input.mtp_draft_input_hiddens = F.pad( - draft_model_input.mtp_draft_input_hiddens, (0, 0, 0, pad_num), value=0 - ) - draft_model_input.multimodal_params = draft_model_input.multimodal_params + [ - {"images": [], "audios": []} for _ in range(pad_num) - ] - draft_model_input.batch_size = target_req_num - draft_model_input.total_token_num = target_req_num * draft_model_input.max_kv_seq_len - return draft_model_input - def _draft_decode_eagle( self, model_input: ModelInput, @@ -652,65 +594,58 @@ def _draft_decode_eagle( mtp_accept_len: torch.Tensor, req_num: int, ): - mtp_size = self.mtp_step + 1 - real_req_num = req_num // mtp_size - common_req_num = model_input.batch_size // mtp_size - padded_req_num = common_req_num - real_req_num + all_next_token_ids = [] + # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, + # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型设置的 verify 布局, + # 命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError。 + draft_model_input = copy.copy(model_input) + draft_model_input.b_num_accepted_tokens = None + draft_model_output = model_output + all_next_token_ids.append(next_token_ids) + draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") + if req_num > 0: + draft_next_token_ids_gpu[:req_num].copy_(next_token_ids, non_blocking=True) - # 即使本 rank 没有真实请求, 也要为其它 rank 同步运行 mtp_step 次 draft forward 的集合通信。 + real_req_num = req_num // (self.mtp_step + 1) + padded_req_num = model_input.batch_size // (self.mtp_step + 1) - real_req_num if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(real_req_num * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(real_req_num * self.mtp_step) eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) - if real_req_num > 0: - (draft_model_input, draft_next_token_ids, accepted_req_idx,) = self._build_eagle_accepted_draft_input( - main_model_input=model_input, - main_model_output=model_output, - next_token_ids=next_token_ids, - mtp_accept_len=mtp_accept_len, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) - if padded_req_num > 0: - draft_model_input = self._pad_draft_input_to(draft_model_input, common_req_num) - draft_next_token_ids = F.pad(draft_next_token_ids, (0, padded_req_num), value=0) - - main_mem_indexes = model_input.mem_indexes.view(common_req_num, mtp_size) - eagle_padded = F.pad( - eagle_mem_indexes.view(self.mtp_step, real_req_num).transpose(0, 1).contiguous(), - (0, 0, 0, padded_req_num), - value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, - ) # (common_req_num, mtp_step) - mem_index_plan = torch.cat([main_mem_indexes, eagle_padded], dim=1) - accepted_offsets = F.pad(mtp_accept_len.long() - 1, (0, padded_req_num), value=0) - req_offsets = torch.arange(common_req_num, dtype=torch.long, device=mem_index_plan.device) - else: - # 本 rank 无真实请求: 纯 padding draft 输入, 仅用于跟随集合通信, 结果不写回。 - draft_model_input = self._build_padding_draft_input(model_input, model_output, common_req_num) - draft_next_token_ids = torch.zeros((common_req_num,), dtype=torch.int64, device="cuda") - mem_index_plan = model_input.mem_indexes.view(common_req_num, mtp_size) - accepted_offsets = torch.zeros((common_req_num,), dtype=torch.long, device=mem_index_plan.device) - req_offsets = torch.arange(common_req_num, dtype=torch.long, device=mem_index_plan.device) - - draft_model_output = model_output - all_next_token_ids = [draft_next_token_ids] + # process the draft model output for _step in range(self.mtp_step): - draft_model_input.input_ids = draft_next_token_ids - if _step > 0: - draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens - draft_model_input.mem_indexes = mem_index_plan[req_offsets, accepted_offsets + _step] + draft_model_input.input_ids = draft_next_token_ids_gpu + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens + # spec decode: MTP draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + # update the meta info of the inference draft_model_input.b_seq_len += 1 draft_model_input.max_kv_seq_len += 1 - all_next_token_ids.append(draft_next_token_ids) + eagle_mem_indexes_i = eagle_mem_indexes[_step * real_req_num : (_step + 1) * real_req_num] + eagle_mem_indexes_i = F.pad( + input=eagle_mem_indexes_i, + pad=(0, padded_req_num), + mode="constant", + value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, + ) + draft_model_input.mem_indexes = torch.cat( + [draft_model_input.mem_indexes.view(-1, self.mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)], + dim=1, + ).view(-1) + draft_next_token_ids_gpu = self._gen_argmax_token_ids(draft_model_output) + all_next_token_ids.append(draft_next_token_ids_gpu) - if real_req_num > 0: - all_next_token_ids = torch.stack(all_next_token_ids, dim=1)[:real_req_num, :] - self._scatter_accepted_next_token_ids( - accepted_req_idx=accepted_req_idx[:real_req_num], + if req_num > 0: + all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] + all_next_token_ids = all_next_token_ids[0:req_num, :] + mtp_scatter_next_token_ids( + req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, + b_req_mtp_start_loc=b_req_mtp_start_loc, all_next_token_ids=all_next_token_ids, + b_req_idx=model_input.b_req_idx[:req_num], + mtp_accept_len=mtp_accept_len, ) return eagle_mem_indexes_cpu diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index ff31133ae2..f2c21ea261 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -269,8 +269,6 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ accept_len = 1 is_eagle = args.mtp_mode.startswith("eagle") model_input.b_num_accepted_tokens = torch.full((batch_size,), accept_len, dtype=torch.int32, device="cuda") - req_offsets = torch.arange(batch_size, dtype=torch.long, device="cuda") - accepted_row_idx = req_offsets * (mtp_step + 1) + (accept_len - 1) if is_eagle: # EAGLE draft scratch slots (n_real * mtp_step), mirroring _draft_decode_eagle. Allocated # once and reused across steps (throughput bench overwrites draft KV; no correctness check). @@ -298,28 +296,14 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ predict_ids = torch.argmax(model_output.logits, dim=1, keepdim=True) if is_eagle: - # EAGLE draft: shrink to the single accepted row per request (1 row/req), then run the - # draft model mtp_step times. The Qwen3.5 MTP draft is full-attention and takes the - # plain decode layout (b_num_accepted_tokens=None). Mirrors chunked_prefill - # _build_eagle_accepted_draft_input + _draft_decode_eagle so the measured draft cost is - # the real n_real-row cost, not the (mtp_step+1)x-inflated full-batch cost. - main_mem = model_input.mem_indexes.view(batch_size, mtp_step + 1) - eagle_mem_by_req = eagle_mem_indexes.view(mtp_step, batch_size).transpose(0, 1).contiguous() - mem_index_plan = torch.cat([main_mem, eagle_mem_by_req], dim=1) - + # EAGLE draft: full (mtp_step+1)-expanded batch, plain decode layout (the Qwen3.5 MTP + # draft is full-attention and takes b_num_accepted_tokens=None). Mirrors chunked_prefill + # _draft_decode_eagle: run the draft model mtp_step times, allocating fresh KV slots and + # shifting mem_indexes one column per step. draft_model_input = copy.copy(model_input) - draft_model_input.batch_size = batch_size - draft_model_input.total_token_num = batch_size * model_input.max_kv_seq_len draft_model_input.b_num_accepted_tokens = None - draft_model_input.mem_indexes_cpu = None - draft_model_input.b_req_idx = model_input.b_req_idx.index_select(0, accepted_row_idx) - draft_model_input.b_seq_len = model_input.b_seq_len.index_select(0, accepted_row_idx) - draft_model_input.b_mtp_index = model_input.b_mtp_index.index_select(0, accepted_row_idx) - draft_model_input.input_ids = predict_ids.reshape(-1).index_select(0, accepted_row_idx) - draft_model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens.index_select( - 0, accepted_row_idx - ) - draft_model_input.multimodal_params = [{"images": [], "audios": []} for _ in range(batch_size)] + draft_model_input.input_ids = predict_ids.reshape(-1) + draft_model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens if _mtp_profile and not warmup: _step_evs = [] @@ -328,7 +312,6 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ _host_t0 = time.time() _ev_d0.record() for _step in range(mtp_step): - draft_model_input.mem_indexes = mem_index_plan[req_offsets, (accept_len - 1) + _step] draft_model = draft_models[_step % num_instances] if _mtp_profile and not warmup: _es = torch.cuda.Event(enable_timing=True) @@ -342,6 +325,11 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ draft_model_input.mtp_draft_input_hiddens = draft_output.mtp_main_output_hiddens draft_model_input.b_seq_len = draft_model_input.b_seq_len + 1 draft_model_input.max_kv_seq_len += 1 + eagle_mem_indexes_i = eagle_mem_indexes[_step * batch_size : (_step + 1) * batch_size] + draft_model_input.mem_indexes = torch.cat( + [draft_model_input.mem_indexes.view(-1, mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)], + dim=1, + ).view(-1) if _mtp_profile and not warmup: _ev_d1.record() _host_t1 = time.time() diff --git a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py index 54d61f18c7..61908d37c0 100644 --- a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py +++ b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py @@ -68,14 +68,13 @@ def alloc(self, size): assert model_input.mtp_draft_input_hiddens.shape == (5, 4) -def test_mtp_decode_cuda_graph_keys_verify_and_normal_layouts(): +def test_mtp_decode_cuda_graph_keys_distinguish_verify_and_normal(): from lightllm.common.basemodel.cuda_graph import CudaGraph graph = CudaGraph.__new__(CudaGraph) graph.mtp_step = 2 graph.graph = {} - graph.normal_cuda_graph_batch_sizes = [1, 2, 4, 8] - graph.mtp_verify_cuda_graph_batch_sizes = [3, 6, 9, 12] + graph.cuda_graph_batch_sizes = [3, 6, 9, 12] verify_state = SimpleNamespace( input_ids=torch.ones(6, dtype=torch.int64), @@ -86,14 +85,15 @@ def test_mtp_decode_cuda_graph_keys_verify_and_normal_layouts(): b_num_accepted_tokens=None, ) + # Same batch size, but the verify and normal decodes get distinct graph keys. assert graph._decode_graph_key(verify_state) == (6, True) assert graph._decode_graph_key(normal_state) == (6, False) - assert graph.find_closest_graph_batch_size(5, is_mtp_verify_decode=True) == 6 - assert graph.find_closest_graph_batch_size(5, is_mtp_verify_decode=False) == 8 + assert graph.find_closest_graph_batch_size(5) == 6 + # A captured verify graph does not satisfy a normal-graph capture need at the same batch size. graph.graph[(6, True)] = "verify graph" assert graph.need_capture(6, is_mtp_verify_decode=True) is False - assert graph.need_capture(5, is_mtp_verify_decode=False) is True + assert graph.need_capture(6, is_mtp_verify_decode=False) is True def test_mtp_decode_cuda_graph_warmup_layouts_split_main_and_draft_models(): @@ -107,11 +107,11 @@ class Qwen3_5MoeMTPModel: graph = CudaGraph.__new__(CudaGraph) graph.mtp_step = 2 - graph.normal_cuda_graph_batch_sizes = [1, 2, 4, 8] - graph.mtp_verify_cuda_graph_batch_sizes = [3, 6, 9] + graph.cuda_graph_batch_sizes = [3, 6, 9] + # Same batch-size set for both; the main model warms up the verify layout, the draft the normal. assert list(graph._iter_warmup_graph_layouts(Qwen3_5MOETpPartModel())) == [(True, [3, 6, 9])] - assert list(graph._iter_warmup_graph_layouts(Qwen3_5MoeMTPModel())) == [(False, [1, 2, 4, 8])] + assert list(graph._iter_warmup_graph_layouts(Qwen3_5MoeMTPModel())) == [(False, [3, 6, 9])] def test_mtp_decode_warmup_layout_marks_qwen3next_verify(monkeypatch): @@ -247,147 +247,3 @@ def fake_page_table_copy(page_table, req_to_token_indexs, b_req_idx): assert state.b_att_seq_len.tolist() == [5, 7] assert copied["page_table_shape"] == (2, 16) assert copied["b_req_idx"].tolist() == [3, 4] - - -def test_build_eagle_accepted_draft_input_narrows_to_accepted_rows(): - from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput - from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ( - ChunkedPrefillBackend, - ) - - backend = ChunkedPrefillBackend.__new__(ChunkedPrefillBackend) - backend.mtp_step = 2 - - main_input = ModelInput( - batch_size=6, - total_token_num=27, - max_q_seq_len=1, - max_kv_seq_len=9, - input_ids=torch.tensor([10, 11, 12, 20, 21, 22], dtype=torch.int64), - mem_indexes=torch.tensor([100, 101, 102, 200, 201, 202], dtype=torch.int32), - b_req_idx=torch.tensor([3, 3, 3, 4, 4, 4], dtype=torch.int32), - b_mtp_index=torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.int32), - b_seq_len=torch.tensor([5, 6, 7, 6, 7, 8], dtype=torch.int32), - b_num_accepted_tokens=torch.tensor([1, 1], dtype=torch.int32), - is_prefill=False, - multimodal_params=[ - {"row": 0}, - {"row": 1}, - {"row": 2}, - {"row": 3}, - {"row": 4}, - {"row": 5}, - ], - ) - hidden = torch.arange(6 * 4, dtype=torch.float32).view(6, 4) - main_output = ModelOutput(logits=torch.empty(6, 8), mtp_main_output_hiddens=hidden) - next_token_ids = torch.tensor([110, 111, 112, 220, 221, 222], dtype=torch.int64) - b_req_mtp_start_loc = torch.tensor([0, 3], dtype=torch.int32) - mtp_accept_len = torch.tensor([2, 3], dtype=torch.int32) - - (draft_input, accepted_next_tokens, accepted_req_idx,) = backend._build_eagle_accepted_draft_input( - main_model_input=main_input, - main_model_output=main_output, - next_token_ids=next_token_ids, - mtp_accept_len=mtp_accept_len, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) - - assert draft_input.batch_size == 2 - assert draft_input.input_ids.tolist() == [111, 222] - assert draft_input.b_req_idx.tolist() == [3, 4] - assert draft_input.b_mtp_index.tolist() == [1, 2] - assert draft_input.b_seq_len.tolist() == [6, 8] - assert draft_input.mem_indexes.tolist() == [101, 202] - assert draft_input.b_num_accepted_tokens is None - assert draft_input.multimodal_params == [{"row": 1}, {"row": 5}] - assert accepted_next_tokens.tolist() == [111, 222] - assert accepted_req_idx.tolist() == [3, 4] - torch.testing.assert_close(draft_input.mtp_draft_input_hiddens, hidden[[1, 5]]) - - -def test_eagle_draft_decode_uses_narrowed_hidden_on_first_step(monkeypatch): - import lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl as chunked_impl - from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput - from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ( - ChunkedPrefillBackend, - ) - - class FakeMemManager: - HOLD_TOKEN_MEMINDEX = -1 - - def alloc(self, need_size): - return torch.arange(300, 300 + need_size, dtype=torch.int32) - - req_to_next_token_ids = torch.empty((8, 3), dtype=torch.int64) - monkeypatch.setattr( - chunked_impl, - "g_infer_context", - SimpleNamespace( - radix_cache=None, - req_manager=SimpleNamespace( - mem_manager=FakeMemManager(), - req_sampling_params_manager=SimpleNamespace(req_to_next_token_ids=req_to_next_token_ids), - ), - ), - ) - monkeypatch.setattr(torch.Tensor, "cuda", lambda self, non_blocking=False: self) - - backend = ChunkedPrefillBackend.__new__(ChunkedPrefillBackend) - backend.mtp_step = 2 - backend.num_mtp_models = 1 - - seen_hiddens = [] - - class FakeDraftModel: - def forward(self, model_input): - seen_hiddens.append(model_input.mtp_draft_input_hiddens.clone()) - logits = torch.zeros((model_input.batch_size, 8), dtype=torch.float32) - return ModelOutput( - logits=logits, - mtp_main_output_hiddens=model_input.mtp_draft_input_hiddens + 100, - ) - - backend.draft_models = [FakeDraftModel()] - - scattered = {} - - def fake_scatter(accepted_req_idx, all_next_token_ids): - scattered["accepted_req_idx"] = accepted_req_idx.clone() - scattered["all_next_token_ids"] = all_next_token_ids.clone() - - backend._scatter_accepted_next_token_ids = fake_scatter - - main_input = ModelInput( - batch_size=6, - total_token_num=27, - max_q_seq_len=1, - max_kv_seq_len=9, - input_ids=torch.tensor([10, 11, 12, 20, 21, 22], dtype=torch.int64), - mem_indexes=torch.tensor([100, 101, 102, 200, 201, 202], dtype=torch.int32), - b_req_idx=torch.tensor([3, 3, 3, 4, 4, 4], dtype=torch.int32), - b_mtp_index=torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.int32), - b_seq_len=torch.tensor([5, 6, 7, 6, 7, 8], dtype=torch.int32), - b_num_accepted_tokens=torch.tensor([1, 1], dtype=torch.int32), - is_prefill=False, - multimodal_params=[{"images": [], "audios": []} for _ in range(6)], - ) - hidden = torch.arange(6 * 4, dtype=torch.float32).view(6, 4) - main_output = ModelOutput(logits=torch.empty(6, 8), mtp_main_output_hiddens=hidden) - next_token_ids = torch.tensor([110, 111, 112, 220, 221, 222], dtype=torch.int64) - b_req_mtp_start_loc = torch.tensor([0, 3], dtype=torch.int32) - mtp_accept_len = torch.tensor([2, 3], dtype=torch.int32) - - returned_mem = backend._draft_decode_eagle( - main_model_input=main_input, - main_model_output=main_output, - next_token_ids=next_token_ids, - mtp_accept_len=mtp_accept_len, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) - - assert returned_mem.tolist() == [300, 301, 302, 303] - torch.testing.assert_close(seen_hiddens[0], hidden[[1, 5]]) - torch.testing.assert_close(seen_hiddens[1], hidden[[1, 5]] + 100) - assert scattered["accepted_req_idx"].tolist() == [3, 4] - assert scattered["all_next_token_ids"].shape == (2, 3) From 7c08d910065a47619e6d79327fa7a9054b38e05a Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 15 Jun 2026 23:23:13 +0800 Subject: [PATCH 10/19] revert(mtp): run the MTP draft on upstream's grouped verify layout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop the remaining draft-side divergence from upstream so this branch is scoped to Qwen3.5 MTP support only. The draft decode no longer clears b_num_accepted_tokens to force a flat/normal layout; it reuses the main model_input (still copy.copy'd to isolate per-step input_ids/b_seq_len/ mem_indexes mutations) and runs the same (mtp_step+1)-grouped verify decode layout as the main model — exactly as upstream does. For the pure-full-attention draft (qwen3_5_mtp: full_attention_interval=1, no GDN) grouped and flat are numerically identical: each position k sees KV [0, s+k) either way, same page-table entries, same RoPE positions; the main verify forward already uses this geometry and is the validated path. The earlier flat-draft only added an unnecessary (bs, False) cudagraph layout + b_num_accepted_tokens gating; nothing the draft computes needs it. - chunked_prefill/dp_backend: 6 draft fns (vanilla/eagle + dp overlap variants) stop clearing b_num_accepted_tokens. - cuda_graph: draft warms up the verify graph key too (mtp_step>0 -> verify for both main and draft); delete the now-dead _is_mtp_draft_model. - tests: rewrite the warmup-layout test (main+draft both verify; mtp_step==0 -> normal) and drop the stale "draft uses normal layout" framing. Keep is_mtp_verify_decode (main-model GDN verify still needs it) and the committed fp8.py causal=True fix. Verified live (QW35-122B-A10B, eagle_with_att, mtp_step=1, tp4): GSM8K acc 0.964 / Invalid 0.000, accept 1.956/2.0 — matches pre-revert baseline (no regression). Codex independent pass concurred (high confidence). --- lightllm/common/basemodel/cuda_graph.py | 11 +++--- .../mode_backend/chunked_prefill/impl.py | 17 +++++----- .../mode_backend/dp_backend/impl.py | 34 +++++++++---------- .../basemodel/test_mtp_decode_cuda_graph.py | 16 ++++++--- 4 files changed, 40 insertions(+), 38 deletions(-) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index ddf78f4c6f..5b0972f860 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -134,13 +134,12 @@ def _build_warmup_decode_model_input( **model._gen_special_model_input(batch_size), ) - def _is_mtp_draft_model(self, model): - return getattr(model, "is_mtp_draft_model", False) - def _iter_warmup_graph_layouts(self, model): - # main-model decode is a verify forward; the draft (MTP) model takes the normal layout. - # Both warm up over the same batch-size set; only the verify flag (graph key + layout) differs. - if self.mtp_step > 0 and not self._is_mtp_draft_model(model): + # Under MTP both the main verify forward and the (pure full-attention) draft forward run the + # (mtp_step+1)-grouped verify decode layout, so both warm up the verify graph key; only + # mtp_step == 0 models use the normal layout. (Matches upstream: the draft reuses the main + # model_input and keeps b_num_accepted_tokens, so its decode is a verify forward too.) + if self.mtp_step > 0: yield True, self.cuda_graph_batch_sizes else: yield False, self.cuda_graph_batch_sizes diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 2a9a2e7492..7b7242ad0f 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -353,12 +353,11 @@ def _draft_decode_vanilla( mtp_accept_len: torch.Tensor, b_req_mtp_start_loc: torch.Tensor, ): - # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, - # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型 decode_mtp 设置的 - # verify 布局,命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError - # (cudagraph 关闭时则会在扁平的 draft batch 上误用 S+1 分组的 verify attention)。 + # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, + # 避免污染之后仍要用到的 main_model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, + # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 + # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 draft_model_input = copy.copy(main_model_input) - draft_model_input.b_num_accepted_tokens = None draft_model_output = main_model_output draft_next_token_ids = next_token_ids all_next_token_ids = [] @@ -398,11 +397,11 @@ def _draft_decode_eagle( eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(num_reqs * self.mtp_step) eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) - # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, - # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型 decode_mtp 设置的 - # verify 布局,命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError。 + # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, + # 避免污染之后仍要用到的 main_model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, + # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 + # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 draft_model_input = copy.copy(main_model_input) - draft_model_input.b_num_accepted_tokens = None draft_model_output = main_model_output draft_next_token_ids = next_token_ids all_next_token_ids = [] diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index b245ac33d8..18f60b4934 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -552,11 +552,11 @@ def _draft_decode_vanilla( req_num: int, ): all_next_token_ids = [] - # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, - # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型设置的 verify 布局, - # 命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError。 + # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, + # 避免污染之后仍要用到的 model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, + # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 + # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 draft_model_input = copy.copy(model_input) - draft_model_input.b_num_accepted_tokens = None draft_model_output = model_output draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") if req_num > 0: @@ -595,11 +595,11 @@ def _draft_decode_eagle( req_num: int, ): all_next_token_ids = [] - # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, - # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型设置的 verify 布局, - # 命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError。 + # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, + # 避免污染之后仍要用到的 model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, + # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 + # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 draft_model_input = copy.copy(model_input) - draft_model_input.b_num_accepted_tokens = None draft_model_output = model_output all_next_token_ids.append(next_token_ids) draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") @@ -913,13 +913,12 @@ def _draft_decode_vanilla_overlap( ): all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, - # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型设置的 verify 布局, - # 命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError。 + # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, + # 避免污染之后仍要用到的 model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, + # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 + # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 draft_model_input0 = copy.copy(model_input0) draft_model_input1 = copy.copy(model_input1) - draft_model_input0.b_num_accepted_tokens = None - draft_model_input1.b_num_accepted_tokens = None draft_model_output0, draft_model_output1 = model_output0, model_output1 draft_next_token_ids_gpu0 = torch.zeros((model_input0.batch_size), dtype=torch.int64, device="cuda") @@ -975,13 +974,12 @@ def _draft_decode_eagle_overlap( ): all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens, - # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型设置的 verify 布局, - # 命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError。 + # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, + # 避免污染之后仍要用到的 model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, + # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 + # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 draft_model_input0 = copy.copy(model_input0) draft_model_input1 = copy.copy(model_input1) - draft_model_input0.b_num_accepted_tokens = None - draft_model_input1.b_num_accepted_tokens = None draft_model_output0, draft_model_output1 = model_output0, model_output1 draft_next_token_ids_gpu0 = torch.zeros((model_input0.batch_size), dtype=torch.int64, device="cuda") diff --git a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py index 61908d37c0..9d35cb53cc 100644 --- a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py +++ b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py @@ -34,7 +34,7 @@ def alloc(self, size): assert model_input.total_token_num == 18 -def test_mtp_decode_cuda_graph_warmup_supports_normal_layout_for_draft(): +def test_mtp_decode_cuda_graph_warmup_builds_normal_layout_when_not_verify(): from lightllm.common.basemodel.cuda_graph import CudaGraph graph = CudaGraph.__new__(CudaGraph) @@ -96,7 +96,7 @@ def test_mtp_decode_cuda_graph_keys_distinguish_verify_and_normal(): assert graph.need_capture(6, is_mtp_verify_decode=False) is True -def test_mtp_decode_cuda_graph_warmup_layouts_split_main_and_draft_models(): +def test_mtp_decode_cuda_graph_warmup_layouts_use_verify_for_main_and_draft(): from lightllm.common.basemodel.cuda_graph import CudaGraph class Qwen3_5MOETpPartModel: @@ -109,9 +109,15 @@ class Qwen3_5MoeMTPModel: graph.mtp_step = 2 graph.cuda_graph_batch_sizes = [3, 6, 9] - # Same batch-size set for both; the main model warms up the verify layout, the draft the normal. + # Under MTP both the main verify forward and the pure-full-attention draft forward run the + # (mtp_step+1)-grouped verify decode layout (the draft reuses the main model_input and keeps + # b_num_accepted_tokens), so both warm up the verify graph key over the same batch-size set. assert list(graph._iter_warmup_graph_layouts(Qwen3_5MOETpPartModel())) == [(True, [3, 6, 9])] - assert list(graph._iter_warmup_graph_layouts(Qwen3_5MoeMTPModel())) == [(False, [3, 6, 9])] + assert list(graph._iter_warmup_graph_layouts(Qwen3_5MoeMTPModel())) == [(True, [3, 6, 9])] + + # A non-MTP model (mtp_step == 0) warms up the normal layout instead. + graph.mtp_step = 0 + assert list(graph._iter_warmup_graph_layouts(Qwen3_5MOETpPartModel())) == [(False, [3, 6, 9])] def test_mtp_decode_warmup_layout_marks_qwen3next_verify(monkeypatch): @@ -205,7 +211,7 @@ def fake_base_init_cudagraph(self): assert model.graph == "captured" -def test_fa3_decode_uses_normal_layout_for_narrowed_mtp_draft(monkeypatch): +def test_fa3_decode_uses_normal_layout_when_no_accept_tensor(monkeypatch): import lightllm.common.basemodel.attention.fa3.fp as fa3_fp from lightllm.common.basemodel.attention.fa3.fp import Fa3DecodeAttState From 067c4c61b395b8469dd7ae5cb2424ea80595d4cf Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 16 Jun 2026 08:59:36 +0800 Subject: [PATCH 11/19] clean code --- .../common/basemodel/attention/fa3/fp8.py | 18 ++---- .../layer_weights/transformer_layer_weight.py | 62 ++++++++++++++++++- .../layer_weights/mtp_retarget_mixin.py | 61 ------------------ .../layer_weights/transformer_layer_weight.py | 61 +++++++++++++++++- .../mode_backend/chunked_prefill/impl.py | 5 +- 5 files changed, 126 insertions(+), 81 deletions(-) delete mode 100644 lightllm/models/qwen3_5_mtp/layer_weights/mtp_retarget_mixin.py diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index d85a1caf33..c1861aad2c 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -3,6 +3,7 @@ from ..base_att import AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.quantization.q_per_head_fp8_quant import q_per_head_fp8_quant from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops from typing import Union @@ -44,12 +45,9 @@ def init_state(self): torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len ) # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = ( - offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - ) - self.v_descale = ( - offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - ) + self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + def prefill_att( self, @@ -125,12 +123,8 @@ def init_state(self): head_num = mem_manager.head_num # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = ( - offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - ) - self.v_descale = ( - offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - ) + self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) return diff --git a/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py index 554db359f6..3367645e83 100644 --- a/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py @@ -2,15 +2,73 @@ COLMMWeight, FusedMoeWeight, ROWMMWeight, + QKVROWNMMWeight, ) from lightllm.models.qwen3_5_moe.layer_weights.transformer_layer_weight import ( Qwen35MOETransformerLayerWeight, ) -from lightllm.models.qwen3_5_mtp.layer_weights.mtp_retarget_mixin import MTPRetargetMixin from lightllm.utils.envs_utils import get_env_start_args -class Qwen3_5MoeMTPTransformerLayerWeight(MTPRetargetMixin, Qwen35MOETransformerLayerWeight): +class Qwen3_5MoeMTPTransformerLayerWeight(Qwen35MOETransformerLayerWeight): + # MTP draft-model weights live under the `mtp.layers.*` checkpoint namespace; the + # main-model attention/norm names (`model.layers.*`) are retargeted to it, while the + # MoE expert / shared-expert names are built directly with the mtp prefix below. + + _MAIN_PREFIX = "model.layers." + _MTP_PREFIX = "mtp.layers." + + _ATTN_NORM_NAME_ATTRS = ( + "_q_weight_name", + "_q_norm_name", + "_q_bias_name", + "_k_weight_name", + "_k_norm_name", + "_k_bias_name", + "_v_weight_name", + "_v_bias_name", + "_kv_weight_name", + "_kv_bias_name", + "_o_weight_name", + "_o_bias_name", + "_att_norm_weight_name", + "_att_norm_bias_name", + "_ffn_norm_weight_name", + "_ffn_norm_bias_name", + ) + + def _retarget(self, name): + if name is None: + return None + return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1) + + def _retarget_attn_norm_names(self): + for attr in self._ATTN_NORM_NAME_ATTRS: + setattr(self, attr, self._retarget(getattr(self, attr))) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.kv_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight" + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) + def _init_weight_names(self): super()._init_weight_names() self._retarget_attn_norm_names() diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/mtp_retarget_mixin.py b/lightllm/models/qwen3_5_mtp/layer_weights/mtp_retarget_mixin.py deleted file mode 100644 index cf9da94887..0000000000 --- a/lightllm/models/qwen3_5_mtp/layer_weights/mtp_retarget_mixin.py +++ /dev/null @@ -1,61 +0,0 @@ -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, QKVROWNMMWeight - - -class MTPRetargetMixin: - """Shared MTP weight-name retargeting (model.layers.* -> mtp.layers.*) and qkv/o_gate wiring, - used by both the dense and MoE Qwen3.5 MTP layer-weight classes (#11). The dense subclass adds - its dense-MLP retargets on top; the MoE subclass must not (it uses fused experts).""" - - _MAIN_PREFIX = "model.layers." - _MTP_PREFIX = "mtp.layers." - - _ATTN_NORM_NAME_ATTRS = ( - "_q_weight_name", - "_q_norm_name", - "_q_bias_name", - "_k_weight_name", - "_k_norm_name", - "_k_bias_name", - "_v_weight_name", - "_v_bias_name", - "_kv_weight_name", - "_kv_bias_name", - "_o_weight_name", - "_o_bias_name", - "_att_norm_weight_name", - "_att_norm_bias_name", - "_ffn_norm_weight_name", - "_ffn_norm_bias_name", - ) - - def _retarget(self, name): - if name is None: - return None - return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1) - - def _retarget_attn_norm_names(self): - for attr in self._ATTN_NORM_NAME_ATTRS: - setattr(self, attr, self._retarget(getattr(self, attr))) - - def _init_qkv(self): - in_dim = self.n_embed - q_out_dim = self.q_head_num_ * self.head_dim - self.qkv_proj = QKVROWNMMWeight( - in_dim=in_dim, - q_head_num=self.q_head_num_, - kv_head_num=self.k_head_num_, - head_dim=self.head_dim, - weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], - data_type=self.data_type_, - bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], - quant_method=self.get_quant_method("qkv_proj"), - ) - self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight" - self._o_gate_proj = ROWMMWeight( - in_dim=in_dim, - out_dims=[q_out_dim], - weight_names=[self._o_gate_weight_name], - data_type=self.data_type_, - bias_names=None, - quant_method=self.get_quant_method("o_gate_proj"), - ) diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py index 5aa0724580..31f6df3f8c 100644 --- a/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py @@ -1,13 +1,70 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, QKVROWNMMWeight from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( Qwen35TransformerLayerWeight, ) -from lightllm.models.qwen3_5_mtp.layer_weights.mtp_retarget_mixin import MTPRetargetMixin from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) -class Qwen3_5MTPTransformerLayerWeight(MTPRetargetMixin, Qwen35TransformerLayerWeight): +class Qwen3_5MTPTransformerLayerWeight(Qwen35TransformerLayerWeight): + # MTP draft-model weights live under the `mtp.layers.*` checkpoint namespace, so every + # main-model layer name (`model.layers.*`) is retargeted to it at load time. + + _MAIN_PREFIX = "model.layers." + _MTP_PREFIX = "mtp.layers." + + _ATTN_NORM_NAME_ATTRS = ( + "_q_weight_name", + "_q_norm_name", + "_q_bias_name", + "_k_weight_name", + "_k_norm_name", + "_k_bias_name", + "_v_weight_name", + "_v_bias_name", + "_kv_weight_name", + "_kv_bias_name", + "_o_weight_name", + "_o_bias_name", + "_att_norm_weight_name", + "_att_norm_bias_name", + "_ffn_norm_weight_name", + "_ffn_norm_bias_name", + ) + + def _retarget(self, name): + if name is None: + return None + return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1) + + def _retarget_attn_norm_names(self): + for attr in self._ATTN_NORM_NAME_ATTRS: + setattr(self, attr, self._retarget(getattr(self, attr))) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.kv_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight" + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) + def _init_weight_names(self): super()._init_weight_names() # Retarget all main-model layer key names to the mtp.* namespace. diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 7b7242ad0f..cd1e14be73 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -252,10 +252,7 @@ def decode_mtp( with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) - # verify the next_token_ids. The chunked decode batch is the contiguous - # (mtp_step+1)-expanded layout, so request starts are structurally - # arange(n_real)*(mtp_step+1). Compute on device instead of a per-step Python - # list-comp + pinned pack + H2D (#22). + # verify the next_token_ids n_real = model_input.batch_size // (self.mtp_step + 1) b_req_mtp_start_loc = torch.arange(n_real, dtype=torch.int32, device="cuda") * (self.mtp_step + 1) From f71bcc96074e315a6e2f7306cc0018c00a252786 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 16 Jun 2026 11:06:10 +0800 Subject: [PATCH 12/19] clean code --- lightllm/common/basemodel/attention/fa3/fp.py | 11 +- .../common/basemodel/attention/fa3/mla.py | 11 +- lightllm/common/basemodel/basemodel.py | 107 ++++-------------- lightllm/common/basemodel/batch_objs.py | 7 -- lightllm/common/basemodel/cuda_graph.py | 103 +++++++---------- .../layer_weights/transformer_layer_weight.py | 2 +- .../layer_weights/transformer_layer_weight.py | 2 +- .../basemodel/test_mtp_decode_cuda_graph.py | 105 ++--------------- 8 files changed, 82 insertions(+), 266 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index a7395faebf..952bb39d91 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -7,7 +7,6 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor -from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn class Fa3AttBackend(BaseAttBackend): @@ -126,9 +125,8 @@ class Fa3DecodeAttState(BaseDecodeAttState): def init_state(self): self.backend: Fa3AttBackend = self.backend args_mtp_step = get_env_start_args().mtp_step - is_mtp_verify_decode = is_mtp_verify_decode_fn(args_mtp_step, self.infer_state.b_num_accepted_tokens) - if is_mtp_verify_decode: + if args_mtp_step > 0: # 修正 mtp 在 fa3 下的输入。 mtp_size = args_mtp_step + 1 b_q_seq_len = torch.full( @@ -145,9 +143,8 @@ def init_state(self): self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() - mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1 - att_batch_size = self.infer_state.batch_size // mtp_size - assert self.infer_state.batch_size % mtp_size == 0 + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 model = self.backend.model # 可以使用 cuda graph的时候从 buffer中申请 @@ -166,7 +163,7 @@ def init_state(self): device=self.infer_state.input_ids.device, ) - if is_mtp_verify_decode: + if args_mtp_step > 0: page_table_copy( page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], req_to_token_indexs=model.req_manager.req_to_token_indexs, diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py index 982bd117c3..9a10457b12 100644 --- a/lightllm/common/basemodel/attention/fa3/mla.py +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -8,7 +8,6 @@ from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor from lightllm.utils.sgl_utils import flash_attn_varlen_func -from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn class MlaFa3AttBackend(BaseAttBackend): @@ -109,9 +108,8 @@ class MlaFa3DecodeAttState(BaseDecodeAttState): def init_state(self): self.backend: MlaFa3AttBackend = self.backend args_mtp_step = get_env_start_args().mtp_step - is_mtp_verify_decode = is_mtp_verify_decode_fn(args_mtp_step, self.infer_state.b_num_accepted_tokens) - if is_mtp_verify_decode: + if args_mtp_step > 0: # 修正 mtp 在 fa3 下的输入。 mtp_size = args_mtp_step + 1 b_q_seq_len = torch.full( @@ -128,9 +126,8 @@ def init_state(self): self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() - mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1 - att_batch_size = self.infer_state.batch_size // mtp_size - assert self.infer_state.batch_size % mtp_size == 0 + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 model = self.backend.model # 可以使用 cuda graph的时候从 buffer中申请 @@ -149,7 +146,7 @@ def init_state(self): device=self.infer_state.input_ids.device, ) - if is_mtp_verify_decode: + if args_mtp_step > 0: page_table_copy( page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], req_to_token_indexs=model.req_manager.req_to_token_indexs, diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4d307ec2dd..1e1090fca0 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -18,17 +18,12 @@ from lightllm.common.req_manager import ReqManager from lightllm.common.infer_utils import init_req_to_token_indexes from lightllm.common.build_utils import repair_config -from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import ( - copy_kv_index_to_req, -) +from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg -from lightllm.common.basemodel.triton_kernel.gather_token_id import ( - gather_token, - gather_token_prefill_decode_mixed, -) +from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import ( @@ -38,13 +33,9 @@ ) from lightllm.distributed.communication_op import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput -from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn from lightllm.common.triton_utils.autotuner import AutotuneLevel from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch -from lightllm.utils.envs_utils import ( - set_model_init_status, - enable_diverse_mode_gqa_decode_fast_kernel, -) +from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache from .attention import get_prefill_att_backend_class, get_decode_att_backend_class @@ -369,9 +360,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) def _get_decode_padding_unit(self, model_input: ModelInput) -> int: padding_unit = self.tp_world_size_ if self.args.enable_tpsp_mix_mode else 1 - if (not model_input.is_prefill) and is_mtp_verify_decode_fn( - self.args.mtp_step, model_input.b_num_accepted_tokens - ): + if (not model_input.is_prefill) and self.args.mtp_step > 0: padding_unit = math.lcm(padding_unit, self.args.mtp_step + 1) return padding_unit @@ -389,10 +378,8 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input = copy.copy(model_input) new_model_input.batch_size = new_batch_size - is_mtp_verify_decode = (not model_input.is_prefill) and is_mtp_verify_decode_fn( - self.args.mtp_step, model_input.b_num_accepted_tokens - ) - if is_mtp_verify_decode: + is_mtp_grouped_decode = (not model_input.is_prefill) and self.args.mtp_step > 0 + if is_mtp_grouped_decode: mtp_size = self.args.mtp_step + 1 assert model_input.batch_size % mtp_size == 0 assert new_batch_size % mtp_size == 0 @@ -500,17 +487,11 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s if enable_diverse_mode_gqa_decode_fast_kernel(): if new_model_input.b_shared_seq_len is not None: new_model_input.b_shared_seq_len = F.pad( - new_model_input.b_shared_seq_len, - (0, padded_batch_size), - mode="constant", - value=0, + new_model_input.b_shared_seq_len, (0, padded_batch_size), mode="constant", value=0 ) if new_model_input.b_mark_shared_group is not None: new_model_input.b_mark_shared_group = F.pad( - new_model_input.b_mark_shared_group, - (0, padded_batch_size), - mode="constant", - value=1, + new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=1 ) # 特殊模型,特殊模式的特殊变量的特殊 padding @@ -545,10 +526,7 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle value=self.mem_manager.HOLD_TOKEN_MEMINDEX, ) new_model_input.b_req_idx = F.pad( - new_model_input.b_req_idx, - (0, 1), - mode="constant", - value=self.req_manager.HOLD_REQUEST_ID, + new_model_input.b_req_idx, (0, 1), mode="constant", value=self.req_manager.HOLD_REQUEST_ID ) new_model_input.b_mtp_index = F.pad(new_model_input.b_mtp_index, (0, 1), mode="constant", value=0) new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, 1), mode="constant", value=padded_token_num) @@ -588,10 +566,7 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba return new_model_output def _create_unpad_prefill_model_output( - self, - padded_model_output: ModelOutput, - origin_handle_token_num: int, - origin_batch_size: int, + self, padded_model_output: ModelOutput, origin_handle_token_num: int, origin_batch_size: int ): if self.return_all_prompt_logics: new_model_output = copy.copy(padded_model_output) @@ -678,7 +653,6 @@ def _decode( origin_batch_size = model_input.batch_size infer_batch_size = self._get_decode_infer_batch_size(model_input) - is_mtp_verify_decode = is_mtp_verify_decode_fn(self.args.mtp_step, model_input.b_num_accepted_tokens) if self.graph is not None and self.graph.can_run( batch_size=infer_batch_size, max_len_in_batch=model_input.max_kv_seq_len @@ -697,7 +671,7 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() - if self.graph.need_capture(infer_batch_size, is_mtp_verify_decode=is_mtp_verify_decode): + if self.graph.need_capture(infer_batch_size): infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: @@ -930,7 +904,6 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode origin_batch_size = model_input0.batch_size max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len) infer_batch_size = self._get_decode_infer_batch_size(model_input0) - is_mtp_verify_decode = is_mtp_verify_decode_fn(self.args.mtp_step, model_input0.b_num_accepted_tokens) if self.graph is not None and self.graph.can_run(infer_batch_size, max_len_in_batch): infer_batch_size = self.graph.find_closest_graph_batch_size(infer_batch_size) @@ -958,7 +931,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.init_some_extra_state(self) infer_state1.init_att_state() - if self.graph.need_capture(infer_batch_size, is_mtp_verify_decode=is_mtp_verify_decode): + if self.graph.need_capture(infer_batch_size): infer_state0.is_cuda_graph = True infer_state1.is_cuda_graph = True @@ -1010,11 +983,7 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state g_cache_manager.cache_env_in() input_embs, input_embs1 = self.pre_infer.overlap_tpsp_context_forward( - infer_state.input_ids, - infer_state1.input_ids, - infer_state, - infer_state1, - self.pre_post_weight, + infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight ) # 决定是否进行 dp balance 优化,可以提升dp > 1 时的 prefill 效率。 @@ -1030,11 +999,7 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state for i in range(self.layers_num): input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward( - input_embs, - input_embs1, - infer_state, - infer_state1, - self.trans_layers_weight[i], + input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] ) # 折叠模式调用完infer_state 和 infer_state1 上的hook函数后,input_embs 和 input_embs1 才具备正确的运算数据。 @@ -1048,11 +1013,7 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state last_input_embs1 = infer_state1._all_to_all_unbalance_get(data=last_input_embs1) predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( - last_input_embs, - last_input_embs1, - infer_state, - infer_state1, - self.pre_post_weight, + last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight ) g_cache_manager.cache_env_out() @@ -1073,22 +1034,14 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state @final def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: InferStateInfo): input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward( - infer_state.input_ids, - infer_state1.input_ids, - infer_state, - infer_state1, - self.pre_post_weight, + infer_state.input_ids, infer_state1.input_ids, infer_state, infer_state1, self.pre_post_weight ) input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) input_embs1 = self.pre_infer._tpsp_sp_split(input=input_embs1, infer_state=infer_state1) for i in range(self.layers_num): input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_token_forward( - input_embs, - input_embs1, - infer_state, - infer_state1, - self.trans_layers_weight[i], + input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] ) # 折叠模式调用完infer_state 上的hook函数后,input_embs 和 input_embs 才具备正确的运算数据。 @@ -1099,11 +1052,7 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: last_input_embs1 = self.post_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( - last_input_embs, - last_input_embs1, - infer_state, - infer_state1, - self.pre_post_weight, + last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight ) model_output = ModelOutput(logits=predict_logits.contiguous()) @@ -1210,12 +1159,7 @@ def _autotune_warmup(self): rand_gen = torch.Generator(device="cuda") rand_gen.manual_seed(input_len) dummy_input_ids = torch.randint( - 0, - 10000, - (input_len,), - dtype=torch.int32, - device="cuda", - generator=rand_gen, + 0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen ) b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() @@ -1279,14 +1223,10 @@ def _init_padded_req(self): batch_size = 1 dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") b_req_idx = torch.tensor( - [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], - dtype=torch.int32, - device="cuda", + [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) mem_indexes = torch.tensor( - [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], - dtype=torch.int32, - device="cuda", + [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") @@ -1333,10 +1273,7 @@ def _gen_special_model_input(self, token_num: int): is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False) if is_mtp_draft_model: special_model_input["mtp_draft_input_hiddens"] = torch.randn( - token_num, - self.config["hidden_size"], - dtype=self.data_type, - device="cuda", + token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" ) else: special_model_input["mtp_draft_input_hiddens"] = None diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 6104022733..03cb36d28d 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -6,13 +6,6 @@ from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor -def is_mtp_verify_decode(mtp_step: int, b_num_accepted_tokens) -> bool: - """Single source of truth for the MTP verify-decode predicate (#21). - A decode forward is a verify pass iff MTP is enabled and the per-real-request accept tensor is - present — decode_mtp sets it on the main verify and clears it (None) on every draft forward.""" - return mtp_step > 0 and b_num_accepted_tokens is not None - - @dataclass class ModelInput: # 通用变量 diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 5b0972f860..e2ba362f45 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -9,7 +9,6 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput -from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn from .infer_struct import InferStateInfo @@ -30,9 +29,8 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap # With MTP enabled, both the main-model verify forward and the draft (MTP) forward run over - # the (mtp_step+1)-expanded decode layout, so all decode batch sizes are multiples of - # (mtp_step+1); a single graph batch-size set serves both. Verify vs normal graphs are told - # apart by the is_mtp_verify_decode component of the graph key, not by a separate set. + # the (mtp_step+1)-expanded decode layout, so every decode batch size is a multiple of + # (mtp_step+1) and there is a single decode layout — the graph is keyed by batch size alone. batch_size_multiple = self.mtp_step + 1 if self.mtp_step > 0 else 1 self.cuda_graph_batch_sizes = self._build_cuda_graph_batch_sizes(batch_size_multiple=batch_size_multiple) logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") @@ -65,13 +63,12 @@ def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch def _decode_graph_key(self, infer_state: InferStateInfo): - is_mtp_verify_decode = is_mtp_verify_decode_fn(self.mtp_step, infer_state.b_num_accepted_tokens) - return (infer_state.input_ids.shape[0], is_mtp_verify_decode) + return infer_state.input_ids.shape[0] - def need_capture(self, batch_size, is_mtp_verify_decode=False): + def need_capture(self, batch_size): find_batch_size = self.find_closest_graph_batch_size(batch_size) if find_batch_size is not None: - return (find_batch_size, is_mtp_verify_decode) not in self.graph + return find_batch_size not in self.graph else: assert False, "dead code" @@ -88,11 +85,7 @@ def _build_warmup_decode_model_input( model, batch_size: int, device: str = "cuda", - is_mtp_verify_decode: Optional[bool] = None, ) -> ModelInput: - if is_mtp_verify_decode is None: - is_mtp_verify_decode = self.mtp_step > 0 - mtp_size = self.mtp_step + 1 input_ids = torch.ones(batch_size, dtype=torch.int32, device=device) mem_indexes = model.mem_manager.alloc(batch_size).to(device) @@ -104,7 +97,7 @@ def _build_warmup_decode_model_input( ) b_num_accepted_tokens = None - if self.mtp_step > 0 and is_mtp_verify_decode: + if self.mtp_step > 0: assert batch_size % mtp_size == 0, "MTP decode CUDA graph batch size must be a multiple of mtp_step + 1" real_batch_size = batch_size // mtp_size b_mtp_index = torch.arange(mtp_size, dtype=torch.int32, device=device).repeat(real_batch_size) @@ -134,16 +127,6 @@ def _build_warmup_decode_model_input( **model._gen_special_model_input(batch_size), ) - def _iter_warmup_graph_layouts(self, model): - # Under MTP both the main verify forward and the (pure full-attention) draft forward run the - # (mtp_step+1)-grouped verify decode layout, so both warm up the verify graph key; only - # mtp_step == 0 models use the normal layout. (Matches upstream: the draft reuses the main - # model_input and keeps b_num_accepted_tokens, so its decode is a verify forward too.) - if self.mtp_step > 0: - yield True, self.cuda_graph_batch_sizes - else: - yield False, self.cuda_graph_batch_sizes - def _capture_decode(self, decode_func, infer_state: InferStateInfo): graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids @@ -274,23 +257,18 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for is_mtp_verify_decode, batch_sizes in self._iter_warmup_graph_layouts(model): - for batch_size in batch_sizes[::-1]: - model_input = self._build_warmup_decode_model_input( - model, - batch_size, - is_mtp_verify_decode=is_mtp_verify_decode, - ) - model_output: ModelOutput = model.forward(model_input) - del model_output - - model.mem_manager.free_all() - model.req_manager.free_all() - # release local tensors - for var_name, var_value in list(locals().items()): - if isinstance(var_value, torch.Tensor): - del locals()[var_name] - torch.cuda.empty_cache() + for batch_size in self.cuda_graph_batch_sizes[::-1]: + model_input = self._build_warmup_decode_model_input(model, batch_size) + model_output: ModelOutput = model.forward(model_input) + del model_output + + model.mem_manager.free_all() + model.req_manager.free_all() + # release local tensors + for var_name, var_value in list(locals().items()): + if isinstance(var_value, torch.Tensor): + del locals()[var_name] + torch.cuda.empty_cache() logger.info( f"Capture cudagraph success, batch_size <={self.max_batch_size} " @@ -305,37 +283,32 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for is_mtp_verify_decode, batch_sizes in self._iter_warmup_graph_layouts(model): - for batch_size in batch_sizes[::-1]: - decode_batches = [] - for micro_batch_index in [0, 1]: - # dummy decoding, capture the cudagraph - micro_batch = self._build_warmup_decode_model_input( - model, - batch_size, - is_mtp_verify_decode=is_mtp_verify_decode, - ) - decode_batches.append(micro_batch) - del micro_batch + for batch_size in self.cuda_graph_batch_sizes[::-1]: + decode_batches = [] + for micro_batch_index in [0, 1]: + # dummy decoding, capture the cudagraph + micro_batch = self._build_warmup_decode_model_input(model, batch_size) + decode_batches.append(micro_batch) + del micro_batch - for var_name, var_value in list(locals().items()): - if isinstance(var_value, torch.Tensor): - del locals()[var_name] - torch.cuda.empty_cache() - - _, _ = model.microbatch_overlap_decode(decode_batches[0], decode_batches[1]) - - model.mem_manager.free_all() - model.req_manager.free_all() - - del decode_batches - - # release local tensors for var_name, var_value in list(locals().items()): if isinstance(var_value, torch.Tensor): del locals()[var_name] torch.cuda.empty_cache() + _, _ = model.microbatch_overlap_decode(decode_batches[0], decode_batches[1]) + + model.mem_manager.free_all() + model.req_manager.free_all() + + del decode_batches + + # release local tensors + for var_name, var_value in list(locals().items()): + if isinstance(var_value, torch.Tensor): + del locals()[var_name] + torch.cuda.empty_cache() + logger.info( f"Capture overlap cudagraph success, batch_size <={self.max_batch_size} " f"and max_len_in_batch <= {self.graph_max_len_in_batch} will infer with cudagraph." diff --git a/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py index 3367645e83..b2700aa0bd 100644 --- a/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py @@ -52,7 +52,7 @@ def _init_qkv(self): self.qkv_proj = QKVROWNMMWeight( in_dim=in_dim, q_head_num=self.q_head_num_, - kv_head_num=self.kv_head_num_, + kv_head_num=self.k_head_num_, head_dim=self.head_dim, weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], data_type=self.data_type_, diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py index 31f6df3f8c..4c76b4bce0 100644 --- a/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py @@ -48,7 +48,7 @@ def _init_qkv(self): self.qkv_proj = QKVROWNMMWeight( in_dim=in_dim, q_head_num=self.q_head_num_, - kv_head_num=self.kv_head_num_, + kv_head_num=self.k_head_num_, head_dim=self.head_dim, weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], data_type=self.data_type_, diff --git a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py index 9d35cb53cc..7c3e751022 100644 --- a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py +++ b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py @@ -34,11 +34,12 @@ def alloc(self, size): assert model_input.total_token_num == 18 -def test_mtp_decode_cuda_graph_warmup_builds_normal_layout_when_not_verify(): +def test_mtp_decode_cuda_graph_warmup_builds_normal_layout_for_non_mtp(): from lightllm.common.basemodel.cuda_graph import CudaGraph + # A non-MTP model (mtp_step == 0) has a single, ungrouped decode layout. graph = CudaGraph.__new__(CudaGraph) - graph.mtp_step = 2 + graph.mtp_step = 0 graph.graph_max_len_in_batch = 128 class FakeMemManager: @@ -53,12 +54,7 @@ def alloc(self, size): _gen_special_model_input=lambda token_num: {"mtp_draft_input_hiddens": torch.full((token_num, 4), 3.0)}, ) - model_input = graph._build_warmup_decode_model_input( - model, - batch_size=5, - device="cpu", - is_mtp_verify_decode=False, - ) + model_input = graph._build_warmup_decode_model_input(model, batch_size=5, device="cpu") assert model_input.batch_size == 5 assert model_input.b_mtp_index.tolist() == [0, 0, 0, 0, 0] @@ -68,56 +64,23 @@ def alloc(self, size): assert model_input.mtp_draft_input_hiddens.shape == (5, 4) -def test_mtp_decode_cuda_graph_keys_distinguish_verify_and_normal(): +def test_mtp_decode_cuda_graph_key_is_batch_size(): from lightllm.common.basemodel.cuda_graph import CudaGraph + # Under MTP there is a single (mtp_step+1)-grouped decode layout, so the graph is keyed by + # batch size alone — no verify/normal distinction in the key. graph = CudaGraph.__new__(CudaGraph) graph.mtp_step = 2 graph.graph = {} graph.cuda_graph_batch_sizes = [3, 6, 9, 12] - verify_state = SimpleNamespace( - input_ids=torch.ones(6, dtype=torch.int64), - b_num_accepted_tokens=torch.ones(2, dtype=torch.int32), - ) - normal_state = SimpleNamespace( - input_ids=torch.ones(6, dtype=torch.int64), - b_num_accepted_tokens=None, - ) - - # Same batch size, but the verify and normal decodes get distinct graph keys. - assert graph._decode_graph_key(verify_state) == (6, True) - assert graph._decode_graph_key(normal_state) == (6, False) + state = SimpleNamespace(input_ids=torch.ones(6, dtype=torch.int64)) + assert graph._decode_graph_key(state) == 6 assert graph.find_closest_graph_batch_size(5) == 6 - # A captured verify graph does not satisfy a normal-graph capture need at the same batch size. - graph.graph[(6, True)] = "verify graph" - assert graph.need_capture(6, is_mtp_verify_decode=True) is False - assert graph.need_capture(6, is_mtp_verify_decode=False) is True - - -def test_mtp_decode_cuda_graph_warmup_layouts_use_verify_for_main_and_draft(): - from lightllm.common.basemodel.cuda_graph import CudaGraph - - class Qwen3_5MOETpPartModel: - pass - - class Qwen3_5MoeMTPModel: - is_mtp_draft_model = True - - graph = CudaGraph.__new__(CudaGraph) - graph.mtp_step = 2 - graph.cuda_graph_batch_sizes = [3, 6, 9] - - # Under MTP both the main verify forward and the pure-full-attention draft forward run the - # (mtp_step+1)-grouped verify decode layout (the draft reuses the main model_input and keeps - # b_num_accepted_tokens), so both warm up the verify graph key over the same batch-size set. - assert list(graph._iter_warmup_graph_layouts(Qwen3_5MOETpPartModel())) == [(True, [3, 6, 9])] - assert list(graph._iter_warmup_graph_layouts(Qwen3_5MoeMTPModel())) == [(True, [3, 6, 9])] - - # A non-MTP model (mtp_step == 0) warms up the normal layout instead. - graph.mtp_step = 0 - assert list(graph._iter_warmup_graph_layouts(Qwen3_5MOETpPartModel())) == [(False, [3, 6, 9])] + graph.graph[6] = "decode graph" + assert graph.need_capture(6) is False + assert graph.need_capture(3) is True def test_mtp_decode_warmup_layout_marks_qwen3next_verify(monkeypatch): @@ -209,47 +172,3 @@ def fake_base_init_cudagraph(self): assert called["disable_cudagraph"] is False assert model.disable_cudagraph is False assert model.graph == "captured" - - -def test_fa3_decode_uses_normal_layout_when_no_accept_tensor(monkeypatch): - import lightllm.common.basemodel.attention.fa3.fp as fa3_fp - from lightllm.common.basemodel.attention.fa3.fp import Fa3DecodeAttState - - monkeypatch.setattr(fa3_fp, "get_env_start_args", lambda: SimpleNamespace(mtp_step=2)) - - copied = {} - - def fake_page_table_copy(page_table, req_to_token_indexs, b_req_idx): - copied["page_table_shape"] = tuple(page_table.shape) - copied["b_req_idx"] = b_req_idx.clone() - - monkeypatch.setattr(fa3_fp, "page_table_copy", fake_page_table_copy) - - model = SimpleNamespace( - graph_max_batch_size=16, - graph_max_len_in_batch=32, - req_manager=SimpleNamespace(req_to_token_indexs=torch.empty((8, 32), dtype=torch.int32)), - ) - backend = SimpleNamespace( - model=model, - get_page_table_buffer=lambda: [torch.empty(16 * 32, dtype=torch.int32)], - ) - infer_state = SimpleNamespace( - batch_size=2, - max_kv_seq_len=16, - input_ids=torch.ones(2, dtype=torch.int64), - b_seq_len=torch.tensor([5, 7], dtype=torch.int32), - b1_cu_q_seq_len=torch.tensor([0, 1, 2], dtype=torch.int32), - b1_cu_kv_seq_len=torch.tensor([0, 5, 12], dtype=torch.int32), - b_req_idx=torch.tensor([3, 4], dtype=torch.int32), - b_num_accepted_tokens=None, - microbatch_index=0, - ) - - state = Fa3DecodeAttState(backend=backend, infer_state=infer_state) - state.init_state() - - assert state.decode_max_q_seq_len == 1 - assert state.b_att_seq_len.tolist() == [5, 7] - assert copied["page_table_shape"] == (2, 16) - assert copied["b_req_idx"].tolist() == [3, 4] From 0d15236fac5d13273f2a99e2940befd80b7df85a Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 16 Jun 2026 14:02:06 +0800 Subject: [PATCH 13/19] refactor(mtp): GPU-resident req_to_accept_len + simplify verify-decode plumbing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - is_mtp_verify: drop the redundant `b_num_accepted_tokens is not None` clause (post grouped-revert it's implied by mtp_step>0 ∧ ¬prefill). - Replace the per-step host round-trip for b_num_accepted_tokens with a GPU-resident ReqManager.req_to_accept_len: a triton scatter_mtp_accept_len after verify + a GDN-only gather in init_mtp_verify_extra_state. Removes the gen_from_list H2D rebuild, the phase-2 req.mtp_accept_len writeback, and the host attr (linear-att offload + resets now read/write the buffer). - Drop the redundant `if mtp_step>0` guard inside decode_mtp/decode_overlap_mtp. - config_objs: inline the mtp draft-layer count, dropping the _mtp_added_layer_num helper (kept get_added_mtp_kv_layer_num inlined in envs_utils). - cpu_cache_meta: don't bump layer_num for linear-att models (the draft full-att slots are already in LinearAttCacheConfig.get_cpu_cache_big_page_bytes()). Static checks pass (ast, flake8). The req_to_accept_len refactor is not yet runtime-verified; pending a hybrid GSM8K + cudagraph-ON parity run. --- .gitignore | 1 + lightllm/common/basemodel/basemodel.py | 116 ++++-------------- lightllm/common/basemodel/cuda_graph.py | 3 +- .../basemodel/mtp_verify_extra_state.py | 27 +--- .../basemodel/triton_kernel/mtp_utils.py | 45 +++++++ .../linear_att_cache_manager/config_objs.py | 11 +- lightllm/common/req_manager.py | 18 ++- lightllm/models/qwen3_5/infer_struct.py | 2 +- lightllm/models/qwen3next/infer_struct.py | 2 +- .../server/router/model_infer/infer_batch.py | 16 +-- .../model_infer/mode_backend/base_backend.py | 6 +- .../mode_backend/chunked_prefill/impl.py | 30 ++--- .../mode_backend/dp_backend/impl.py | 93 ++++---------- lightllm/utils/envs_utils.py | 18 ++- lightllm/utils/kv_cache_utils.py | 6 +- 15 files changed, 152 insertions(+), 242 deletions(-) diff --git a/.gitignore b/.gitignore index 1156bab780..67a0db0b4c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ requirements-musa.txt logs/ /benchmark/ +artifacts/ diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 1e1090fca0..84303d4095 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -26,11 +26,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size -from lightllm.utils.envs_utils import ( - get_env_start_args, - get_llm_data_type, - get_added_mtp_kv_layer_num, -) +from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from lightllm.common.triton_utils.autotuner import AutotuneLevel @@ -381,105 +377,36 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s is_mtp_grouped_decode = (not model_input.is_prefill) and self.args.mtp_step > 0 if is_mtp_grouped_decode: mtp_size = self.args.mtp_step + 1 - assert model_input.batch_size % mtp_size == 0 - assert new_batch_size % mtp_size == 0 assert padded_batch_size % mtp_size == 0 padded_req_num = padded_batch_size // mtp_size - - pad_mtp_index = torch.arange( - mtp_size, - dtype=new_model_input.b_mtp_index.dtype, - device=new_model_input.b_mtp_index.device, - ).repeat(padded_req_num) - pad_seq_len = torch.arange( - 2, - mtp_size + 2, - dtype=new_model_input.b_seq_len.dtype, - device=new_model_input.b_seq_len.device, - ).repeat(padded_req_num) new_model_input.total_token_num += padded_req_num * (mtp_size * (mtp_size + 3) // 2) new_model_input.max_kv_seq_len = max(mtp_size + 1, model_input.max_kv_seq_len) - new_model_input.input_ids = torch.cat( - ( - new_model_input.input_ids, - torch.ones( - padded_batch_size, - dtype=new_model_input.input_ids.dtype, - device=new_model_input.input_ids.device, - ), - ), - dim=0, - ) - new_model_input.b_req_idx = torch.cat( - ( - new_model_input.b_req_idx, - torch.full( - (padded_batch_size,), - self.req_manager.HOLD_REQUEST_ID, - dtype=new_model_input.b_req_idx.dtype, - device=new_model_input.b_req_idx.device, - ), - ), - dim=0, - ) - new_model_input.b_mtp_index = torch.cat((new_model_input.b_mtp_index, pad_mtp_index), dim=0) + pad_seq_len = torch.arange( + 2, mtp_size + 2, dtype=new_model_input.b_seq_len.dtype, device=new_model_input.b_seq_len.device + ).repeat(padded_req_num) new_model_input.b_seq_len = torch.cat((new_model_input.b_seq_len, pad_seq_len), dim=0) - new_model_input.mem_indexes = torch.cat( - ( - new_model_input.mem_indexes, - torch.full( - (padded_batch_size,), - self.mem_manager.HOLD_TOKEN_MEMINDEX, - dtype=new_model_input.mem_indexes.dtype, - device=new_model_input.mem_indexes.device, - ), - ), - dim=0, - ) - new_model_input.b_num_accepted_tokens = torch.cat( - ( - new_model_input.b_num_accepted_tokens, - torch.ones( - padded_req_num, - dtype=new_model_input.b_num_accepted_tokens.dtype, - device=new_model_input.b_num_accepted_tokens.device, - ), - ), - dim=0, - ) + # b_num_accepted_tokens 不再随 model_input 流转/补齐:它在 GDN 的 init_mtp_verify_extra_state + # 里按 req_first 从 req_to_accept_len gather,padding 组 req_first=HOLD(槽恒为 1)自然得 1。 else: new_model_input.total_token_num += padded_batch_size * 2 new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) - new_model_input.input_ids = F.pad( - new_model_input.input_ids, - (0, padded_batch_size), - mode="constant", - value=1, - ) - new_model_input.b_req_idx = F.pad( - new_model_input.b_req_idx, - (0, padded_batch_size), - mode="constant", - value=self.req_manager.HOLD_REQUEST_ID, - ) - new_model_input.b_mtp_index = F.pad( - new_model_input.b_mtp_index, - (0, padded_batch_size), - mode="constant", - value=0, - ) new_model_input.b_seq_len = F.pad( - new_model_input.b_seq_len, - (0, padded_batch_size), - mode="constant", - value=2, - ) - new_model_input.mem_indexes = F.pad( - new_model_input.mem_indexes, - (0, padded_batch_size), - mode="constant", - value=self.mem_manager.HOLD_TOKEN_MEMINDEX, + new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2 ) + + new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) + new_model_input.b_req_idx = F.pad( + new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID + ) + new_model_input.b_mtp_index = F.pad( + new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 + ) + new_model_input.mem_indexes = F.pad( + new_model_input.mem_indexes, + (0, padded_batch_size), + mode="constant", + value=self.mem_manager.HOLD_TOKEN_MEMINDEX, + ) new_model_input.multimodal_params = new_model_input.multimodal_params + [ {"images": [], "audios": []} for _ in range(padded_batch_size) ] @@ -698,6 +625,7 @@ def _decode( @final def _context_forward(self, infer_state: InferStateInfo): + input_embs = self.pre_infer.context_forward(infer_state.input_ids, infer_state, self.pre_post_weight) if self.args.enable_dp_prefill_balance: assert not self.args.enable_prefill_cudagraph, "not support now" diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index e2ba362f45..001e6299e8 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -102,7 +102,8 @@ def _build_warmup_decode_model_input( real_batch_size = batch_size // mtp_size b_mtp_index = torch.arange(mtp_size, dtype=torch.int32, device=device).repeat(real_batch_size) b_seq_len = torch.arange(2, mtp_size + 2, dtype=torch.int32, device=device).repeat(real_batch_size) - b_num_accepted_tokens = torch.ones(real_batch_size, dtype=torch.int32, device=device) + # b_num_accepted_tokens 不再随 model_input 传入:GDN 的 init_mtp_verify_extra_state 会按 + # req_first(全 HOLD,槽恒为 1) gather,warmup/capture 自然得到全 1,等价旧的 torch.ones。 total_token_num = real_batch_size * (mtp_size * (mtp_size + 3) // 2) else: seq_len = 2 diff --git a/lightllm/common/basemodel/mtp_verify_extra_state.py b/lightllm/common/basemodel/mtp_verify_extra_state.py index e4bc5f4f7a..95bfce9388 100644 --- a/lightllm/common/basemodel/mtp_verify_extra_state.py +++ b/lightllm/common/basemodel/mtp_verify_extra_state.py @@ -3,29 +3,14 @@ from lightllm.utils.envs_utils import get_env_start_args -def init_mtp_verify_extra_state(self): - """Shared MTP-verify decode metadata, used by qwen3_5 and qwen3next infer-struct classes (#12). - Call AFTER super().init_some_extra_state(model). `self` is the InferStateInfo instance.""" +def init_mtp_verify_extra_state(self, model): self.b_att_seq_len = self.b_seq_len mtp_step = get_env_start_args().mtp_step self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index - # conv buffer is now ONE widened slot per request (indexed by req_idx), - # dropping the *(S+1) + mtp_index addressing used by the SSM block. self.b_conv_buffer_idx = self.b_req_idx - # MTP verify batch: decode-mode, S+1 expanded, and gated on the - # per-real-request accept tensor that decode_mtp threads in. Gating on - # b_num_accepted_tokens (vs only b_mtp_index, which is set for any decode) - # distinguishes the main-model verify forward from draft/plain decode. - self.is_mtp_verify = ( - (mtp_step > 0) - and (not self.is_prefill) - and (self.b_mtp_index is not None) - and (self.b_num_accepted_tokens is not None) - ) + self.is_mtp_verify = (mtp_step > 0) and (not self.is_prefill) and (self.b_mtp_index is not None) self.b_gdn_verify_cu_seqlens = None self.b_ssm_index_rows = None - # b_num_accepted_tokens is threaded onto the infer_state from ModelInput by - # _create_inferstate (mirrors b_mtp_index) BEFORE this runs; nothing to do here. if self.is_mtp_verify: step = mtp_step + 1 n_real = self.b_req_idx.shape[0] // step @@ -36,12 +21,6 @@ def init_mtp_verify_extra_state(self): base = (req_first * step).view(n_real, 1) self.b_ssm_index_rows = base + torch.arange(step, device=base.device, dtype=base.dtype).view(1, step) assert self.b_ssm_index_rows.shape == (n_real, step) - # The spec conv kernel is per-SEQUENCE (one program per real request), - # indexed by conv_state_indices[idx_seq] with idx_seq in [0, n_real), - # aligned 1:1 with b_gdn_verify_cu_seqlens / b_num_accepted_tokens. The - # default b_conv_buffer_idx = b_req_idx has the expanded length n_real*step, - # which launches n_real*step conv programs and reads num_accepted/ - # query_start_loc out of bounds for idx_seq >= n_real, corrupting the - # committed conv slot. Narrow it to one widened conv slot per request. self.b_conv_buffer_idx = req_first + self.b_num_accepted_tokens = model.req_manager.req_to_accept_len[req_first] return diff --git a/lightllm/common/basemodel/triton_kernel/mtp_utils.py b/lightllm/common/basemodel/triton_kernel/mtp_utils.py index a020605c26..bdd59c65e3 100644 --- a/lightllm/common/basemodel/triton_kernel/mtp_utils.py +++ b/lightllm/common/basemodel/triton_kernel/mtp_utils.py @@ -148,6 +148,51 @@ def mtp_scatter_next_token_ids( ) +@triton.jit +def _fwd_kernel_scatter_accept_len( + req_to_accept_len, + b_req_mtp_start_loc, + b_req_idx, + mtp_accept_len, +): + cur_index = tl.program_id(0) + req_start_loc = tl.load(b_req_mtp_start_loc + cur_index) + cur_req_idx = tl.load(b_req_idx + req_start_loc) + accept_len = tl.load(mtp_accept_len + cur_index) + tl.store(req_to_accept_len + cur_req_idx, accept_len) + return + + +def scatter_mtp_accept_len( + req_to_accept_len: torch.Tensor, + b_req_mtp_start_loc: torch.Tensor, + b_req_idx: torch.Tensor, + mtp_accept_len: torch.Tensor, +): + """ + 将本步每个真实请求(组首)的 accept 数量写入 GPU 常驻的 req_to_accept_len[req_idx]。 + 融合 `req_to_accept_len[b_req_idx[b_req_mtp_start_loc]] = mtp_accept_len` 的 gather+scatter + 为单次 launch、无中间张量。每个 program 处理一个真实请求。 + Args: + req_to_accept_len: (max_req_num + 1,) + b_req_mtp_start_loc: (num_reqs,) 每组首行在 batch 中的偏移 + b_req_idx: (batch_size,) grouped 布局的 req_idx(组首即该请求的 req_idx) + mtp_accept_len: (num_reqs,) + """ + num_reqs = mtp_accept_len.shape[0] + if num_reqs == 0: + return + grid = (num_reqs,) + _fwd_kernel_scatter_accept_len[grid]( + req_to_accept_len=req_to_accept_len, + b_req_mtp_start_loc=b_req_mtp_start_loc, + b_req_idx=b_req_idx, + mtp_accept_len=mtp_accept_len, + num_warps=1, + num_stages=1, + ) + + def test_mtp_verify(): req_to_next_token_ids = torch.tensor( [[1, 2, -2, -1, -1], [1, 2, 0, -1, -1], [1, 3, 4, 4, 5]], dtype=torch.int32, device="cuda" diff --git a/lightllm/common/linear_att_cache_manager/config_objs.py b/lightllm/common/linear_att_cache_manager/config_objs.py index f533c71dbc..f48b9865a3 100644 --- a/lightllm/common/linear_att_cache_manager/config_objs.py +++ b/lightllm/common/linear_att_cache_manager/config_objs.py @@ -1,7 +1,7 @@ import torch import dataclasses import triton -from lightllm.utils.envs_utils import get_env_start_args, _mtp_added_layer_num +from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.utils.torch_dtype_utils import get_torch_dtype @@ -9,8 +9,13 @@ def get_mtp_draft_full_att_layer_num(args) -> int: - # Delegates to the single source of truth in envs_utils (#9). - return _mtp_added_layer_num(getattr(args, "mtp_mode", None), getattr(args, "mtp_step", 0)) + # mtp_mode -> draft model 增加的 full-att KV 层数(与 envs_utils.get_added_mtp_kv_layer_num 同口径)。 + mtp_mode = getattr(args, "mtp_mode", None) + if mtp_mode == "eagle_with_att": + return 1 + if mtp_mode == "vanilla_with_att": + return getattr(args, "mtp_step", 0) + return 0 @dataclasses.dataclass diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index c01f2d7c0e..164d04fc3b 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -86,6 +86,15 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num + # MTP verify decode 的 per-req accept 数量:GPU 常驻、按 req_idx 索引(含 HOLD 槽)。 + # 取代旧的 req.mtp_accept_len host 属性 —— verify 后在 GPU 上 scatter,下一步在 GDN 的 + # init_mtp_verify_extra_state 里按 req_first gather 成 b_num_accepted_tokens,省掉每步的 + # host 回写 + H2D 重建。HOLD 槽恒为 1,使 padding 组 gather 到 1。仅 mtp_step>0 时分配。 + self.req_to_accept_len = ( + torch.ones((max_request_num + 1,), dtype=torch.int32, device="cuda") + if get_env_start_args().mtp_step > 0 + else None + ) def alloc(self): return self.req_list.alloc() @@ -274,7 +283,8 @@ def init_linear_att_state(self, req: "InferReq"): # #17: zero the FULL (mtp_step + 1)-row SSM block, not just canonical row +0, so a future # first-step verify reading offset>0 after fresh init never hits a never-written row (NaN). self.req_to_ssm_state.buffer[:, ssm_start : ssm_start + (self.mtp_step + 1), ...].fill_(0) - req.mtp_accept_len = 1 + if self.req_to_accept_len is not None: + self.req_to_accept_len[req.req_idx] = 1 return def get_mamba_cache(self, layer_idx_in_all: int): @@ -298,7 +308,8 @@ def copy_big_page_buffer_to_linear_att_state(self, big_page_buffer_idx: int, req narrow_w = conv_state.shape[-1] # persisted (narrow) width self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state - req.mtp_accept_len = 1 + if self.req_to_accept_len is not None: + self.req_to_accept_len[req.req_idx] = 1 return def copy_small_page_buffer_to_linear_att_state( @@ -314,5 +325,6 @@ def copy_small_page_buffer_to_linear_att_state( # 同时,非连续对象的拷贝,可能存在效率问题。 self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state - req.mtp_accept_len = 1 + if self.req_to_accept_len is not None: + self.req_to_accept_len[req.req_idx] = 1 return diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py index 2687a4aca7..35bd6f7925 100644 --- a/lightllm/models/qwen3_5/infer_struct.py +++ b/lightllm/models/qwen3_5/infer_struct.py @@ -10,5 +10,5 @@ def init_some_extra_state(self, model): super().init_some_extra_state(model) from lightllm.common.basemodel.mtp_verify_extra_state import init_mtp_verify_extra_state - init_mtp_verify_extra_state(self) + init_mtp_verify_extra_state(self, model) return diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py index b486bc6040..bb1516673a 100644 --- a/lightllm/models/qwen3next/infer_struct.py +++ b/lightllm/models/qwen3next/infer_struct.py @@ -10,5 +10,5 @@ def init_some_extra_state(self, model): super().init_some_extra_state(model) from lightllm.common.basemodel.mtp_verify_extra_state import init_mtp_verify_extra_state - init_mtp_verify_extra_state(self) + init_mtp_verify_extra_state(self, model) return diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index fc02095f85..be8c022594 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -389,9 +389,11 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer - b_num_accepted_tokens = torch.tensor( - [req.mtp_accept_len for req in reqs], dtype=torch.int32, requires_grad=False, device="cpu" + # accept 数量改由 GPU 常驻的 req_to_accept_len 按 req_idx gather(不再读 req.mtp_accept_len)。 + req_idxs = torch.tensor( + [req.req_idx for req in reqs], dtype=torch.int32, requires_grad=False, device="cpu" ).cuda(non_blocking=True) + b_num_accepted_tokens = self.req_manager.req_to_accept_len[req_idxs] copy_linear_att_state_to_kv_buffer( b_req_idx=b_req_idx, @@ -417,11 +419,13 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L self.radix_cache.linear_att_small_page_buffers.alloc_one_state_cache() ) if req.tail_linear_att_small_page_buffer_id is not None: - assert 1 <= req.mtp_accept_len <= self.args.mtp_step + 1, ( - f"mtp_accept_len={req.mtp_accept_len} out of range " + # 冷路径(prefill 跨小页边界):单标量从 GPU buffer 读回做 Python 切片下标。 + accept_len = int(self.req_manager.req_to_accept_len[req.req_idx].item()) + assert 1 <= accept_len <= self.args.mtp_step + 1, ( + f"mtp_accept_len={accept_len} out of range " f"[1, {self.args.mtp_step + 1}]; would slice past the widened conv slot" ) - canonical_off = req.mtp_accept_len - 1 + canonical_off = accept_len - 1 conv_src_idx = req.req_idx ssm_src_idx = req.req_idx * (self.args.mtp_step + 1) + canonical_off narrow_w = self.req_manager.linear_config.get_persisted_conv_state_shape()[-1] @@ -578,8 +582,6 @@ def __init__( else: self.decode_need_token_num = self._normal_decode_need_token_num - self.mtp_accept_len: int = 1 - if g_infer_context.is_linear_att_mixed_model: self.get_chuncked_input_token_len = self.get_chuncked_input_token_len_for_linear_att self.get_chuncked_input_token_ids = self.get_chuncked_input_token_ids_for_linear_att diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index e7fa58712a..e02ca368f1 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -569,6 +569,7 @@ def _get_classed_reqs( can_alloc_token_num = g_infer_context.get_can_alloc_token_num() for req_obj in ready_reqs: + if req_obj.filter_mark: finished_reqs.append(req_obj) continue @@ -750,8 +751,6 @@ def _update_mtp_accept_ratio( decode_reqs: List[InferReq], mtp_accept_len_cpu: torch.Tensor, ): - # Master-only accept-ratio statistics. Unlike the phase-2 mtp_accept_len commit - # (inlined in decode_mtp) this only feeds metrics, so it may stay in phase 3. if self.is_master_in_dp: for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): req.update_mtp_accepted_token_num(accept_token_num=accept_len - 1) @@ -759,8 +758,6 @@ def _update_mtp_accept_ratio( def _gen_argmax_token_ids(self, model_output: ModelOutput): logits = model_output.logits - # softmax is strictly monotonic, so argmax(softmax(logits)) == argmax(logits); - # skip the softmax to shorten the per-step MTP draft critical chain (need-to-fix #16). draft_next_token_ids_gpu = torch.argmax(logits, dim=-1) return draft_next_token_ids_gpu @@ -774,6 +771,7 @@ def _sample_and_scatter_token( b_prefill_has_output_cpu: torch.Tensor = None, mask_func: Optional[Callable] = None, ): + if mask_func is not None: assert len(run_reqs) == logits.shape[0] mask_func(run_reqs, logits) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index cd1e14be73..65bb96163e 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -20,6 +20,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.common.basemodel.triton_kernel.mtp_utils import ( mtp_scatter_next_token_ids, + scatter_mtp_accept_len, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id @@ -241,14 +242,6 @@ def decode_mtp( """ model_input, run_reqs = prepare_decode_inputs(decode_reqs) - if self.mtp_step > 0: - accept_lens = [req.mtp_accept_len for req in decode_reqs] - model_input.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list( - key="b_num_accepted_tokens", - data=accept_lens, - dtype=torch.int32, - ) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) @@ -261,6 +254,9 @@ def decode_mtp( b_req_idx=model_input.b_req_idx, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + scatter_mtp_accept_len( + self.model.req_manager.req_to_accept_len, b_req_mtp_start_loc, model_input.b_req_idx, mtp_accept_len + ) accepted_index_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="accepted_index", gpu_tensor=accepted_index, @@ -296,8 +292,6 @@ def decode_mtp( # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() verify_event.synchronize() - for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): - req.mtp_accept_len = int(accept_len) verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1] update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) @@ -350,17 +344,15 @@ def _draft_decode_vanilla( mtp_accept_len: torch.Tensor, b_req_mtp_start_loc: torch.Tensor, ): - # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, - # 避免污染之后仍要用到的 main_model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, - # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 - # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 - draft_model_input = copy.copy(main_model_input) + # share some inference info with the main model + draft_model_input = main_model_input draft_model_output = main_model_output draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) # process the draft model output for draft_model_idx in range(self.mtp_step): + draft_model_input.input_ids = draft_next_token_ids draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP @@ -394,17 +386,15 @@ def _draft_decode_eagle( eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(num_reqs * self.mtp_step) eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) - # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, - # 避免污染之后仍要用到的 main_model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, - # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 - # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 - draft_model_input = copy.copy(main_model_input) + # share some inference info with the main model + draft_model_input = main_model_input draft_model_output = main_model_output draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) # process the draft model output for _step in range(self.mtp_step): + draft_model_input.input_ids = draft_next_token_ids draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 18f60b4934..9c83a5f352 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -21,7 +21,7 @@ from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager -from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids +from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids, scatter_mtp_accept_len from .control_state import DPControlState @@ -264,6 +264,7 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer b_req_idx = torch.cat((model_input0.b_req_idx[0:req_num0], model_input1.b_req_idx[0:req_num1]), dim=0) if (req_num0 + req_num1) > 0: + _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=logits, b_req_idx=b_req_idx, @@ -405,6 +406,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] sync_event.record() if req_num > 0: + # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill) @@ -431,22 +433,10 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] return def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): - model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs(decode_reqs) + model_input, run_reqs, _ = padded_prepare_decode_inputs(decode_reqs) b_mtp_index_cpu = model_input.b_mtp_index req_num = len(run_reqs) - if self.mtp_step > 0: - # 标记 verify decode 布局:每个 req 一个 accept 数量(padding 出来的 fake req 记为 1)。 - # 不设置 b_num_accepted_tokens 会让主模型的 verify forward 走非 verify 的 GDN/FA3 布局, - # 并命中 hybrid 主模型从未捕获的 cudagraph key (bs, False) -> KeyError。 - # 与 chunked_prefill/impl.py 的 decode_mtp 保持一致。 - accept_lens = [req.mtp_accept_len for req in decode_reqs] + [1] * padded_req_num - model_input.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list( - key="b_num_accepted_tokens", - data=accept_lens, - dtype=torch.int32, - ) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None @@ -473,6 +463,9 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): b_req_idx=b_req_idx, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + scatter_mtp_accept_len( + self.model.req_manager.req_to_accept_len, b_req_mtp_start_loc, b_req_idx, mtp_accept_len + ) accepted_index_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="accepted_index", gpu_tensor=accepted_index, @@ -507,11 +500,6 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() verify_event.synchronize() - # 写回每个 req 的本步 accept 数量,供下一步 verify 经 b_num_accepted_tokens 传入 - # GDN/linear-att verify kernel(据此提交 conv/ssm 递归状态的正确偏移)。chunked 路径 - # 在 chunked_prefill/impl.py 同样写回;dp 缺失会让状态停留在 accept=1 -> 状态错乱、精度崩塌。 - for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): - req.mtp_accept_len = int(accept_len) verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1] update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) @@ -552,11 +540,8 @@ def _draft_decode_vanilla( req_num: int, ): all_next_token_ids = [] - # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, - # 避免污染之后仍要用到的 model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, - # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 - # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 - draft_model_input = copy.copy(model_input) + # share some inference info with the main model + draft_model_input = model_input draft_model_output = model_output draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") if req_num > 0: @@ -566,6 +551,7 @@ def _draft_decode_vanilla( # process the draft model output for draft_model_idx in range(self.mtp_step): + draft_model_input.input_ids = draft_next_token_ids_gpu draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP @@ -595,11 +581,8 @@ def _draft_decode_eagle( req_num: int, ): all_next_token_ids = [] - # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, - # 避免污染之后仍要用到的 model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, - # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 - # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 - draft_model_input = copy.copy(model_input) + # share some inference info with the main model + draft_model_input = model_input draft_model_output = model_output all_next_token_ids.append(next_token_ids) draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") @@ -615,6 +598,7 @@ def _draft_decode_eagle( # process the draft model output for _step in range(self.mtp_step): + draft_model_input.input_ids = draft_next_token_ids_gpu draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP @@ -699,6 +683,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I draft_model_output0, draft_model_output1 = model_output0, model_output1 for draft_model_idx in range(self.num_mtp_models): + draft_model_input0 = prepare_mtp_prefill_inputs( model_input=draft_model_input0, b_next_token_ids=draft_next_token_ids_gpu0, @@ -750,35 +735,15 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf ( model_input0, run_reqs0, - padded_req_num0, + _, model_input1, run_reqs1, - padded_req_num1, + _, ) = padded_overlap_prepare_decode_inputs(decode_reqs) req_num0, req_num1 = len(run_reqs0), len(run_reqs1) all_next_token_ids = [] b_mtp_index_cpu0 = model_input0.b_mtp_index b_mtp_index_cpu1 = model_input1.b_mtp_index - - if self.mtp_step > 0: - # 标记两个 micro-batch 的 verify decode 布局,每个 req 一个 accept 数量 - # (padding 出来的 fake req 记为 1)。run_reqs* 内每个真实 req 占 mtp_step+1 行, - # 取每组首行即可得到逐 req 的列表。不设置会让主模型 verify forward 走非 verify 布局, - # 命中 hybrid 主模型从未捕获的 cudagraph key (bs, False) -> KeyError。 - mtp_size = self.mtp_step + 1 - accept_lens0 = [r.mtp_accept_len for r in run_reqs0[::mtp_size]] + [1] * padded_req_num0 - accept_lens1 = [r.mtp_accept_len for r in run_reqs1[::mtp_size]] + [1] * padded_req_num1 - model_input0.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list( - key="b_num_accepted_tokens_0", - data=accept_lens0, - dtype=torch.int32, - ) - model_input1.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list( - key="b_num_accepted_tokens_1", - data=accept_lens1, - dtype=torch.int32, - ) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) logits0 = model_output0.logits @@ -810,6 +775,9 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_req_idx=b_req_idx, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + scatter_mtp_accept_len( + self.model.req_manager.req_to_accept_len, b_req_mtp_start_loc, b_req_idx, mtp_accept_len + ) accepted_index_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="accepted_index", gpu_tensor=accepted_index, @@ -848,11 +816,6 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf if req_num0 + req_num1 > 0: event_pack.notify_post_handle_and_wait_pre_post_handle() verify_event.synchronize() - # 写回每个 req 的本步 accept 数量,供下一步 verify 经 b_num_accepted_tokens 传入 - # GDN/linear-att verify kernel(据此提交 conv/ssm 递归状态的正确偏移)。chunked 路径 - # 在 chunked_prefill/impl.py 同样写回;dp 缺失会让状态停留在 accept=1 -> 状态错乱、精度崩塌。 - for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu): - req.mtp_accept_len = int(accept_len) verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1] update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) @@ -913,12 +876,8 @@ def _draft_decode_vanilla_overlap( ): all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, - # 避免污染之后仍要用到的 model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, - # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 - # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 - draft_model_input0 = copy.copy(model_input0) - draft_model_input1 = copy.copy(model_input1) + # share some inference info with the main model + draft_model_input0, draft_model_input1 = model_input0, model_input1 draft_model_output0, draft_model_output1 = model_output0, model_output1 draft_next_token_ids_gpu0 = torch.zeros((model_input0.batch_size), dtype=torch.int64, device="cuda") @@ -932,6 +891,7 @@ def _draft_decode_vanilla_overlap( # process the draft model output for draft_model_idx in range(self.mtp_step): + draft_model_input0.input_ids = draft_next_token_ids_gpu0 draft_model_input0.mtp_draft_input_hiddens = draft_model_output0.mtp_main_output_hiddens draft_model_input1.input_ids = draft_next_token_ids_gpu1 @@ -974,12 +934,8 @@ def _draft_decode_eagle_overlap( ): all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - # 复用主模型的推理信息。copy.copy 隔离 draft 每步对 input_ids / b_seq_len / mem_indexes 的修改, - # 避免污染之后仍要用到的 model_input(need_free_mem_indexes)。保留 b_num_accepted_tokens, - # 使 draft 与主模型一样走 (mtp_step+1) 分组的 verify decode 布局(与 upstream 一致;纯全注意力 - # draft 在分组布局下与展开成扁平 batch 的逐位置 attention 数值等价)。 - draft_model_input0 = copy.copy(model_input0) - draft_model_input1 = copy.copy(model_input1) + # share some inference info with the main model + draft_model_input0, draft_model_input1 = model_input0, model_input1 draft_model_output0, draft_model_output1 = model_output0, model_output1 draft_next_token_ids_gpu0 = torch.zeros((model_input0.batch_size), dtype=torch.int64, device="cuda") @@ -1004,6 +960,7 @@ def _draft_decode_eagle_overlap( # process the draft model output for _step in range(self.mtp_step): + draft_model_input0.input_ids = draft_next_token_ids_gpu0 draft_model_input0.mtp_draft_input_hiddens = draft_model_output0.mtp_main_output_hiddens draft_model_input1.input_ids = draft_next_token_ids_gpu1 diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index f6508994b4..773320273c 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -226,20 +226,16 @@ def enable_huge_page(): return enable_env_vars("LIGHTLLM_HUGE_PAGE_ENABLE") -def _mtp_added_layer_num(mtp_mode, mtp_step: int) -> int: - # Single source of truth for the mtp_mode -> added KV/full-att layer count (#9). - if mtp_mode == "eagle_with_att": - return 1 - if mtp_mode == "vanilla_with_att": - return mtp_step - return 0 - - @lru_cache(maxsize=None) def get_added_mtp_kv_layer_num() -> int: # mtp 模式下需要在mem manger上扩展draft model使用的layer - args = get_env_start_args() - return _mtp_added_layer_num(args.mtp_mode, args.mtp_step) + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "eagle_with_att": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "vanilla_with_att": + added_mtp_layer_num += get_env_start_args().mtp_step + + return added_mtp_layer_num @lru_cache(maxsize=None) diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 2a089a9bf2..ff5ad0127b 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -120,11 +120,7 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": if args.mtp_mode is not None: # TODO 可能会存在不同mtp模式的精度问题 - if is_linear_att_mixed_model(args.model_dir): - # Linear mixed models use one packed byte page; MTP draft full-attn - # slots are accounted in LinearAttCacheConfig.get_cpu_cache_big_page_bytes(). - pass - else: + if not is_linear_att_mixed_model(args.model_dir): cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() cpu_cache_page_num = int( From 7d6fc71c61179ae41e57a5ff2cdd741645ef5aa4 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 16 Jun 2026 15:32:38 +0800 Subject: [PATCH 14/19] revert: drop all test/ and unit_tests/ changes from this branch --- .../benchmark/static_inference/model_infer.py | 2 +- .../static_inference/model_infer_mtp.py | 272 +++++------------- test/benchmark/static_inference/test_model.py | 21 +- test/cpu_cache_kernel/test_speed.py | 2 +- .../test_fp8_decode_verify_narrowed.py | 59 ---- .../basemodel/test_mtp_decode_cuda_graph.py | 174 ----------- .../test_init_linear_att_state_zeros_block.py | 41 --- .../common/test_linear_att_copy_guards.py | 39 --- ...st_linear_att_mtp_cpu_cache_persistence.py | 219 -------------- .../common/test_linear_att_snapshot_split.py | 41 --- .../common/test_mtp_verify_extra_state.py | 36 --- .../test_qwen3next_linear_att_page_helper.py | 112 -------- .../qwen3next/test_causal_conv1d_spec.py | 147 ---------- .../test_conv_prefill_decode_roundtrip.py | 74 ----- .../qwen3next/test_gdn_verify_equivalence.py | 194 ------------- 15 files changed, 71 insertions(+), 1362 deletions(-) delete mode 100644 unit_tests/common/basemodel/test_fp8_decode_verify_narrowed.py delete mode 100644 unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py delete mode 100644 unit_tests/common/test_init_linear_att_state_zeros_block.py delete mode 100644 unit_tests/common/test_linear_att_copy_guards.py delete mode 100644 unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py delete mode 100644 unit_tests/common/test_linear_att_snapshot_split.py delete mode 100644 unit_tests/common/test_mtp_verify_extra_state.py delete mode 100644 unit_tests/common/test_qwen3next_linear_att_page_helper.py delete mode 100644 unit_tests/models/qwen3next/test_causal_conv1d_spec.py delete mode 100644 unit_tests/models/qwen3next/test_conv_prefill_decode_roundtrip.py delete mode 100644 unit_tests/models/qwen3next/test_gdn_verify_equivalence.py diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index b93c5fee55..f2c900af09 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -36,7 +36,7 @@ def test_model_inference(args): "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, "mem_fraction": args.mem_fraction, - "max_req_num": 512, + "max_req_num": 2048, "batch_max_tokens": 1024, "run_mode": "normal", "max_seq_length": args.max_req_total_len, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index f2c21ea261..72f06a919c 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -1,5 +1,4 @@ import os -import copy import torch import numpy as np from multiprocessing import Queue @@ -10,60 +9,42 @@ from lightllm.models import get_model from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from lightllm.server.core.objs.start_args_type import StartArgs -from torch.profiler import profile, ProfilerActivity +from torch.profiler import profile, record_function, ProfilerActivity from lightllm.utils.log_utils import init_logger +from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +import torch.cuda as cuda logger = init_logger(__name__) def init_mtp_model(args: StartArgs, kvargs, main_model): + mtp_step = args.mtp_step draft_models = [] os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - - if args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: - num_mtp_modules = args.mtp_step - elif args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: - num_mtp_modules = 1 - else: - assert False, f"error mtp mode {args.mtp_mode}" - - for i in range(num_mtp_modules): - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir[i]) - model_type = mtp_model_cfg.get("model_type", "") - mtp_model_kvargs = { - "weight_dir": args.mtp_draft_model_dir[i], + mtp_model_kvargs = kvargs + mtp_model_kvargs.update( + { + "weight_dir": args.mtp_draft_model_dir, "max_total_token_num": main_model.mem_manager.size, - "load_way": kvargs["load_way"], - "max_req_num": kvargs.get("max_req_num", 1000), - "max_seq_length": kvargs.get("max_seq_length", 1024 * 5), - "is_token_healing": False, - "return_all_prompt_logics": False, - "disable_chunked_prefill": args.disable_chunked_prefill, - "data_type": kvargs.get("data_type", "float16"), - "graph_max_batch_size": kvargs.get("graph_max_batch_size", 16), - "graph_max_len_in_batch": kvargs.get("graph_max_len_in_batch", 8196), - "disable_cudagraph": kvargs.get("disable_cudagraph", False), - "mem_fraction": kvargs["mem_fraction"], - "batch_max_tokens": kvargs.get("batch_max_tokens", None), - "quant_type": kvargs.get("quant_type", None), - "quant_cfg": kvargs.get("quant_cfg", None), - "run_mode": "normal", - "llm_prefill_att_backend": kvargs.get("llm_prefill_att_backend", args.llm_prefill_att_backend), - "llm_decode_att_backend": kvargs.get("llm_decode_att_backend", args.llm_decode_att_backend), - "vit_att_backend": kvargs.get("vit_att_backend", args.vit_att_backend), - "llm_kv_type": kvargs.get("llm_kv_type", args.llm_kv_type), - "llm_kv_quant_group_size": kvargs.get("llm_kv_quant_group_size", args.llm_kv_quant_group_size), - "main_model": main_model, - "mtp_previous_draft_models": draft_models.copy(), + "disable_chunked_prefill": True, "mtp_mode": args.mtp_mode, + "main_model": main_model, } - - from lightllm.server.router.model_infer.mode_backend.mtp_model_factory import create_mtp_draft_model - - draft_models.append(create_mtp_draft_model(model_type, args.mtp_mode, mtp_model_kvargs)) - - logger.info(f"loaded mtp model class {draft_models[i].__class__}") + ) + for i in range(mtp_step): + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir) + mtp_model_kvargs.update( + { + "weight_dir": args.spec_model_dir, + "max_total_token_num": main_model.mem_manager.size, + "disable_chunked_prefill": True, + "mtp_mode": args.mtp_mode, + "main_model": main_model, + "mem_layer_start": main_model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], + } + ) + draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) return draft_models @@ -87,22 +68,13 @@ def test_model_inference_mtp(args): "max_total_token_num": args.max_total_token_num, "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, - "mem_fraction": args.mem_fraction, - # Static bench runs explicit batch sizes (<= a few hundred). The hybrid Qwen3.5 - # GDN req-state cache is sized max_req_num * (mtp_step + 1) at ~34 MB/slot, so the - # old default of 2000 alloc'd ~140 GB and OOM'd under MTP. 512 covers any realistic - # static batch sweep while keeping the GDN cache small. - "max_req_num": 512, + "mem_faction": args.mem_fraction, + "max_req_num": 2000, "batch_max_tokens": 2048, "run_mode": "normal", "max_seq_length": args.max_req_total_len, + "spec_algo": args.spec_algo, "disable_cudagraph": args.disable_cudagraph, - "quant_cfg": args.quant_cfg, - "llm_prefill_att_backend": args.llm_prefill_att_backend, - "llm_decode_att_backend": args.llm_decode_att_backend, - "vit_att_backend": args.vit_att_backend, - "llm_kv_type": args.llm_kv_type, - "llm_kv_quant_group_size": args.llm_kv_quant_group_size, } proc = multiprocessing.Process( target=tppart_model_infer, @@ -141,36 +113,28 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) test_data = test_data.reshape(-1) - test_data = torch.from_numpy(test_data) + test_data = torch.from_numpy(test_data).cuda() b_req_idx = torch.tensor( - [main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cpu" + [main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) - b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") + b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") for i in range(batch_size): b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32) - b_prefill_start_loc = b_seq_len.cumsum(dim=0, dtype=torch.int32) - b_seq_len + mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() # Main model Prefill model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_q_seq_len=input_len, - max_kv_seq_len=input_len, - max_cache_len=0, input_ids=test_data, - mem_indexes_cpu=mem_indexes, + mem_indexes=mem_indexes, b_req_idx=b_req_idx, - b_mtp_index=b_mtp_index, b_seq_len=b_seq_len, is_prefill=True, b_ready_cache_len=b_ready_cache_len, - b_prefill_start_loc=b_prefill_start_loc, - prefix_total_token_num=0, multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], ) @@ -203,22 +167,8 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ torch.cuda.synchronize() - # Speculative width = args.mtp_step in BOTH modes (mirrors base_backend: self.mtp_step = - # args.mtp_step). The number of draft MODEL INSTANCES differs: vanilla loads mtp_step - # instances (each forwarded once), eagle loads ONE instance forwarded mtp_step times - # (chunked_prefill/impl.py: draft_models[_step % num_instances]). The verify batch always - # expands to (mtp_step + 1) rows per request. - spec_width = args.mtp_step - num_instances = len(draft_models) - # The draft prefill above produced (1 + num_instances) columns; pad/truncate to - # (spec_width + 1) so the decode verify batch matches the server's expand width. Only the - # SHAPE matters for throughput here (argmax over random inputs); token values do not. - while len(draft_ids) < spec_width + 1: - draft_ids.append(draft_ids[-1]) - draft_ids = draft_ids[: spec_width + 1] decode_input_ids = np.stack(draft_ids, axis=-1).reshape(-1) - decode_input_ids = torch.from_numpy(decode_input_ids) - mtp_step = spec_width + decode_input_ids = torch.from_numpy(decode_input_ids).cuda() # build main decode input: nopad_b_seq_idx = [] @@ -227,155 +177,67 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_max_len_in_batch = 0 for i in range(batch_size): - nopad_b_seq_idx.append(b_req_idx[i].item()) + nopad_b_seq_idx.append(b_req_idx[i]) seq_len = b_seq_len[i].item() nopad_b_seq_len.append(seq_len + 1) nopad_total_token_num += seq_len + 1 - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + 1) + nopad_max_len_in_batch = max(nopad_max_len_in_batch, b_seq_len[i] + 1) - for step in range(mtp_step): - nopad_b_seq_idx.append(b_req_idx[i].item()) + for step in range(len(draft_models)): + nopad_b_seq_idx.append(b_req_idx[i]) nopad_b_seq_len.append(seq_len + step + 2) nopad_total_token_num += seq_len + step + 2 nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + step + 2) - nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cpu") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cpu") - b_mtp_index = torch.arange(mtp_step + 1, dtype=torch.int32).repeat(batch_size) - mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (mtp_step + 1)) + nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() model_input = ModelInput( - batch_size=batch_size * (mtp_step + 1), + batch_size=batch_size * (len(draft_models) + 1), total_token_num=nopad_total_token_num, - max_q_seq_len=1, - max_kv_seq_len=nopad_max_len_in_batch, input_ids=decode_input_ids, - mem_indexes_cpu=mem_indexes, + mem_indexes=mem_indexes, b_req_idx=nopad_b_seq_idx, - b_mtp_index=b_mtp_index, b_seq_len=nopad_b_seq_len, is_prefill=False, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size * (mtp_step + 1))], + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size * (len(draft_models) + 1))], ) - # MTP verify layout. The main decode is a VERIFY forward over the (mtp_step+1)-expanded - # batch. Setting b_num_accepted_tokens (one entry per real request) flips is_mtp_verify=True - # so the hybrid GDN main model runs the fused spec-decode verify kernel — the production path. - # Without it the main decode silently takes the plain _gdn_decode_kernel on the S+1-expanded - # batch (whose rows share req_idx), colliding on the single widened conv slot and mismeasuring - # cost. accept_len is fixed at 1 (steady-state low-acceptance); the verify-forward COST is - # ~constant in accept_len (it always processes mtp_step+1 rows), so this faithfully measures - # per-step decode cost. Vary accept_len in [1, mtp_step+1] to sweep the acceptance regime. - accept_len = 1 - is_eagle = args.mtp_mode.startswith("eagle") - model_input.b_num_accepted_tokens = torch.full((batch_size,), accept_len, dtype=torch.int32, device="cuda") - if is_eagle: - # EAGLE draft scratch slots (n_real * mtp_step), mirroring _draft_decode_eagle. Allocated - # once and reused across steps (throughput bench overwrites draft KV; no correctness check). - eagle_mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * mtp_step).cuda() - - # Prize-sizing profiler (need-to-fix #22): env-gated, eagle-only, additive. Times the verify - # forward vs the S-step draft chain to decide whether collapsing the chain into a CUDA graph is - # worth it. host_bound_ratio ~1 (or per_step flat across bs) => host/launch-bound => graph wins. - _mtp_profile = os.environ.get("MTP_PROFILE", "0") == "1" - _prof = {"verify_ms": 0.0, "draft_ms": 0.0, "draft_host_ms": 0.0, "n": 0, "per_step_ms": [0.0] * mtp_step} - # Main decode - for i in range(0, output_len, mtp_step + 1): + for i in range(0, output_len, len(draft_models) + 1): torch.cuda.synchronize() step_start_time = time.time() + model_output = main_model.forward( + model_input, + ) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - # --- main VERIFY forward: mtp_step+1 rows/req through the fused GDN verify kernel --- - if _mtp_profile and not warmup: - _ev_v0 = torch.cuda.Event(enable_timing=True) - _ev_v1 = torch.cuda.Event(enable_timing=True) - _ev_v0.record() - model_output = main_model.forward(model_input) - if _mtp_profile and not warmup: - _ev_v1.record() - predict_ids = torch.argmax(model_output.logits, dim=1, keepdim=True) - - if is_eagle: - # EAGLE draft: full (mtp_step+1)-expanded batch, plain decode layout (the Qwen3.5 MTP - # draft is full-attention and takes b_num_accepted_tokens=None). Mirrors chunked_prefill - # _draft_decode_eagle: run the draft model mtp_step times, allocating fresh KV slots and - # shifting mem_indexes one column per step. - draft_model_input = copy.copy(model_input) - draft_model_input.b_num_accepted_tokens = None - draft_model_input.input_ids = predict_ids.reshape(-1) - draft_model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens - - if _mtp_profile and not warmup: - _step_evs = [] - _ev_d0 = torch.cuda.Event(enable_timing=True) - _ev_d1 = torch.cuda.Event(enable_timing=True) - _host_t0 = time.time() - _ev_d0.record() - for _step in range(mtp_step): - draft_model = draft_models[_step % num_instances] - if _mtp_profile and not warmup: - _es = torch.cuda.Event(enable_timing=True) - _ee = torch.cuda.Event(enable_timing=True) - _es.record() - draft_output = draft_model.forward(draft_model_input) - if _mtp_profile and not warmup: - _ee.record() - _step_evs.append((_es, _ee)) - draft_model_input.input_ids = torch.argmax(draft_output.logits, dim=1, keepdim=True).reshape(-1) - draft_model_input.mtp_draft_input_hiddens = draft_output.mtp_main_output_hiddens - draft_model_input.b_seq_len = draft_model_input.b_seq_len + 1 - draft_model_input.max_kv_seq_len += 1 - eagle_mem_indexes_i = eagle_mem_indexes[_step * batch_size : (_step + 1) * batch_size] - draft_model_input.mem_indexes = torch.cat( - [draft_model_input.mem_indexes.view(-1, mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)], - dim=1, - ).view(-1) - if _mtp_profile and not warmup: - _ev_d1.record() - _host_t1 = time.time() - else: - # VANILLA draft: full (mtp_step+1)-expanded batch, plain decode layout. Mirrors - # chunked_prefill _draft_decode_vanilla (b_num_accepted_tokens cleared on a copy so the - # MTP draft model does not inherit the main model's verify layout / cudagraph key). - draft_model_input = copy.copy(model_input) - draft_model_input.b_num_accepted_tokens = None - draft_model_input.input_ids = predict_ids.reshape(-1) - draft_model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens - for _step in range(mtp_step): - draft_model = draft_models[_step % num_instances] - draft_output = draft_model.forward(draft_model_input) - draft_model_input.input_ids = torch.argmax(draft_output.logits, dim=1, keepdim=True).reshape(-1) - draft_model_input.mtp_draft_input_hiddens = draft_output.mtp_main_output_hiddens + # draft decode + model_input.input_ids = predict_ids.reshape(-1) + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens + for draft_model_id in range(len(draft_models)): + draft_model = draft_models[draft_model_id] + model_output = draft_model.forward( + model_input, + ) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + model_input.input_ids = predict_ids.reshape(-1) + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens + + # accept all draft ids by default. + model_input.input_ids = predict_ids.reshape(-1) + model_input.mtp_draft_input_hiddens = model_output.mtp_main_output_hiddens torch.cuda.synchronize() - if _mtp_profile and not warmup and is_eagle and i >= 3 * (mtp_step + 1): - # skip first 3 macro-steps (lazy cudagraph capture / cache warmup) - _prof["verify_ms"] += _ev_v0.elapsed_time(_ev_v1) - _prof["draft_ms"] += _ev_d0.elapsed_time(_ev_d1) - _prof["draft_host_ms"] += (_host_t1 - _host_t0) * 1000.0 - for _k, (_es, _ee) in enumerate(_step_evs): - _prof["per_step_ms"][_k] += _es.elapsed_time(_ee) - _prof["n"] += 1 if i % 100 == 0 or i == output_len - 1: step_end_time = time.time() if get_current_rank_in_dp() == 0 and not warmup: step_time = step_end_time - step_start_time print(i, " step cost time:", step_time * 1000) - # Peak (all-accepted) throughput: mtp_step+1 candidate tokens per req per step. - print(f"Decode throughput: {batch_size * (mtp_step + 1) * args.dp / step_time} tokens/s") - - if _mtp_profile and is_eagle and _prof["n"] > 0 and get_current_rank_in_dp() == 0 and not warmup: - n = _prof["n"] - ps = ", ".join(f"{v / n:.3f}" for v in _prof["per_step_ms"]) - print(f"[MTP_PROFILE] bs={batch_size} S={mtp_step} steps={n}") - print(f"[MTP_PROFILE] verify_gpu_ms = {_prof['verify_ms'] / n:.3f}") - print(f"[MTP_PROFILE] draft_chain_gpu_ms = {_prof['draft_ms'] / n:.3f}") - print(f"[MTP_PROFILE] draft_chain_host_ms = {_prof['draft_host_ms'] / n:.3f} (host-enqueue, no sync)") - print(f"[MTP_PROFILE] per_draft_step_gpu_ms = [{ps}]") - print( - f"[MTP_PROFILE] host_bound_ratio = " - f"{_prof['draft_host_ms'] / max(_prof['draft_ms'], 1e-9):.3f} (~1 => host-bound => graph wins)" - ) + print(f"Decode throughput: {batch_size * (len(draft_models) + 1) * args.dp / step_time} tokens/s") main_model.mem_manager.free_all() main_model.req_manager.free_all() diff --git a/test/benchmark/static_inference/test_model.py b/test/benchmark/static_inference/test_model.py index 7992c03743..5b3751bcc3 100644 --- a/test/benchmark/static_inference/test_model.py +++ b/test/benchmark/static_inference/test_model.py @@ -11,29 +11,12 @@ from lightllm.utils.config_utils import get_config_json, get_dtype -def parse_batch_size(value): - parts = [part.strip() for part in value.split(",") if part.strip()] - if not parts: - raise ValueError("batch_size must contain at least one integer") - - batch_sizes = [] - for part in parts: - size = int(part) - if size <= 0: - raise ValueError("batch_size values must be positive integers") - batch_sizes.append(size) - - if len(batch_sizes) == 1: - return batch_sizes[0] - return batch_sizes - - class TestModelInfer(unittest.TestCase): def test_model_infer(self): args = get_env_start_args() if args.data_type is None: args.data_type = get_dtype(args.model_dir) - if args.mtp_mode is not None: + if args.mtp_mode == "deepseekv3": test_model_inference_mtp(args) else: test_model_inference(args) @@ -44,7 +27,7 @@ def test_model_infer(self): import torch parser = make_argument_parser() - parser.add_argument("--batch_size", type=parse_batch_size, default=None, help="batch size, e.g. 8 or 1,2,4,8") + parser.add_argument("--batch_size", type=int, default=None, help="batch size") parser.add_argument("--input_len", type=int, default=64, help="input sequence length") parser.add_argument("--output_len", type=int, default=128, help="output sequence length") parser.add_argument( diff --git a/test/cpu_cache_kernel/test_speed.py b/test/cpu_cache_kernel/test_speed.py index b1c761f4c2..254142050c 100644 --- a/test/cpu_cache_kernel/test_speed.py +++ b/test/cpu_cache_kernel/test_speed.py @@ -104,7 +104,7 @@ buffer_count = triton.cdiv(SEQ_LEN, big_page_token_num) + 2 # matches Qwen3NextMemManager -conv_shape = linear_config.get_persisted_conv_state_shape() +conv_shape = linear_config.get_conv_state_shape() cpu_kv_conv_state = torch.empty( (buffer_count, linear_config.linear_layer_num, *conv_shape), dtype=linear_config.conv_state_dtype, diff --git a/unit_tests/common/basemodel/test_fp8_decode_verify_narrowed.py b/unit_tests/common/basemodel/test_fp8_decode_verify_narrowed.py deleted file mode 100644 index a671550d31..0000000000 --- a/unit_tests/common/basemodel/test_fp8_decode_verify_narrowed.py +++ /dev/null @@ -1,59 +0,0 @@ -import types -import torch -import pytest - -import lightllm.common.basemodel.attention.fa3.fp8 as fp8_mod -from lightllm.common.basemodel.attention.fa3.fp8 import Fp8Fa3DecodeAttState - - -def _make_verify_state(n_real, mtp_size, head_num=2, head_dim=8): - """Build an Fp8Fa3DecodeAttState as init_state would leave it in MTP-verify mode, - bypassing init_state. b_att_seq_len/page_table are NARROW (n_real); infer_state.b_seq_len - is the FULL expanded tensor (n_real*mtp_size) that must NOT be used as cache_seqlens.""" - state = object.__new__(Fp8Fa3DecodeAttState) - batch = n_real * mtp_size - state.b_att_seq_len = torch.full((n_real,), 16, dtype=torch.int32) - state.page_table = torch.zeros((n_real, 16), dtype=torch.int32) - state.cu_seqlens_q = torch.arange(0, (n_real + 1) * mtp_size, mtp_size, dtype=torch.int32) - state.cu_seqlens_k = torch.zeros((n_real + 1,), dtype=torch.int32) - state.decode_max_q_seq_len = mtp_size - state.infer_state = types.SimpleNamespace( - b_seq_len=torch.full((batch,), 16, dtype=torch.int32), - batch_size=batch, - ) - # k/v descale sized per real request (att_batch_size), indexed by layer - state.k_descale = torch.ones((1, n_real, head_num)) - state.v_descale = torch.ones((1, n_real, head_num)) - state.backend = types.SimpleNamespace(_find_layer_index=lambda k, v, att_state: 0) - return state, batch - - -def test_fp8_decode_uses_narrowed_cache_seqlens_and_causal(monkeypatch): - n_real, mtp_size, head_num, head_dim = 3, 4, 2, 8 - state, batch = _make_verify_state(n_real, mtp_size, head_num, head_dim) - - captured = {} - - def fake_flash(**kwargs): - captured.update(kwargs) - q = kwargs["q"] - return torch.zeros((q.shape[0], q.shape[1], q.shape[2])) - - def fake_quant(x, use_per_token_if_dynamic=True): - return x, torch.ones((x.shape[0], 1)) - - monkeypatch.setattr(fp8_mod, "flash_attn_with_kvcache", fake_flash) - monkeypatch.setattr(fp8_mod, "scaled_fp8_quant", fake_quant) - - q = torch.randn((batch, head_num, head_dim)) - k = torch.randn((batch, head_num, head_dim)) - v = torch.randn((batch, head_num, head_dim)) - - state._fp8_decode_att(q=q, k=k, v=v) - - # The KV-side seqlens must be the NARROW per-real-request tensor, matching page_table rows. - assert captured["cache_seqlens"] is state.b_att_seq_len - assert captured["cache_seqlens"].shape[0] == n_real - assert captured["cache_seqlens"].shape[0] == captured["page_table"].shape[0] - # Verify decode must be causal, like the non-fp8 sibling. - assert captured["causal"] is True diff --git a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py b/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py deleted file mode 100644 index 7c3e751022..0000000000 --- a/unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py +++ /dev/null @@ -1,174 +0,0 @@ -from types import SimpleNamespace - -import torch - -from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.basemodel.batch_objs import ModelInput - - -def test_mtp_decode_cuda_graph_warmup_uses_verify_layout(): - from lightllm.common.basemodel.cuda_graph import CudaGraph - - graph = CudaGraph.__new__(CudaGraph) - graph.mtp_step = 2 - graph.graph_max_len_in_batch = 128 - - class FakeMemManager: - HOLD_TOKEN_MEMINDEX = -1 - - def alloc(self, size): - return torch.arange(size, dtype=torch.int32) - - model = SimpleNamespace( - req_manager=SimpleNamespace(HOLD_REQUEST_ID=99), - mem_manager=FakeMemManager(), - _gen_special_model_input=lambda token_num: {"mtp_draft_input_hiddens": None}, - ) - - model_input = graph._build_warmup_decode_model_input(model, batch_size=6, device="cpu") - - assert model_input.batch_size == 6 - assert model_input.b_mtp_index.tolist() == [0, 1, 2, 0, 1, 2] - assert model_input.b_seq_len.tolist() == [2, 3, 4, 2, 3, 4] - assert model_input.b_num_accepted_tokens.tolist() == [1, 1] - assert model_input.total_token_num == 18 - - -def test_mtp_decode_cuda_graph_warmup_builds_normal_layout_for_non_mtp(): - from lightllm.common.basemodel.cuda_graph import CudaGraph - - # A non-MTP model (mtp_step == 0) has a single, ungrouped decode layout. - graph = CudaGraph.__new__(CudaGraph) - graph.mtp_step = 0 - graph.graph_max_len_in_batch = 128 - - class FakeMemManager: - HOLD_TOKEN_MEMINDEX = -1 - - def alloc(self, size): - return torch.arange(size, dtype=torch.int32) - - model = SimpleNamespace( - req_manager=SimpleNamespace(HOLD_REQUEST_ID=99), - mem_manager=FakeMemManager(), - _gen_special_model_input=lambda token_num: {"mtp_draft_input_hiddens": torch.full((token_num, 4), 3.0)}, - ) - - model_input = graph._build_warmup_decode_model_input(model, batch_size=5, device="cpu") - - assert model_input.batch_size == 5 - assert model_input.b_mtp_index.tolist() == [0, 0, 0, 0, 0] - assert model_input.b_seq_len.tolist() == [2, 2, 2, 2, 2] - assert model_input.b_num_accepted_tokens is None - assert model_input.total_token_num == 10 - assert model_input.mtp_draft_input_hiddens.shape == (5, 4) - - -def test_mtp_decode_cuda_graph_key_is_batch_size(): - from lightllm.common.basemodel.cuda_graph import CudaGraph - - # Under MTP there is a single (mtp_step+1)-grouped decode layout, so the graph is keyed by - # batch size alone — no verify/normal distinction in the key. - graph = CudaGraph.__new__(CudaGraph) - graph.mtp_step = 2 - graph.graph = {} - graph.cuda_graph_batch_sizes = [3, 6, 9, 12] - - state = SimpleNamespace(input_ids=torch.ones(6, dtype=torch.int64)) - assert graph._decode_graph_key(state) == 6 - assert graph.find_closest_graph_batch_size(5) == 6 - - graph.graph[6] = "decode graph" - assert graph.need_capture(6) is False - assert graph.need_capture(3) is True - - -def test_mtp_decode_warmup_layout_marks_qwen3next_verify(monkeypatch): - import pytest - - if not torch.cuda.is_available(): - pytest.skip("needs CUDA for gen_decode_params") - - import lightllm.common.basemodel.mtp_verify_extra_state as mtp_verify_extra_state_mod - from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo - - monkeypatch.setattr(mtp_verify_extra_state_mod, "get_env_start_args", lambda: SimpleNamespace(mtp_step=2)) - - state = Qwen3NextInferStateInfo() - state.is_prefill = False - state.b_req_idx = torch.tensor([5, 5, 5, 6, 6, 6], dtype=torch.int32, device="cuda") - state.b_mtp_index = torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.int32, device="cuda") - state.b_seq_len = torch.tensor([2, 3, 4, 2, 3, 4], dtype=torch.int32, device="cuda") - state.b_num_accepted_tokens = torch.tensor([1, 2], dtype=torch.int32, device="cuda") - - model = SimpleNamespace( - _cos_cached=torch.zeros((16, 4), dtype=torch.float32, device="cuda"), - _sin_cached=torch.zeros((16, 4), dtype=torch.float32, device="cuda"), - ) - - state.init_some_extra_state(model) - - assert state.is_mtp_verify is True - assert state.b_gdn_verify_cu_seqlens.tolist() == [0, 3, 6] - assert state.b_conv_buffer_idx.tolist() == [5, 6] - assert state.b_ssm_index_rows.tolist() == [[15, 16, 17], [18, 19, 20]] - - -def test_mtp_decode_padding_preserves_verify_groups(monkeypatch): - import lightllm.common.basemodel.basemodel as basemodel_mod - - monkeypatch.setattr(basemodel_mod, "enable_diverse_mode_gqa_decode_fast_kernel", lambda: False) - - model = TpPartBaseModel.__new__(TpPartBaseModel) - model.args = SimpleNamespace(mtp_step=2) - model.req_manager = SimpleNamespace(HOLD_REQUEST_ID=99) - model.mem_manager = SimpleNamespace(HOLD_TOKEN_MEMINDEX=-1) - - model_input = ModelInput( - batch_size=3, - total_token_num=12, - max_q_seq_len=1, - max_kv_seq_len=4, - input_ids=torch.tensor([10, 11, 12], dtype=torch.int32), - mem_indexes=torch.tensor([20, 21, 22], dtype=torch.int32), - b_req_idx=torch.tensor([7, 7, 7], dtype=torch.int32), - b_mtp_index=torch.tensor([0, 1, 2], dtype=torch.int32), - b_seq_len=torch.tensor([2, 3, 4], dtype=torch.int32), - b_num_accepted_tokens=torch.tensor([2], dtype=torch.int32), - is_prefill=False, - multimodal_params=[{"images": [], "audios": []} for _ in range(3)], - ) - - padded = model._create_padded_decode_model_input(model_input, new_batch_size=6) - - assert padded.batch_size == 6 - assert padded.b_req_idx.tolist() == [7, 7, 7, 99, 99, 99] - assert padded.b_mtp_index.tolist() == [0, 1, 2, 0, 1, 2] - assert padded.b_seq_len.tolist() == [2, 3, 4, 2, 3, 4] - assert padded.b_num_accepted_tokens.tolist() == [2, 1] - assert padded.mem_indexes.tolist() == [20, 21, 22, -1, -1, -1] - assert len(padded.multimodal_params) == 6 - - -def test_qwen3next_hybrid_mtp_keeps_decode_cuda_graph_enabled(monkeypatch): - import lightllm.models.qwen3next.model as qwen3next_model - from lightllm.models.qwen3next.model import Qwen3NextTpPartModel - - monkeypatch.setattr(qwen3next_model, "get_env_start_args", lambda: SimpleNamespace(mtp_step=2)) - - called = {} - - def fake_base_init_cudagraph(self): - called["disable_cudagraph"] = self.disable_cudagraph - self.graph = "captured" - - monkeypatch.setattr(TpPartBaseModel, "_init_cudagraph", fake_base_init_cudagraph) - - model = Qwen3NextTpPartModel.__new__(Qwen3NextTpPartModel) - model.disable_cudagraph = False - - Qwen3NextTpPartModel._init_cudagraph(model) - - assert called["disable_cudagraph"] is False - assert model.disable_cudagraph is False - assert model.graph == "captured" diff --git a/unit_tests/common/test_init_linear_att_state_zeros_block.py b/unit_tests/common/test_init_linear_att_state_zeros_block.py deleted file mode 100644 index b20e489bce..0000000000 --- a/unit_tests/common/test_init_linear_att_state_zeros_block.py +++ /dev/null @@ -1,41 +0,0 @@ -import types -import torch - -# NOTE: importing lightllm.common.req_manager *first* trips a pre-existing circular import -# (req_manager line-8 imports gen_sampling_params -> basemodel -> infer_struct, which re-enters -# the half-initialized req_manager before ReqManager is defined). Importing basemodel first -# fully resolves that chain, after which ReqManagerForMamba imports cleanly. This is an -# import-ordering fix only; it does not alter the method-under-test or the duck-typed call below. -import lightllm.common.basemodel # noqa: F401 (resolves circular import; must precede req_manager) -from lightllm.common.req_manager import ReqManagerForMamba - - -class _Buf: - def __init__(self, t): - self.buffer = t - - -def test_init_zeros_full_ssm_block(): - mtp_step = 3 - layer, n_req = 2, 4 - conv_dim, width = 8, 3 - conv_buf = torch.ones(layer, n_req, conv_dim, width) - ssm_buf = torch.ones(layer, n_req * (mtp_step + 1), 5) - - dummy = types.SimpleNamespace( - mtp_step=mtp_step, - req_to_conv_state=_Buf(conv_buf), - req_to_ssm_state=_Buf(ssm_buf), - ) - req = types.SimpleNamespace(req_idx=2, mtp_accept_len=None) - - ReqManagerForMamba.init_linear_att_state(dummy, req) - - start = 2 * (mtp_step + 1) - block = ssm_buf[:, start : start + (mtp_step + 1), ...] - assert torch.count_nonzero(block) == 0, "all S+1 SSM rows of the block must be zeroed on init" - # other requests' rows must be untouched - assert torch.count_nonzero(ssm_buf[:, :start, ...]) > 0 - # conv slot for this request zeroed; canonical accept-len reset - assert torch.count_nonzero(conv_buf[:, 2, ...]) == 0 - assert req.mtp_accept_len == 1 diff --git a/unit_tests/common/test_linear_att_copy_guards.py b/unit_tests/common/test_linear_att_copy_guards.py deleted file mode 100644 index b6c48a0c86..0000000000 --- a/unit_tests/common/test_linear_att_copy_guards.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest -import torch - -from lightllm.common.basemodel.triton_kernel.linear_att_copy import ( - copy_linear_att_state_to_kv_buffer, -) - - -def _args(gpu_conv, accept_len, mtp_step): - layer_num = gpu_conv.shape[0] - dim_conv = gpu_conv.shape[2] - width_narrow = 3 - return dict( - b_req_idx=torch.tensor([0], dtype=torch.int32), - big_page_buffer_ids=torch.tensor([0], dtype=torch.int32), - gpu_conv_state=gpu_conv, - gpu_ssm_state=torch.zeros(layer_num, 1 * (mtp_step + 1), 8), - cpu_kv_conv_state=torch.zeros(1, layer_num, dim_conv, width_narrow), - cpu_kv_ssm_state=torch.zeros(1, layer_num, 8), - mtp_step=mtp_step, - b_num_accepted_tokens=torch.tensor([accept_len], dtype=torch.int32), - ) - - -def test_rejects_non_contiguous_width_axis(): - mtp_step = 2 - # widened slot allocated 2x, then strided ::2 along the width axis -> stride(3) == 2 - base = torch.zeros(2, 1, 32, (3 + mtp_step) * 2) - gpu_conv = base[:, :, :, ::2] - assert gpu_conv.stride(3) != 1 - with pytest.raises(AssertionError, match="width"): - copy_linear_att_state_to_kv_buffer(**_args(gpu_conv, accept_len=1, mtp_step=mtp_step)) - - -def test_rejects_out_of_range_accept_len(): - mtp_step = 2 - gpu_conv = torch.zeros(2, 1, 32, 3 + mtp_step) # contiguous, passes the #6 guard - with pytest.raises(AssertionError, match="b_num_accepted_tokens"): - copy_linear_att_state_to_kv_buffer(**_args(gpu_conv, accept_len=mtp_step + 2, mtp_step=mtp_step)) diff --git a/unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py b/unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py deleted file mode 100644 index fb22f0ed1f..0000000000 --- a/unit_tests/common/test_linear_att_mtp_cpu_cache_persistence.py +++ /dev/null @@ -1,219 +0,0 @@ -from types import SimpleNamespace - -import pytest -import torch - - -def _make_start_args(**overrides): - base = dict( - model_dir="/tmp/qwen3_5", - tp=1, - dp=1, - data_type="bfloat16", - linear_att_ssm_data_type="bfloat16", - mtp_mode=None, - mtp_step=0, - linear_att_page_block_num=2, - linear_att_hash_page_size=4, - cpu_cache_token_page_size=8, - ) - base.update(overrides) - return SimpleNamespace(**base) - - -def _make_model_cfg(): - return { - "model_type": "qwen3_5", - "num_hidden_layers": 64, - "num_key_value_heads": 16, - "head_dim": 128, - "linear_num_key_heads": 16, - "linear_num_value_heads": 48, - "linear_key_head_dim": 128, - "linear_value_head_dim": 128, - "linear_conv_kernel_dim": 4, - "full_attention_interval": 4, - } - - -def _patch_linear_config_args(monkeypatch, args): - import lightllm.common.linear_att_cache_manager.config_objs as config_objs - - monkeypatch.setattr(config_objs, "get_env_start_args", lambda: args) - - -def _make_config(draft_full_att_layer_num=0): - from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig - - return LinearAttCacheConfig( - tp_world_size=1, - full_att_all_num_kv_heads=16, - full_att_dtype=torch.bfloat16, - full_att_num_kv_heads=16, - full_att_head_dim=128, - num_linear_k_heads=16, - num_linear_v_heads=48, - head_linear_k_dim=128, - head_linear_v_dim=128, - conv_kernel_size=4, - linear_layer_num=48, - conv_state_dtype=torch.bfloat16, - ssm_state_dtype=torch.bfloat16, - full_attention_interval=4, - all_layer_num=64, - draft_full_att_layer_num=draft_full_att_layer_num, - ) - - -def test_load_from_args_includes_mtp_draft_full_att_layers(monkeypatch): - from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig - from transformers.configuration_utils import PretrainedConfig - - args = _make_start_args(mtp_mode="vanilla_with_att", mtp_step=3) - _patch_linear_config_args(monkeypatch, args) - monkeypatch.setattr(PretrainedConfig, "get_config_dict", lambda _model_path: (_make_model_cfg(), None)) - - cfg = LinearAttCacheConfig.load_from_args() - - assert cfg.get_main_full_att_layer_num() == 16 - assert cfg.draft_full_att_layer_num == 3 - assert cfg.get_persisted_full_att_layer_num() == 19 - - -def test_cpu_cache_full_att_bytes_include_mtp_draft_layers(monkeypatch): - args = _make_start_args() - _patch_linear_config_args(monkeypatch, args) - main_only = _make_config(draft_full_att_layer_num=0) - with_draft = _make_config(draft_full_att_layer_num=2) - - bytes_per_full_att_layer = ( - args.cpu_cache_token_page_size - * 2 - * main_only.full_att_all_num_kv_heads - * main_only.full_att_head_dim - * main_only.full_att_dtype.itemsize - ) - - assert main_only.get_main_full_att_layer_num() == 16 - assert with_draft.get_persisted_full_att_layer_num() == 18 - assert with_draft.get_cpu_cache_full_att_bytes() == ( - main_only.get_cpu_cache_full_att_bytes() + 2 * bytes_per_full_att_layer - ) - - -def test_linear_operator_persisted_full_att_slice_includes_draft_slots(): - from lightllm.common.kv_cache_mem_manager.operator.linear_att import LinearAttMemOperator - - class MtpMemManager: - main_full_att_layer_num = 16 - draft_full_att_layers = 2 - kv_buffer = torch.empty((18, 1)) - - class MainOnlyMemManager: - main_full_att_layer_num = 16 - kv_buffer = torch.empty((18, 1)) - - class PlainMemManager: - kv_buffer = torch.empty((7, 1)) - - assert LinearAttMemOperator._get_persisted_full_att_layer_num(MtpMemManager()) == 18 - assert LinearAttMemOperator._get_persisted_full_att_layer_num(MainOnlyMemManager()) == 16 - assert LinearAttMemOperator._get_persisted_full_att_layer_num(PlainMemManager()) == 7 - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") -def test_linear_cpu_cache_roundtrips_mtp_draft_full_att_slot(monkeypatch): - from lightllm.common.basemodel.triton_kernel.linear_att_cpu_cache_copy import ( - copy_cpu_cache_to_kv_buffer, - copy_kv_buffer_to_cpu_cache, - ) - from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig - - args = _make_start_args( - linear_att_page_block_num=1, - linear_att_hash_page_size=2, - cpu_cache_token_page_size=2, - ) - _patch_linear_config_args(monkeypatch, args) - cfg = LinearAttCacheConfig( - tp_world_size=1, - full_att_all_num_kv_heads=2, - full_att_dtype=torch.float32, - full_att_num_kv_heads=2, - full_att_head_dim=8, - num_linear_k_heads=1, - num_linear_v_heads=1, - head_linear_k_dim=8, - head_linear_v_dim=8, - conv_kernel_size=2, - linear_layer_num=1, - conv_state_dtype=torch.float32, - ssm_state_dtype=torch.float32, - full_attention_interval=2, - all_layer_num=2, - draft_full_att_layer_num=1, - ) - - gpu_kv = torch.arange(2 * 2 * 4 * 8, dtype=torch.float32, device="cuda").reshape(2, 2, 4, 8) - cpu_cache_tensor = torch.zeros( - (1, 1, 1, 1, cfg.get_cpu_cache_big_page_bytes()), - dtype=torch.uint8, - device="cuda", - ) - conv_state = torch.zeros( - (1, cfg.linear_layer_num, cfg.get_conv_dim(), cfg.conv_kernel_size - 1), - dtype=torch.float32, - device="cuda", - ) - ssm_state = torch.zeros( - ( - 1, - cfg.linear_layer_num, - cfg.num_linear_v_heads, - cfg.head_linear_k_dim, - cfg.head_linear_v_dim, - ), - dtype=torch.float32, - device="cuda", - ) - mem_indexes = torch.tensor([0, 1], dtype=torch.int32, device="cuda") - page_indexes = torch.tensor([0], dtype=torch.int32, device="cuda") - page_readies = torch.tensor([False], dtype=torch.bool, device="cuda") - big_page_buffer_ids = torch.tensor([0], dtype=torch.int64, device="cuda") - - copy_kv_buffer_to_cpu_cache( - mem_indexes=mem_indexes, - page_indexes=page_indexes, - page_readies=page_readies, - big_page_buffer_ids=big_page_buffer_ids, - gpu_kv_full_att_state=gpu_kv, - cpu_kv_conv_state=conv_state, - cpu_kv_ssm_state=ssm_state, - cpu_cache_tensor=cpu_cache_tensor, - tp_rank=0, - tp_world_size=1, - big_page_token_num=args.cpu_cache_token_page_size, - linear_config=cfg, - grid_num=1, - ) - - restored_gpu_kv = torch.full_like(gpu_kv, fill_value=-1) - restored_conv = torch.empty_like(conv_state) - restored_ssm = torch.empty_like(ssm_state) - copy_cpu_cache_to_kv_buffer( - mem_indexes=mem_indexes, - big_page_buffer_ids=big_page_buffer_ids, - page_indexes=page_indexes, - gpu_full_att_kv_state=restored_gpu_kv, - cpu_kv_conv_state=restored_conv, - cpu_kv_ssm_state=restored_ssm, - cpu_cache_tensor=cpu_cache_tensor, - tp_rank=0, - tp_world_size=1, - big_page_token_num=args.cpu_cache_token_page_size, - linear_config=cfg, - grid_num=1, - ) - torch.cuda.synchronize() - - torch.testing.assert_close(restored_gpu_kv, gpu_kv) diff --git a/unit_tests/common/test_linear_att_snapshot_split.py b/unit_tests/common/test_linear_att_snapshot_split.py deleted file mode 100644 index 2ce2833bcf..0000000000 --- a/unit_tests/common/test_linear_att_snapshot_split.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -import torch - -pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") - - -@pytest.mark.parametrize("S", [1, 2, 3]) -@pytest.mark.parametrize("accept_len", [1, 2]) -def test_snapshot_reads_committed_conv_and_ssm(S, accept_len): - from lightllm.common.basemodel.triton_kernel.linear_att_copy import ( - copy_linear_att_state_to_kv_buffer, - ) - - layer_num, dim_conv = 2, 32 - width_narrow = 3 - gpu_conv = torch.zeros(layer_num, 1, dim_conv, width_narrow + S, device="cuda") - off = accept_len - 1 - marker_conv = torch.arange(dim_conv * width_narrow, device="cuda").float().reshape(dim_conv, width_narrow) - gpu_conv[:, 0, :, off : off + width_narrow] = marker_conv - - hv, k, v = 4, 8, 8 - gpu_ssm = torch.zeros(layer_num, 1 * (S + 1), hv, k, v, device="cuda") - marker_ssm = torch.arange(hv * k * v, device="cuda").float().reshape(hv, k, v) - gpu_ssm[:, off, ...] = marker_ssm # block slot 0*(S+1)+off - - cpu_conv = torch.zeros(1, layer_num, dim_conv, width_narrow, device="cuda") - cpu_ssm = torch.zeros(1, layer_num, hv, k, v, device="cuda") - - copy_linear_att_state_to_kv_buffer( - b_req_idx=torch.tensor([0], dtype=torch.int32, device="cuda"), - big_page_buffer_ids=torch.tensor([0], dtype=torch.int32, device="cuda"), - gpu_conv_state=gpu_conv, - gpu_ssm_state=gpu_ssm, - cpu_kv_conv_state=cpu_conv, - cpu_kv_ssm_state=cpu_ssm, - mtp_step=S, - b_num_accepted_tokens=torch.tensor([accept_len], dtype=torch.int32, device="cuda"), - ) - - torch.testing.assert_close(cpu_conv[0], marker_conv.expand(layer_num, dim_conv, width_narrow)) - torch.testing.assert_close(cpu_ssm[0], marker_ssm.expand(layer_num, hv, k, v)) diff --git a/unit_tests/common/test_mtp_verify_extra_state.py b/unit_tests/common/test_mtp_verify_extra_state.py deleted file mode 100644 index 7252af0736..0000000000 --- a/unit_tests/common/test_mtp_verify_extra_state.py +++ /dev/null @@ -1,36 +0,0 @@ -import types -import torch - -import lightllm.common.basemodel.mtp_verify_extra_state as mod - - -def _state(n_real, mtp_step, is_prefill=False, with_accept=True): - step = mtp_step + 1 - s = types.SimpleNamespace() - s.b_seq_len = torch.arange(1, n_real * step + 1, dtype=torch.int32) - s.b_req_idx = torch.arange(n_real, dtype=torch.int32).repeat_interleave(step) - s.b_mtp_index = torch.arange(step, dtype=torch.int32).repeat(n_real) - s.is_prefill = is_prefill - s.b_num_accepted_tokens = torch.ones(n_real, dtype=torch.int32) if with_accept else None - return s - - -def test_verify_branch_sets_index_rows(monkeypatch): - monkeypatch.setattr(mod, "get_env_start_args", lambda: types.SimpleNamespace(mtp_step=2)) - n_real, mtp_step = 3, 2 - step = mtp_step + 1 - s = _state(n_real, mtp_step) - mod.init_mtp_verify_extra_state(s) - assert s.is_mtp_verify is True - assert s.b_ssm_index_rows.shape == (n_real, step) - assert s.b_gdn_verify_cu_seqlens.tolist() == [0, 3, 6, 9] - assert s.b_conv_buffer_idx.tolist() == [0, 1, 2] # one widened conv slot per req - - -def test_non_verify_branch_no_index_rows(monkeypatch): - monkeypatch.setattr(mod, "get_env_start_args", lambda: types.SimpleNamespace(mtp_step=2)) - s = _state(3, 2, with_accept=False) - mod.init_mtp_verify_extra_state(s) - assert s.is_mtp_verify is False - assert s.b_ssm_index_rows is None - assert s.b_gdn_verify_cu_seqlens is None diff --git a/unit_tests/common/test_qwen3next_linear_att_page_helper.py b/unit_tests/common/test_qwen3next_linear_att_page_helper.py deleted file mode 100644 index e4c2e71c76..0000000000 --- a/unit_tests/common/test_qwen3next_linear_att_page_helper.py +++ /dev/null @@ -1,112 +0,0 @@ -from types import SimpleNamespace - -import torch - - -class _Buf: - def __init__(self, tensor): - self.buffer = tensor - - -def _make_config(): - return SimpleNamespace( - tp_world_size=1, - linear_layer_num=1, - conv_kernel_size=4, - global_linear_k_heads=1, - global_linear_v_heads=1, - num_linear_k_heads=1, - num_linear_v_heads=1, - head_linear_k_dim=2, - head_linear_v_dim=3, - ) - - -def _make_mem(mtp_step=2, req_slots=4): - config = _make_config() - conv_dim = ( - 2 * config.num_linear_k_heads * config.head_linear_k_dim - + config.num_linear_v_heads * config.head_linear_v_dim - ) - narrow_w = config.conv_kernel_size - 1 - conv = torch.full( - (config.linear_layer_num, req_slots, conv_dim, narrow_w + mtp_step), - -9.0, - dtype=torch.float32, - ) - ssm = torch.full( - ( - config.linear_layer_num, - req_slots * (mtp_step + 1), - config.num_linear_v_heads, - config.head_linear_k_dim, - config.head_linear_v_dim, - ), - -11.0, - dtype=torch.float32, - ) - return SimpleNamespace( - linear_config=config, - req_to_conv_state=_Buf(conv), - req_to_ssm_state=_Buf(ssm), - kv_move_buffer=torch.zeros((1, 4096), dtype=torch.uint8), - ) - - -def test_page_helper_writes_req_conv_slot_and_narrow_width(monkeypatch): - import lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager as qwen3next_mem_manager - from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextLinearAttPageHelper - - mtp_step = 2 - req_idx = 2 - monkeypatch.setattr(qwen3next_mem_manager, "get_env_start_args", lambda: SimpleNamespace(mtp_step=mtp_step)) - - mem = _make_mem(mtp_step=mtp_step) - helper = Qwen3NextLinearAttPageHelper(mem) - mem.kv_move_buffer = torch.zeros((1, helper.state_nbytes), dtype=torch.uint8) - - narrow_w = helper.conv_shape[-1] - marker_conv = torch.arange( - helper.conv_shape[0] * helper.conv_shape[1] * narrow_w, - dtype=torch.float32, - ).view(helper.conv_shape) - marker_ssm = torch.arange( - helper.ssm_shape[0] * helper.ssm_shape[1] * helper.ssm_shape[2] * helper.ssm_shape[3], - dtype=torch.float32, - ).view(helper.ssm_shape) - - mem.req_to_conv_state.buffer[:, req_idx, :, :narrow_w] = marker_conv - mem.req_to_conv_state.buffer[:, req_idx, :, narrow_w:] = 999.0 - mem.req_to_ssm_state.buffer[:, req_idx * (mtp_step + 1), ...] = marker_ssm - - helper.write_req_to_page(page_index=0, req_idx=req_idx, dp_mems=[mem]) - - conv_page, ssm_page = helper.view_page_to_linear_att_state(page_index=0) - torch.testing.assert_close(conv_page, marker_conv) - torch.testing.assert_close(ssm_page, marker_ssm) - - -def test_page_helper_restores_narrow_conv_to_req_slot(monkeypatch): - import lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager as qwen3next_mem_manager - from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextLinearAttPageHelper - - mtp_step = 2 - req_idx = 2 - monkeypatch.setattr(qwen3next_mem_manager, "get_env_start_args", lambda: SimpleNamespace(mtp_step=mtp_step)) - - mem = _make_mem(mtp_step=mtp_step) - helper = Qwen3NextLinearAttPageHelper(mem) - mem.kv_move_buffer = torch.zeros((1, helper.state_nbytes), dtype=torch.uint8) - conv_page, ssm_page = helper.view_page_to_linear_att_state(page_index=0) - - marker_conv = torch.arange(conv_page.numel(), dtype=torch.float32).view_as(conv_page) - marker_ssm = torch.arange(ssm_page.numel(), dtype=torch.float32).view_as(ssm_page) - conv_page.copy_(marker_conv) - ssm_page.copy_(marker_ssm) - - helper.read_page_to_req(page_index=0, req_idx=req_idx, dp_mems=[mem]) - - narrow_w = helper.conv_shape[-1] - torch.testing.assert_close(mem.req_to_conv_state.buffer[:, req_idx, :, :narrow_w], marker_conv) - assert torch.all(mem.req_to_conv_state.buffer[:, req_idx, :, narrow_w:] == -9.0) - torch.testing.assert_close(mem.req_to_ssm_state.buffer[:, req_idx * (mtp_step + 1), ...], marker_ssm) diff --git a/unit_tests/models/qwen3next/test_causal_conv1d_spec.py b/unit_tests/models/qwen3next/test_causal_conv1d_spec.py deleted file mode 100644 index e99497ec33..0000000000 --- a/unit_tests/models/qwen3next/test_causal_conv1d_spec.py +++ /dev/null @@ -1,147 +0,0 @@ -import pytest -import torch - -pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") - - -def _eager_conv_update(x_seq, conv_state, weight, bias, activation): - # x_seq: (dim, seqlen) tokens to roll in, conv_state: (dim, width-1) history - dim, width = weight.shape - state = conv_state.clone() # (dim, width-1) - outs = [] - for t in range(x_seq.shape[1]): - window = torch.cat([state, x_seq[:, t : t + 1]], dim=1) # (dim, width) - y = (window * weight).sum(dim=1) # depthwise conv - if bias is not None: - y = y + bias - if activation in ("silu", "swish"): - y = torch.nn.functional.silu(y) - outs.append(y) - state = window[:, 1:] # slide - return torch.stack(outs, dim=1), state - - -@pytest.mark.parametrize("S", [0, 1, 2, 3]) -def test_spec_conv_matches_eager_after_partial_accept(S): - from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update - - torch.manual_seed(0) - dim, width = 64, 4 - seqlen = S + 1 - state_len = (width - 1) + S - device = "cuda" - dtype = torch.float32 - - weight = torch.randn(dim, width, device=device, dtype=dtype) - bias = torch.randn(dim, device=device, dtype=dtype) - - conv_state = torch.zeros(1, dim, state_len, device=device, dtype=dtype) - committed_hist = torch.randn(dim, width - 1, device=device, dtype=dtype) - conv_state[0, :, : width - 1] = committed_hist - - x = torch.randn(seqlen, dim, device=device, dtype=dtype) # candidate tokens - - out = causal_conv1d_update( - x.clone(), - conv_state, - weight, - bias=bias, - activation="silu", - conv_state_indices=torch.zeros(1, dtype=torch.int32, device=device), - num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), # fresh: read offset 0 - query_start_loc=torch.tensor([0, seqlen], dtype=torch.int32, device=device), - ) - - ref_out, _ = _eager_conv_update(x.t(), committed_hist, weight, bias, "silu") - torch.testing.assert_close(out.t(), ref_out, rtol=1e-3, atol=1e-3) - - -@pytest.mark.parametrize("S", [1, 2, 3]) -def test_spec_conv_reads_from_partial_accept_offset(S): - # Exercise the nonzero read offset: num_accepted_tokens=2 -> read offset 1. - # The widened slot front-loads a STALE token then the real committed history; - # the kernel must read history starting at (num_accepted_tokens-1)==1, i.e. - # conv_state[:, 1:width], NOT the stale token at index 0. - from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update - - torch.manual_seed(0) - dim, width = 64, 4 - seqlen = S + 1 - state_len = (width - 1) + S - device = "cuda" - dtype = torch.float32 - - weight = torch.randn(dim, width, device=device, dtype=dtype) - bias = torch.randn(dim, device=device, dtype=dtype) - - conv_state = torch.zeros(1, dim, state_len, device=device, dtype=dtype) - # tokens [0 .. width-1] hold [stale, h1, h2, ...]: a stale front token then history - seed = torch.randn(dim, width, device=device, dtype=dtype) - conv_state[0, :, :width] = seed - stale_front = conv_state[0, :, :width].clone() # snapshot of the seeded window - - x = torch.randn(seqlen, dim, device=device, dtype=dtype) # candidate tokens - - out = causal_conv1d_update( - x.clone(), - conv_state, - weight, - bias=bias, - activation="silu", - conv_state_indices=torch.zeros(1, dtype=torch.int32, device=device), - num_accepted_tokens=2 * torch.ones(1, dtype=torch.int32, device=device), # read offset 1 - query_start_loc=torch.tensor([0, seqlen], dtype=torch.int32, device=device), - ) - - # Eager reference starts from the offset-1 window: committed history excluding - # the stale front token == conv_state[:, 1:width]. - committed_hist = stale_front[:, 1:width] - ref_out, _ = _eager_conv_update(x.t(), committed_hist, weight, bias, "silu") - torch.testing.assert_close(out.t(), ref_out, rtol=1e-3, atol=1e-3) - - -def test_spec_conv_varlen_update_is_cuda_graph_capturable(): - from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update - - torch.manual_seed(0) - dim, width, S = 64, 4, 1 - seqlen = S + 1 - state_len = (width - 1) + S - device = "cuda" - dtype = torch.float32 - - weight = torch.randn(dim, width, device=device, dtype=dtype) - bias = torch.randn(dim, device=device, dtype=dtype) - conv_state = torch.zeros(1, dim, state_len, device=device, dtype=dtype) - x = torch.randn(seqlen, dim, device=device, dtype=dtype) - conv_state_indices = torch.zeros(1, dtype=torch.int32, device=device) - num_accepted_tokens = torch.ones(1, dtype=torch.int32, device=device) - query_start_loc = torch.tensor([0, seqlen], dtype=torch.int32, device=device) - - # Compile/warm the Triton kernel before capture; the regression is the wrapper's - # host sync on query_start_loc during capture, not first-use compilation. - causal_conv1d_update( - x.clone(), - conv_state, - weight, - bias=bias, - activation="silu", - conv_state_indices=conv_state_indices, - num_accepted_tokens=num_accepted_tokens, - query_start_loc=query_start_loc, - ) - torch.cuda.synchronize() - - graph = torch.cuda.CUDAGraph() - static_x = x.clone() - with torch.cuda.graph(graph): - causal_conv1d_update( - static_x, - conv_state, - weight, - bias=bias, - activation="silu", - conv_state_indices=conv_state_indices, - num_accepted_tokens=num_accepted_tokens, - query_start_loc=query_start_loc, - ) diff --git a/unit_tests/models/qwen3next/test_conv_prefill_decode_roundtrip.py b/unit_tests/models/qwen3next/test_conv_prefill_decode_roundtrip.py deleted file mode 100644 index 2fca8bfc57..0000000000 --- a/unit_tests/models/qwen3next/test_conv_prefill_decode_roundtrip.py +++ /dev/null @@ -1,74 +0,0 @@ -import pytest -import torch - -pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") - - -def _eager_conv_update(x_seq, conv_state, weight, bias, activation): - # x_seq: (dim, seqlen) tokens to roll in, conv_state: (dim, width-1) history - state = conv_state.clone() - outs = [] - for t in range(x_seq.shape[1]): - window = torch.cat([state, x_seq[:, t : t + 1]], dim=1) # (dim, width) - y = (window * weight).sum(dim=1) - if bias is not None: - y = y + bias - if activation in ("silu", "swish"): - y = torch.nn.functional.silu(y) - outs.append(y) - state = window[:, 1:] - return torch.stack(outs, dim=1), state - - -@pytest.mark.parametrize("S", [1, 2, 3]) -def test_prefill_writes_first_columns_then_decode_reads_them(S): - from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn - from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import causal_conv1d_update - - torch.manual_seed(0) - dim, width = 64, 4 - prefill_len = 7 - state_len = (width - 1) + S # widened slot - device, dtype = "cuda", torch.float32 - - weight = torch.randn(dim, width, device=device, dtype=dtype) - bias = torch.randn(dim, device=device, dtype=dtype) - - # ---- PREFILL: populate one widened conv slot from a fresh (no initial state) sequence ---- - conv_states = torch.zeros(1, dim, state_len, device=device, dtype=dtype) - x_prefill = torch.randn(dim, prefill_len, device=device, dtype=dtype) # (dim, total_tokens) - causal_conv1d_fn( - x_prefill.clone(), - weight, - bias=bias, - query_start_loc=torch.tensor([0, prefill_len], dtype=torch.int32, device=device), - cache_indices=torch.zeros(1, dtype=torch.int32, device=device), - has_initial_state=torch.zeros(1, dtype=torch.bool, device=device), - conv_states=conv_states, - activation="silu", - ) - - # Contract (a): committed state lands in the FIRST width-1 columns; widened tail untouched. - committed_hist = conv_states[0, :, : width - 1].clone() - expected_hist = x_prefill[:, -(width - 1) :] # trailing window for a fresh causal conv - torch.testing.assert_close(committed_hist, expected_hist, rtol=1e-3, atol=1e-3) - if state_len > width - 1: - assert torch.count_nonzero(conv_states[0, :, width - 1 :]) == 0, "widened tail must be untouched by prefill" - - # ---- FIRST DECODE: verify reads at offset accept_len-1 == 0 -> columns [0:width-1] ---- - seqlen = S + 1 - x_decode = torch.randn(seqlen, dim, device=device, dtype=dtype) - out = causal_conv1d_update( - x_decode.clone(), - conv_states, - weight, - bias=bias, - activation="silu", - conv_state_indices=torch.zeros(1, dtype=torch.int32, device=device), - num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), # offset 0 - query_start_loc=torch.tensor([0, seqlen], dtype=torch.int32, device=device), - ) - - # Contract (b): decode output must match an eager conv seeded from the prefill-written history. - ref_out, _ = _eager_conv_update(x_decode.t(), committed_hist, weight, bias, "silu") - torch.testing.assert_close(out.t(), ref_out, rtol=1e-3, atol=1e-3) diff --git a/unit_tests/models/qwen3next/test_gdn_verify_equivalence.py b/unit_tests/models/qwen3next/test_gdn_verify_equivalence.py deleted file mode 100644 index 7481607d54..0000000000 --- a/unit_tests/models/qwen3next/test_gdn_verify_equivalence.py +++ /dev/null @@ -1,194 +0,0 @@ -import pytest -import torch - -pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") - - -@pytest.mark.parametrize("S", [1, 2, 3]) -def test_gdn_verify_state_equals_sequential_decode(S): - from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( - fused_recurrent_gated_delta_rule, - ) - - torch.manual_seed(0) - HV, K, V = 4, 16, 16 - T = S + 1 - device = "cuda" - - def rand_qkv(t): - q = torch.randn(1, t, HV, K, device=device) - k = torch.nn.functional.normalize(torch.randn(1, t, HV, K, device=device), dim=-1) - v = torch.randn(1, t, HV, V, device=device) - g = torch.nn.functional.logsigmoid(torch.rand(1, t, HV, device=device)) - beta = torch.rand(1, t, HV, device=device).sigmoid() - return q, k, v, g, beta - - q, k, v, g, beta = rand_qkv(T) - - ref_state = torch.zeros(1, HV, K, V, device=device) - for t in range(T): - _, ref_state = fused_recurrent_gated_delta_rule( - q=q[:, t : t + 1], - k=k[:, t : t + 1], - v=v[:, t : t + 1], - g=g[:, t : t + 1], - beta=beta[:, t : t + 1], - initial_state=ref_state, - inplace_final_state=False, - ) - - block = torch.zeros(T, HV, K, V, device=device) - ssm_idx = torch.arange(T, device=device).view(1, T) - fused_recurrent_gated_delta_rule( - q=q, - k=k, - v=v, - g=g, - beta=beta, - initial_state=block, - inplace_final_state=True, - cu_seqlens=torch.tensor([0, T], dtype=torch.long, device=device), - ssm_state_indices=ssm_idx, - ssm_state_write_indices=ssm_idx, - num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), - ) - torch.testing.assert_close(block[T - 1], ref_state[0], rtol=2e-2, atol=2e-2) - - -@pytest.mark.parametrize("S", [1, 2, 3]) -def test_gdn_verify_output_equals_sequential_decode_fused(S): - """H1: the LIVE verify combination - varlen + FUSED gating (A_log/dt_bias/a_raw/b_raw) - + spec-decode - must produce per-position OUTPUT o[t] identical to running the proven - T=1 decode recurrence sequentially. The original test only checked the final SSM state - with EXPLICIT g/beta; it never verified o[t] nor the fused-gating path that - _gdn_verify_kernel actually uses.""" - from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( - fused_recurrent_gated_delta_rule, - ) - - torch.manual_seed(0) - HV, K, V = 4, 16, 16 - H = HV - T = S + 1 - device = "cuda" - - q = torch.randn(1, T, H, K, device=device) - k = torch.nn.functional.normalize(torch.randn(1, T, H, K, device=device), dim=-1) - v = torch.randn(1, T, HV, V, device=device) - # Raw gating inputs (pre-activation), exactly as the model feeds the fused path. - a_raw = torch.randn(T, HV, device=device) - b_raw = torch.randn(T, HV, device=device) - A_log = torch.randn(HV, device=device) - dt_bias = torch.randn(HV, device=device) - - # Reference: sequential T=1 decode through the proven non-varlen fused path. - ref_state = torch.zeros(1, HV, K, V, device=device) - ref_o = torch.zeros(T, HV, V, device=device) - for t in range(T): - o_t, ref_state = fused_recurrent_gated_delta_rule( - q=q[:, t : t + 1], - k=k[:, t : t + 1], - v=v[:, t : t + 1], - initial_state=ref_state, - inplace_final_state=False, - use_qk_l2norm_in_kernel=True, - A_log=A_log, - dt_bias=dt_bias, - a_raw=a_raw[t : t + 1], - b_raw=b_raw[t : t + 1], - ) - ref_o[t] = o_t[0, 0] - - # Verify path: single varlen call with fused gating + spec-decode indices, - # mirroring _gdn_verify_kernel for a single request, num_accepted=1. - block = torch.zeros(T, HV, K, V, device=device) - ssm_idx = torch.arange(T, device=device).view(1, T) - o, _ = fused_recurrent_gated_delta_rule( - q=q, - k=k, - v=v, - initial_state=block, - inplace_final_state=True, - cu_seqlens=torch.tensor([0, T], dtype=torch.long, device=device), - ssm_state_indices=ssm_idx, - ssm_state_write_indices=ssm_idx, - num_accepted_tokens=torch.ones(1, dtype=torch.int32, device=device), - use_qk_l2norm_in_kernel=True, - A_log=A_log, - dt_bias=dt_bias, - a_raw=a_raw, - b_raw=b_raw, - ) - o = o.view(T, HV, V) - torch.testing.assert_close(o, ref_o, rtol=2e-2, atol=2e-2) - torch.testing.assert_close(block[T - 1], ref_state[0], rtol=2e-2, atol=2e-2) - - -@pytest.mark.parametrize("num_accepted", [1, 2]) -def test_gdn_verify_reads_committed_slot_by_num_accepted(num_accepted): - """The verify kernel must read the per-request initial state from the SSM block - slot at offset (num_accepted-1) -- i.e. the state committed after the previous - step's last accepted token. This is the read path exercised by the FIRST decode - after an accept-`num_accepted` step. A decoy is written into the OTHER block slot - to prove the kernel reads the correct one and ignores the rest of the (S+1) block.""" - from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( - fused_recurrent_gated_delta_rule, - ) - - torch.manual_seed(0) - HV, K, V = 4, 16, 16 - S = 1 - T = S + 1 - device = "cuda" - - q = torch.randn(1, T, HV, K, device=device) - k = torch.nn.functional.normalize(torch.randn(1, T, HV, K, device=device), dim=-1) - v = torch.randn(1, T, HV, V, device=device) - a_raw = torch.randn(T, HV, device=device) - b_raw = torch.randn(T, HV, device=device) - A_log = torch.randn(HV, device=device) - dt_bias = torch.randn(HV, device=device) - - # (S+1) block: the committed slot is (num_accepted-1); the others hold decoys - # that MUST NOT be read. - block = torch.randn(T, HV, K, V, device=device) * 5.0 - committed = torch.randn(1, HV, K, V, device=device) - block[num_accepted - 1] = committed[0] - - ref_state = committed.clone() - ref_o = torch.zeros(T, HV, V, device=device) - for t in range(T): - o_t, ref_state = fused_recurrent_gated_delta_rule( - q=q[:, t : t + 1], - k=k[:, t : t + 1], - v=v[:, t : t + 1], - initial_state=ref_state, - inplace_final_state=False, - use_qk_l2norm_in_kernel=True, - A_log=A_log, - dt_bias=dt_bias, - a_raw=a_raw[t : t + 1], - b_raw=b_raw[t : t + 1], - ) - ref_o[t] = o_t[0, 0] - - blk = block.clone() - ssm_idx = torch.arange(T, device=device).view(1, T) - o, _ = fused_recurrent_gated_delta_rule( - q=q, - k=k, - v=v, - initial_state=blk, - inplace_final_state=True, - cu_seqlens=torch.tensor([0, T], dtype=torch.long, device=device), - ssm_state_indices=ssm_idx, - ssm_state_write_indices=ssm_idx, - num_accepted_tokens=torch.tensor([num_accepted], dtype=torch.int32, device=device), - use_qk_l2norm_in_kernel=True, - A_log=A_log, - dt_bias=dt_bias, - a_raw=a_raw, - b_raw=b_raw, - ) - o = o.view(T, HV, V) - torch.testing.assert_close(o, ref_o, rtol=2e-2, atol=2e-2) From e70909736170912be2b95f113b2ff3fda1a65911 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 16 Jun 2026 15:40:07 +0800 Subject: [PATCH 15/19] style: black-format fp8.py k/v_descale lines (pre-commit) --- lightllm/common/basemodel/attention/fa3/fp8.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index c1861aad2c..99fd864bf2 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -45,9 +45,12 @@ def init_state(self): torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len ) # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) def prefill_att( self, @@ -123,8 +126,12 @@ def init_state(self): head_num = mem_manager.head_num # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) return From bca81f9551d56e681863357379ca41a660f4b960 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 16 Jun 2026 15:43:53 +0800 Subject: [PATCH 16/19] clean code --- lightllm/common/req_manager.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 164d04fc3b..6a67993b4e 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -86,10 +86,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num - # MTP verify decode 的 per-req accept 数量:GPU 常驻、按 req_idx 索引(含 HOLD 槽)。 - # 取代旧的 req.mtp_accept_len host 属性 —— verify 后在 GPU 上 scatter,下一步在 GDN 的 - # init_mtp_verify_extra_state 里按 req_first gather 成 b_num_accepted_tokens,省掉每步的 - # host 回写 + H2D 重建。HOLD 槽恒为 1,使 padding 组 gather 到 1。仅 mtp_step>0 时分配。 + self.req_to_accept_len = ( torch.ones((max_request_num + 1,), dtype=torch.int32, device="cuda") if get_env_start_args().mtp_step > 0 From c0293e92e49b324a161d3799ea1c69a4661fa332 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 17 Jun 2026 14:20:45 +0800 Subject: [PATCH 17/19] fix code --- lightllm/common/req_manager.py | 9 ++++----- test/benchmark/static_inference/model_infer.py | 2 +- test/benchmark/static_inference/model_infer_mtp.py | 2 +- test/benchmark/static_inference/test_model.py | 6 ++++++ 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 6a67993b4e..9be22b23bb 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -87,11 +87,10 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num - self.req_to_accept_len = ( - torch.ones((max_request_num + 1,), dtype=torch.int32, device="cuda") - if get_env_start_args().mtp_step > 0 - else None - ) + # Always resident: infer_batch.copy_linear_att_state_to_cache_buffer reads this + # unconditionally in the linear-att cache-copy path (even for mtp_step=0, where + # accept_len=1 is the correct non-widened value). MTP overwrites the live slots. + self.req_to_accept_len = torch.ones((max_request_num + 1,), dtype=torch.int32, device="cuda") def alloc(self): return self.req_list.alloc() diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index f2c900af09..c640a1152c 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -36,7 +36,7 @@ def test_model_inference(args): "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, "mem_fraction": args.mem_fraction, - "max_req_num": 2048, + "max_req_num": args.static_max_req_num, "batch_max_tokens": 1024, "run_mode": "normal", "max_seq_length": args.max_req_total_len, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index 72f06a919c..7689fc97df 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -69,7 +69,7 @@ def test_model_inference_mtp(args): "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, "mem_faction": args.mem_fraction, - "max_req_num": 2000, + "max_req_num": args.static_max_req_num, "batch_max_tokens": 2048, "run_mode": "normal", "max_seq_length": args.max_req_total_len, diff --git a/test/benchmark/static_inference/test_model.py b/test/benchmark/static_inference/test_model.py index 5b3751bcc3..78b841bbeb 100644 --- a/test/benchmark/static_inference/test_model.py +++ b/test/benchmark/static_inference/test_model.py @@ -30,6 +30,12 @@ def test_model_infer(self): parser.add_argument("--batch_size", type=int, default=None, help="batch size") parser.add_argument("--input_len", type=int, default=64, help="input sequence length") parser.add_argument("--output_len", type=int, default=128, help="output sequence length") + parser.add_argument( + "--static_max_req_num", + type=int, + default=2048, + help="max_req_num used by the standalone static benchmark harness", + ) parser.add_argument( "--profile", action="store_true", From 1f69b05cd491c118edf4cdf69470e2ecbe313028 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 17 Jun 2026 16:56:19 +0800 Subject: [PATCH 18/19] Remove extra model request padding --- lightllm/server/router/manager.py | 2 +- test/router/test_model_kvargs.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 test/router/test_model_kvargs.py diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index dfb8866601..c0560419ff 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -145,7 +145,7 @@ async def wait_to_model_ready(self): "weight_dir": self.model_weightdir, "load_way": self.load_way, "max_total_token_num": self.max_total_token_num, - "max_req_num": self.args.running_max_req_size + 8, + "max_req_num": self.args.running_max_req_size, "max_seq_length": self.args.max_req_total_len + 8, # 留一点余量 "nccl_host": self.args.nccl_host, "nccl_port": self.args.nccl_port, diff --git a/test/router/test_model_kvargs.py b/test/router/test_model_kvargs.py new file mode 100644 index 0000000000..6f8f9d2e36 --- /dev/null +++ b/test/router/test_model_kvargs.py @@ -0,0 +1,17 @@ +import ast +from pathlib import Path + + +def test_model_kvargs_uses_running_max_req_size_without_extra_padding(): + source = Path("lightllm/server/router/manager.py").read_text() + module = ast.parse(source) + + for node in ast.walk(module): + if not isinstance(node, ast.Dict): + continue + for key, value in zip(node.keys, node.values): + if isinstance(key, ast.Constant) and key.value == "max_req_num": + assert ast.unparse(value) == "self.args.running_max_req_size" + return + + raise AssertionError("max_req_num kvarg was not found") From 4bd83cef87095690b03e5ac8f6c845cd762b0bdc Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 17 Jun 2026 17:02:14 +0800 Subject: [PATCH 19/19] Log linear attention state buffer memory --- lightllm/common/req_manager.py | 19 +++++++++++++++++++ test/common/test_req_manager.py | 20 ++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 test/common/test_req_manager.py diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 9be22b23bb..38a6d37727 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -23,6 +23,12 @@ REQ_NEXT_TOKEN_IDS_WIDTH = 8 +def _format_nbytes(nbytes: int) -> str: + mib = nbytes / (1024**2) + gib = nbytes / (1024**3) + return f"{mib:.2f} MiB ({gib:.2f} GiB)" + + def assert_mtp_step_within_next_token_ids_width(mtp_step: int) -> None: assert mtp_step <= REQ_NEXT_TOKEN_IDS_WIDTH - 1, ( f"mtp_step={mtp_step} exceeds {REQ_NEXT_TOKEN_IDS_WIDTH - 1}; " @@ -270,6 +276,19 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_con layer_num=self.linear_config.linear_layer_num, device="cuda", ) + conv_buffer = self.req_to_conv_state.buffer + ssm_buffer = self.req_to_ssm_state.buffer + conv_nbytes = conv_buffer.numel() * conv_buffer.element_size() + ssm_nbytes = ssm_buffer.numel() * ssm_buffer.element_size() + logger.info( + "linear att gpu state buffers: " + f"max_request_num={max_request_num}, hold_request_id={self.HOLD_REQUEST_ID}, mtp_step={self.mtp_step}, " + f"conv_state shape={tuple(conv_buffer.shape)}, dtype={conv_buffer.dtype}, " + f"nbytes={conv_nbytes}, memory={_format_nbytes(conv_nbytes)}; " + f"ssm_state shape={tuple(ssm_buffer.shape)}, dtype={ssm_buffer.dtype}, " + f"nbytes={ssm_nbytes}, memory={_format_nbytes(ssm_nbytes)}; " + f"total memory={_format_nbytes(conv_nbytes + ssm_nbytes)}" + ) return def init_linear_att_state(self, req: "InferReq"): diff --git a/test/common/test_req_manager.py b/test/common/test_req_manager.py new file mode 100644 index 0000000000..7ab9062d7c --- /dev/null +++ b/test/common/test_req_manager.py @@ -0,0 +1,20 @@ +import ast +from pathlib import Path + + +def test_linear_att_state_buffer_log_reports_shape_and_memory(): + source = Path("lightllm/common/req_manager.py").read_text() + module = ast.parse(source) + + class_node = next( + node for node in module.body if isinstance(node, ast.ClassDef) and node.name == "ReqManagerForMamba" + ) + init_node = next(node for node in class_node.body if isinstance(node, ast.FunctionDef) and node.name == "__init__") + init_source = ast.unparse(init_node) + + assert "logger.info" in init_source + assert "conv_state shape=" in init_source + assert "ssm_state shape=" in init_source + assert "total memory=" in init_source + assert "_format_nbytes(conv_nbytes)" in init_source + assert "_format_nbytes(ssm_nbytes)" in init_source