diff --git a/.gitignore b/.gitignore index 9b69e2eb4c..67a0db0b4c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,7 @@ dist .vscode tmp/ requirements-musa.txt -logs/ \ No newline at end of file +logs/ + +/benchmark/ +artifacts/ diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 594e81a9b4..08d15294bd 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -100,7 +100,7 @@ def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashi return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list) -def get_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend: +def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type backend_str = args.llm_decode_att_backend[index] @@ -120,7 +120,7 @@ def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "fl return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list) -def get_mla_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend: +def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type backend_str = args.llm_decode_att_backend[index] diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..48deabaf4b 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -3,12 +3,16 @@ from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id -from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.sgl_utils import flash_attn_with_kvcache, get_scheduler_metadata from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +_DECODE_MAX_NUM_SPLITS = 32 +_DECODE_PACK_GQA = True + + class Fa3AttBackend(BaseAttBackend): def __init__(self, model): super().__init__(model=model) @@ -119,6 +123,7 @@ class Fa3DecodeAttState(BaseDecodeAttState): cu_seqlens_k: torch.Tensor = None page_table: torch.Tensor = None b_att_seq_len: torch.Tensor = None + scheduler_metadata: torch.Tensor = None # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 decode_max_q_seq_len: int = None @@ -179,8 +184,33 @@ def init_state(self): ) self.b_att_seq_len = self.infer_state.b_seq_len self.decode_max_q_seq_len = 1 + self._init_scheduler_metadata() return + def _init_scheduler_metadata(self): + if get_scheduler_metadata is None: + self.scheduler_metadata = None + return + + model = self.backend.model + self.scheduler_metadata = get_scheduler_metadata( + batch_size=self.b_att_seq_len.shape[0], + max_seqlen_q=self.decode_max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + num_heads=model.config["num_attention_heads"] // model.tp_world_size_, + num_heads_k=model.tp_k_head_num_, + headdim=model.head_dim_, + cache_seqlens=self.b_att_seq_len, + qkv_dtype=model.data_type, + headdim_v=model.head_dim_, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + page_size=1, + causal=True, + num_splits=_DECODE_MAX_NUM_SPLITS, + pack_gqa=_DECODE_PACK_GQA, + ) + def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"): super().copy_for_decode_cuda_graph(new_state) @@ -235,6 +265,9 @@ def _normal_decode_att( causal=True, window_size=window_size, softcap=0.0, + scheduler_metadata=self.scheduler_metadata, + num_splits=_DECODE_MAX_NUM_SPLITS, + pack_gqa=_DECODE_PACK_GQA, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=False, diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index acbb1315fe..99fd864bf2 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -45,9 +45,12 @@ def init_state(self): torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len ) # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) def prefill_att( self, @@ -116,20 +119,19 @@ def init_state(self): super().init_state() self.backend: Fp8Fa3AttBackend = self.backend - args_mtp_step = get_env_start_args().mtp_step - att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) - assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 - - device = self.infer_state.input_ids.device - batch_size = att_batch_size + batch_size = self.b_att_seq_len.shape[0] mem_manager = self.backend.model.mem_manager offline_scales: torch.Tensor = mem_manager.scales head_num = mem_manager.head_num # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) - self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + ) return @@ -180,11 +182,11 @@ def _fp8_decode_att( k_cache=cache_k, v_cache=cache_v, page_table=self.page_table, - cache_seqlens=self.infer_state.b_seq_len, + cache_seqlens=self.b_att_seq_len, cu_seqlens_q=self.cu_seqlens_q, cu_seqlens_k_new=self.cu_seqlens_k, max_seqlen_q=self.decode_max_q_seq_len, - causal=False, + causal=True, window_size=(-1, -1), softcap=0.0, q_descale=q_scale.view(self.infer_state.batch_size, k_head_num), diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 91a004ec2e..b5145ac932 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -6,6 +6,82 @@ from .env_utils import set_flashinfer_envs +def _fast_plan_tensor_core_decode( + decode_wrapper, + indptr, + indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + indptr_host, + kv_lens_arr_host, + max_kv_len, +): + batch_size = len(last_page_len) + if batch_size != decode_wrapper._fixed_batch_size: + raise ValueError( + "The batch size should be fixed in cudagraph mode, the runtime batch size {} " + "mismatches the batch size set during initialization {}".format( + batch_size, decode_wrapper._fixed_batch_size + ) + ) + if len(indices) > len(decode_wrapper._paged_kv_indices_buf): + raise ValueError("The size of indices should be less than or equal to the allocated buffer") + + qo_indptr_host = getattr(decode_wrapper, "_lightllm_qo_indptr_host", None) + if qo_indptr_host is None or len(qo_indptr_host) != batch_size + 1: + from flashinfer.decode import _get_range_buf + + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + decode_wrapper._lightllm_qo_indptr_host = qo_indptr_host + + if indptr_host is None: + indptr_host = indptr.to("cpu") + if kv_lens_arr_host is None: + from flashinfer.decode import get_seq_lens + + last_page_len_host = last_page_len.to("cpu") + kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) + if max_kv_len is None: + max_kv_len = max(kv_lens_arr_host).item() + + decode_wrapper._batch_size = batch_size + decode_wrapper._num_qo_heads = num_qo_heads + decode_wrapper._num_kv_heads = num_kv_heads + decode_wrapper._block_tables = None + decode_wrapper._max_kv_len = max_kv_len + + args = [ + decode_wrapper._float_workspace_buffer, + decode_wrapper._int_workspace_buffer, + decode_wrapper._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_host, + kv_lens_arr_host, + batch_size, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + decode_wrapper.is_cuda_graph_enabled, + head_dim, + head_dim, + False, + -1, + ] + if decode_wrapper._backend == "fa2": + args.extend([-1, False, 0]) + decode_wrapper._plan_info = decode_wrapper._cached_module.plan(*args) + decode_wrapper._pos_encoding_mode = "NONE" + decode_wrapper._window_left = -1 + decode_wrapper._logits_soft_cap = 0.0 + decode_wrapper._sm_scale = None + decode_wrapper._rope_scale = None + decode_wrapper._rope_theta = None + + class FlashInferAttBackend(BaseAttBackend): def __init__(self, model): set_flashinfer_envs() @@ -25,6 +101,10 @@ def __init__(self, model): model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() ), ] + self.kv_starts_host_buffer = [ + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + torch.empty((model.graph_max_batch_size + 1,), dtype=torch.int32, device="cpu"), + ] self.q_data_type = model.data_type self.kv_data_type = model.data_type @@ -124,11 +204,11 @@ class FlashInferDecodeAttState(BaseDecodeAttState): kv_last_page_len_buffer: torch.Tensor = None kv_indices: torch.Tensor = None kv_starts: torch.Tensor = None + kv_starts_host: torch.Tensor = None + kv_seq_lens_host: torch.Tensor = None decode_wrapper: object = None def init_state(self): - import flashinfer - self.backend: FlashInferAttBackend = self.backend device = self.infer_state.input_ids.device model = self.backend.model @@ -154,8 +234,21 @@ def init_state(self): self.infer_state.b_kv_start_loc, self.infer_state.max_kv_seq_len, self.kv_indices, + zero_output=False, ) self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.infer_state.b_seq_len_cpu is not None: + self.kv_seq_lens_host = self.infer_state.b_seq_len_cpu + self.kv_starts_host = self.backend.kv_starts_host_buffer[self.infer_state.microbatch_index][ + : self.infer_state.batch_size + 1 + ] + self.kv_starts_host[0] = 0 + torch.cumsum(self.infer_state.b_seq_len_cpu, dim=0, out=self.kv_starts_host[1:]) + if self.infer_state.skip_decode_att_wrapper_init: + return + + import flashinfer + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.backend.workspace_buffer, @@ -182,7 +275,8 @@ def init_state(self): def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): super().copy_for_decode_cuda_graph(new_state) - self.decode_wrapper.plan( + _fast_plan_tensor_core_decode( + self.decode_wrapper, new_state.kv_starts, new_state.kv_indices, new_state.kv_last_page_len_buffer, @@ -190,9 +284,9 @@ def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): new_state.backend.tp_kv_head_num, new_state.backend.head_dim, 1, - q_data_type=new_state.backend.q_data_type, - kv_data_type=new_state.backend.kv_data_type, - non_blocking=True, + new_state.kv_starts_host, + new_state.kv_seq_lens_host, + new_state.infer_state.max_kv_seq_len, ) def decode_att( diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 84b44dc45a..8689839db4 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -116,8 +116,6 @@ class MlaFlashInferDecodeAttState(BaseDecodeAttState): decode_wrapper: object = None def init_state(self): - import flashinfer - self.backend: MlaFlashInferAttBackend = self.backend model = self.backend.model device = self.infer_state.input_ids.device @@ -144,7 +142,13 @@ def init_state(self): self.infer_state.b_kv_start_loc, self.infer_state.max_kv_seq_len, self.kv_indices, + zero_output=False, ) + if self.infer_state.skip_decode_att_wrapper_init: + return + + import flashinfer + assert self.decode_wrapper is None self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 94f9d4c1a2..e6db58475b 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -4,6 +4,7 @@ import gc import copy import json +import math import torch import torch.nn.functional as F import triton @@ -314,7 +315,9 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0] infer_state.b_req_idx = model_input.b_req_idx infer_state.b_seq_len = model_input.b_seq_len + infer_state.b_seq_len_cpu = model_input.b_seq_len_cpu infer_state.b_mtp_index = model_input.b_mtp_index + infer_state.b_num_accepted_tokens = model_input.b_num_accepted_tokens if model_input.is_prefill: if model_input.b_ready_cache_len is not None: infer_state.b_ready_cache_len = model_input.b_ready_cache_len @@ -352,6 +355,16 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) return infer_state + def _get_decode_padding_unit(self, model_input: ModelInput) -> int: + padding_unit = self.tp_world_size_ if self.args.enable_tpsp_mix_mode else 1 + if (not model_input.is_prefill) and self.args.mtp_step > 0: + padding_unit = math.lcm(padding_unit, self.args.mtp_step + 1) + return padding_unit + + def _get_decode_infer_batch_size(self, model_input: ModelInput) -> int: + padding_unit = self._get_decode_padding_unit(model_input) + return triton.cdiv(model_input.batch_size, padding_unit) * padding_unit + def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_size: int): if model_input.batch_size == new_batch_size: return model_input @@ -361,8 +374,27 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s padded_batch_size = new_batch_size - model_input.batch_size new_model_input = copy.copy(model_input) new_model_input.batch_size = new_batch_size - new_model_input.total_token_num += padded_batch_size * 2 - new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) + + is_mtp_grouped_decode = (not model_input.is_prefill) and self.args.mtp_step > 0 + if is_mtp_grouped_decode: + mtp_size = self.args.mtp_step + 1 + assert padded_batch_size % mtp_size == 0 + padded_req_num = padded_batch_size // mtp_size + new_model_input.total_token_num += padded_req_num * (mtp_size * (mtp_size + 3) // 2) + new_model_input.max_kv_seq_len = max(mtp_size + 1, model_input.max_kv_seq_len) + pad_seq_len = torch.arange( + 2, mtp_size + 2, dtype=new_model_input.b_seq_len.dtype, device=new_model_input.b_seq_len.device + ).repeat(padded_req_num) + new_model_input.b_seq_len = torch.cat((new_model_input.b_seq_len, pad_seq_len), dim=0) + # b_num_accepted_tokens 不再随 model_input 流转/补齐:它在 GDN 的 init_mtp_verify_extra_state + # 里按 req_first 从 req_to_accept_len gather,padding 组 req_first=HOLD(槽恒为 1)自然得 1。 + else: + new_model_input.total_token_num += padded_batch_size * 2 + new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) + new_model_input.b_seq_len = F.pad( + new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2 + ) + new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) new_model_input.b_req_idx = F.pad( new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID @@ -370,7 +402,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input.b_mtp_index = F.pad( new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 ) - new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) + if new_model_input.b_seq_len_cpu is not None: + new_model_input.b_seq_len_cpu = F.pad( + new_model_input.b_seq_len_cpu, (0, padded_batch_size), mode="constant", value=2 + ) new_model_input.mem_indexes = F.pad( new_model_input.mem_indexes, (0, padded_batch_size), @@ -549,10 +584,7 @@ def _decode( ) origin_batch_size = model_input.batch_size - if self.args.enable_tpsp_mix_mode: - infer_batch_size = triton.cdiv(model_input.batch_size, self.tp_world_size_) * self.tp_world_size_ - else: - infer_batch_size = model_input.batch_size + infer_batch_size = self._get_decode_infer_batch_size(model_input) if self.graph is not None and self.graph.can_run( batch_size=infer_batch_size, max_len_in_batch=model_input.max_kv_seq_len @@ -562,6 +594,8 @@ def _decode( model_input=model_input, new_batch_size=infer_batch_size ) infer_state = self._create_inferstate(model_input) + need_capture = self.graph.need_capture(infer_batch_size) + infer_state.skip_decode_att_wrapper_init = not need_capture copy_kv_index_to_req( self.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -571,7 +605,7 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() - if self.graph.need_capture(infer_batch_size): + if need_capture: infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: @@ -804,7 +838,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode origin_batch_size = model_input0.batch_size max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len) - infer_batch_size = triton.cdiv(origin_batch_size, self.tp_world_size_) * self.tp_world_size_ + infer_batch_size = self._get_decode_infer_batch_size(model_input0) if self.graph is not None and self.graph.can_run(infer_batch_size, max_len_in_batch): infer_batch_size = self.graph.find_closest_graph_batch_size(infer_batch_size) @@ -1037,6 +1071,9 @@ def autotune_layers(self): # 控制autotune的层数,用于适配不同模型 return self.config.get("first_k_dense_replace", 0) + 1 + def _autotune_extra_warmup(self): + return + @final @torch.no_grad() @post_empty_cache @@ -1106,6 +1143,11 @@ def _autotune_warmup(self): self.mem_manager.free_all() gc.collect() torch.cuda.empty_cache() + try: + self._autotune_extra_warmup() + except Exception as e: + logger.warning(f"extra autotune warmup failed: {str(e)}") + logger.exception(str(e)) self.layers_num = layer_num_bak torch.distributed.barrier() Autotuner.end_autotune_warmup() @@ -1171,12 +1213,7 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_mtp_draft_model = ( - "Deepseek3MTPModel" in str(self.__class__) - or "Qwen3MOEMTPModel" in str(self.__class__) - or "MistralMTPModel" in str(self.__class__) - or "Glm4MoeLiteMTPModel" in str(self.__class__) - ) + is_mtp_draft_model = getattr(self, "is_mtp_draft_model", False) if is_mtp_draft_model: special_model_input["mtp_draft_input_hiddens"] = torch.randn( token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 1795ff9a82..b96d0ae512 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -42,6 +42,7 @@ class ModelInput: multimodal_params: list = None # cpu 变量 mem_indexes_cpu: torch.Tensor = None + b_seq_len_cpu: torch.Tensor = None # prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理 # 的一些变量 b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出 @@ -53,6 +54,8 @@ class ModelInput: # 的 draft 模型的输入 mtp_draft_input_hiddens: Optional[torch.Tensor] = None + b_num_accepted_tokens: Optional[torch.Tensor] = None + def to_cuda(self): if self.input_ids is not None: self.input_ids = self.input_ids.cuda(non_blocking=True) @@ -64,8 +67,12 @@ def to_cuda(self): assert self.is_prefill self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) + if not self.b_seq_len.is_cuda: + self.b_seq_len_cpu = self.b_seq_len self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) + if self.b_num_accepted_tokens is not None: + self.b_num_accepted_tokens = self.b_num_accepted_tokens.cuda(non_blocking=True) if self.b_ready_cache_len is not None: self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True) if self.b_prefill_start_loc is not None: diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 782150661e..001e6299e8 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -2,6 +2,7 @@ import torch import copy import bisect +import math import triton from typing import Optional from lightllm.utils.log_utils import init_logger @@ -27,33 +28,43 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap - # gen cuda graph batch_sizes - # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] - # and [graph_split_batch_size + graph_grow_step_size, - # if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1) + # With MTP enabled, both the main-model verify forward and the draft (MTP) forward run over + # the (mtp_step+1)-expanded decode layout, so every decode batch size is a multiple of + # (mtp_step+1) and there is a single decode layout — the graph is keyed by batch size alone. + batch_size_multiple = self.mtp_step + 1 if self.mtp_step > 0 else 1 + self.cuda_graph_batch_sizes = self._build_cuda_graph_batch_sizes(batch_size_multiple=batch_size_multiple) + logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") - graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1) - graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1) + def _build_cuda_graph_batch_sizes(self, batch_size_multiple: int): + graph_split_batch_size = self.args.graph_split_batch_size * batch_size_multiple + graph_grow_step_size = self.args.graph_grow_step_size * batch_size_multiple - batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)] - for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size): + batch_sizes = [i * batch_size_multiple for i in range(1, self.args.graph_split_batch_size + 1)] + for _batch_size in range( + graph_split_batch_size + graph_grow_step_size, + self.max_batch_size, + graph_grow_step_size, + ): batch_sizes.append(_batch_size) - batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size])) - batch_sizes.append(max_batch_size) + batch_sizes = list(set([e for e in batch_sizes if e < self.max_batch_size])) + batch_sizes.append(self.max_batch_size) batch_sizes.sort() if self.args.enable_tpsp_mix_mode: - batch_sizes = [triton.cdiv(e, self.tp_world_size) * self.tp_world_size for e in batch_sizes] + padding_unit = math.lcm(self.tp_world_size, batch_size_multiple) + batch_sizes = [triton.cdiv(e, padding_unit) * padding_unit for e in batch_sizes] batch_sizes = list(set(batch_sizes)) batch_sizes.sort() - self.cuda_graph_batch_sizes = batch_sizes assert batch_sizes[-1] == self.max_batch_size - logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") + return batch_sizes def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch + def _decode_graph_key(self, infer_state: InferStateInfo): + return infer_state.input_ids.shape[0] + def need_capture(self, batch_size): find_batch_size = self.find_closest_graph_batch_size(batch_size) if find_batch_size is not None: @@ -69,6 +80,54 @@ def find_closest_graph_batch_size(self, batch_size): else: return None + def _build_warmup_decode_model_input( + self, + model, + batch_size: int, + device: str = "cuda", + ) -> ModelInput: + mtp_size = self.mtp_step + 1 + input_ids = torch.ones(batch_size, dtype=torch.int32, device=device) + mem_indexes = model.mem_manager.alloc(batch_size).to(device) + b_req_idx = torch.full( + (batch_size,), + fill_value=model.req_manager.HOLD_REQUEST_ID, + dtype=torch.int32, + device=device, + ) + + b_num_accepted_tokens = None + if self.mtp_step > 0: + assert batch_size % mtp_size == 0, "MTP decode CUDA graph batch size must be a multiple of mtp_step + 1" + real_batch_size = batch_size // mtp_size + b_mtp_index = torch.arange(mtp_size, dtype=torch.int32, device=device).repeat(real_batch_size) + b_seq_len = torch.arange(2, mtp_size + 2, dtype=torch.int32, device=device).repeat(real_batch_size) + # b_num_accepted_tokens 不再随 model_input 传入:GDN 的 init_mtp_verify_extra_state 会按 + # req_first(全 HOLD,槽恒为 1) gather,warmup/capture 自然得到全 1,等价旧的 torch.ones。 + total_token_num = real_batch_size * (mtp_size * (mtp_size + 3) // 2) + else: + seq_len = 2 + total_token_num = batch_size * seq_len + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device=device) + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device=device) + b_seq_len.fill_(seq_len) + + return ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_q_seq_len=1, + max_kv_seq_len=self.graph_max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_mtp_index=b_mtp_index, + b_num_accepted_tokens=b_num_accepted_tokens, + is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], + **model._gen_special_model_input(batch_size), + ) + def _capture_decode(self, decode_func, infer_state: InferStateInfo): graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids @@ -96,7 +155,11 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): with torch.cuda.graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) - self.graph[batch_size] = (graph_obj, infer_state, model_output) + self.graph[self._decode_graph_key(infer_state)] = ( + graph_obj, + infer_state, + model_output, + ) graph_obj.replay() return model_output @@ -130,7 +193,7 @@ def _capture_decode_overlap( with torch.cuda.graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) - self.graph[batch_size] = ( + self.graph[self._decode_graph_key(infer_state)] = ( graph_obj, infer_state, infer_state1, @@ -157,8 +220,7 @@ def capture_decode( return self._capture_decode(decode_func, infer_state) def _replay(self, infer_state: InferStateInfo): - batch_size = infer_state.input_ids.shape[0] - graph_obj, graph_infer_state, graph_output = self.graph[batch_size] + graph_obj, graph_infer_state, graph_output = self.graph[self._decode_graph_key(infer_state)] graph_infer_state.copy_for_cuda_graph(infer_state) graph_obj.replay() return graph_output @@ -168,14 +230,13 @@ def _replay_overlap( infer_state: InferStateInfo, infer_state1: InferStateInfo, ): - batch_size = infer_state.input_ids.shape[0] ( graph_obj, graph_infer_state, graph_infer_state1, graph_model_output, graph_model_output1, - ) = self.graph[batch_size] + ) = self.graph[self._decode_graph_key(infer_state)] graph_infer_state.copy_for_cuda_graph(infer_state) graph_infer_state1.copy_for_cuda_graph(infer_state1) graph_obj.replay() @@ -198,38 +259,9 @@ def warmup(self, model): # decode cuda graph init for batch_size in self.cuda_graph_batch_sizes[::-1]: - seq_len = 2 - total_token_num = batch_size * seq_len - max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() - b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") - b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - - model_input = ModelInput( - batch_size=batch_size, - total_token_num=total_token_num, - max_q_seq_len=1, - max_kv_seq_len=max_len_in_batch, - input_ids=input_ids, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - b_mtp_index=b_mtp_index, - is_prefill=False, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], - **model._gen_special_model_input(batch_size), - ) + model_input = self._build_warmup_decode_model_input(model, batch_size) model_output: ModelOutput = model.forward(model_input) del model_output - del input_ids - del mem_indexes - del b_req_idx - del b_seq_len model.mem_manager.free_all() model.req_manager.free_all() @@ -256,32 +288,7 @@ def warmup_overlap(self, model): decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph - seq_len = 2 - total_token_num = batch_size * seq_len - max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() - b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") - b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - - micro_batch = ModelInput( - is_prefill=False, - batch_size=batch_size, - total_token_num=total_token_num, - max_q_seq_len=1, - max_kv_seq_len=max_len_in_batch, - input_ids=input_ids, - b_mtp_index=b_mtp_index, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], - **model._gen_special_model_input(batch_size), - ) + micro_batch = self._build_warmup_decode_model_input(model, batch_size) decode_batches.append(micro_batch) del micro_batch diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 711484c835..954d74c885 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -39,7 +39,10 @@ def __init__(self): self.b_mtp_index: torch.Tensor = None + self.b_num_accepted_tokens: torch.Tensor = None + self.b_seq_len: torch.Tensor = None + self.b_seq_len_cpu: torch.Tensor = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None # prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度 @@ -56,6 +59,7 @@ def __init__(self): self.return_all_prompt_logics: bool = False self.multimodal_params: dict = None self.is_cuda_graph: bool = False # 标记是否是cuda graph的捕获推理 + self.skip_decode_att_wrapper_init: bool = False self.dist_group: CustomProcessGroup = None # 在microbatch overlap的运行模式下,用于标记当前 microbatch 的 index 序号 diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index f0cc129c09..061658bc07 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -53,6 +53,18 @@ def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tens def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") + def _add_residual_ffn_norm(self, input_embdings, residual, infer_state: InferStateInfo, layer_weight): + add_rmsnorm = getattr(layer_weight.ffn_norm_weight_, "add_rmsnorm", None) + if add_rmsnorm is None: + input_embdings.add_(residual.view(-1, self.embed_dim_)) + return self._ffn_norm(input_embdings, infer_state, layer_weight) + return add_rmsnorm( + input=input_embdings, + residual=residual.view(-1, self.embed_dim_), + eps=self.eps_, + alloc_func=self.alloc_tensor, + ) + def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) @@ -89,10 +101,9 @@ def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, l def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): input1 = self._att_norm(input_embdings, infer_state, layer_weight) o = self.token_attention_forward(input1, infer_state, layer_weight) - input_embdings.add_(o.view(-1, self.embed_dim_)) + input1 = self._add_residual_ffn_norm(input_embdings, o, infer_state, layer_weight) o = None - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..d9a77b39a5 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -134,6 +134,8 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -150,6 +152,8 @@ def experts( num_expert_group=num_expert_group, is_prefill=is_prefill, per_expert_scale=self.per_expert_scale, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) def low_latency_dispatch( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index dd6f9a6880..b54b03ee05 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -63,5 +63,7 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 4d4614c007..bc0e86d7eb 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -76,6 +76,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): output = fused_experts( hidden_states=input_tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 0094b09b1c..417d001c72 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -30,6 +30,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index a0d30547a3..fdda2b2139 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -94,6 +94,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale @@ -111,6 +113,8 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return input_tensor @@ -129,6 +133,8 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -150,5 +156,7 @@ def __call__( topk_ids=topk_ids, router_logits=router_logits, is_prefill=is_prefill, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return output diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index ee9d1923c3..33b59ac5a9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -2,7 +2,7 @@ from typing import Optional, Dict from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size -from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import add_rmsnorm_forward, rmsnorm_forward from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward @@ -71,6 +71,21 @@ def __call__( ) -> torch.Tensor: return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def add_rmsnorm( + self, + input: torch.Tensor, + residual: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + input.ndim in [2, 3] and residual.ndim in [2, 3] and self.weight.ndim == 1 + ), f"input.ndim: {input.ndim}, residual.ndim: {residual.ndim}, weight.ndim: {self.weight.ndim}" + if out is None: + out = alloc_func(input.shape, dtype=input.dtype, device=input.device) + return add_rmsnorm_forward(x=input, residual=residual, weight=self.weight, eps=eps, out=out) + class GatedRMSNormWeight(RMSNormWeight): def _triton_forward( diff --git a/lightllm/common/basemodel/mtp_verify_extra_state.py b/lightllm/common/basemodel/mtp_verify_extra_state.py new file mode 100644 index 0000000000..95bfce9388 --- /dev/null +++ b/lightllm/common/basemodel/mtp_verify_extra_state.py @@ -0,0 +1,26 @@ +import torch + +from lightllm.utils.envs_utils import get_env_start_args + + +def init_mtp_verify_extra_state(self, model): + self.b_att_seq_len = self.b_seq_len + mtp_step = get_env_start_args().mtp_step + self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + self.b_conv_buffer_idx = self.b_req_idx + self.is_mtp_verify = (mtp_step > 0) and (not self.is_prefill) and (self.b_mtp_index is not None) + self.b_gdn_verify_cu_seqlens = None + self.b_ssm_index_rows = None + if self.is_mtp_verify: + step = mtp_step + 1 + n_real = self.b_req_idx.shape[0] // step + self.b_gdn_verify_cu_seqlens = torch.arange( + 0, (n_real + 1) * step, step, dtype=torch.int32, device=self.b_req_idx.device + ) + req_first = self.b_req_idx.view(n_real, step)[:, 0] + base = (req_first * step).view(n_real, 1) + self.b_ssm_index_rows = base + torch.arange(step, device=base.device, dtype=base.dtype).view(1, step) + assert self.b_ssm_index_rows.shape == (n_real, step) + self.b_conv_buffer_idx = req_first + self.b_num_accepted_tokens = model.req_manager.req_to_accept_len[req_first] + return diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 76acea25a7..95dcba9836 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -221,10 +221,17 @@ def moe_align_fused_kernel( expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] token_num, + expert_num: tl.constexpr, topk_num: tl.constexpr, BLOCK_SIZE: tl.constexpr, + ZERO_EXPERT_TOKEN_NUM: tl.constexpr, + BLOCK_EXPERT: tl.constexpr, ): token_block = tl.program_id(0) + if ZERO_EXPERT_TOKEN_NUM: + expert_offs = tl.arange(0, BLOCK_EXPERT) + tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) + offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offs < token_num * topk_num @@ -282,6 +289,8 @@ def moe_align_fused( run_config = {} BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256) num_warps = run_config.get("num_warps", 4) + expert_num = expert_token_num.shape[0] + zero_expert_token_num = token_num * topk_num <= BLOCK_SIZE grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),) moe_align_fused_kernel[grid]( @@ -291,8 +300,11 @@ def moe_align_fused( expert_to_weight, expert_token_num, token_num, + expert_num, topk_num, BLOCK_SIZE=BLOCK_SIZE, + ZERO_EXPERT_TOKEN_NUM=zero_expert_token_num, + BLOCK_EXPERT=triton.next_power_of_2(expert_num), num_warps=num_warps, ) return expert_to_token_index, expert_to_weight, expert_token_num @@ -911,6 +923,8 @@ def fused_experts_impl( layout="blocked", limit=None, alpha=None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -957,7 +971,12 @@ def fused_experts_impl( expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda") expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda") - expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda") + expert_token_count_in_align_kernel = topk_num * tokens_in_chunk <= 128 + expert_to_token_num = ( + torch.empty((E,), dtype=torch.int32, device="cuda") + if expert_token_count_in_align_kernel + else torch.zeros((E,), dtype=torch.int32, device="cuda") + ) moe_align_fused( expert_to_token_index=expert_to_tokens, expert_to_weight=expert_to_weights, @@ -1011,8 +1030,15 @@ def fused_experts_impl( bias=w2_bias, ) + has_shared_gate = shared_expert_out is not None moe_sum_reduce( - intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx] + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + shared=None if not has_shared_gate else shared_expert_out[begin_chunk_idx:end_chunk_idx], + gate=None if not has_shared_gate else shared_expert_gate[begin_chunk_idx:end_chunk_idx], + run_config=( + None if not has_shared_gate else {"BLOCK_M": 1, "BLOCK_DIM": 128, "NUM_STAGE": 1, "num_warps": 2} + ), ) return out_hidden_states @@ -1035,6 +1061,8 @@ def inplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: fused_experts_impl( hidden_states, @@ -1054,6 +1082,8 @@ def inplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) @@ -1075,6 +1105,8 @@ def inplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: pass @@ -1105,6 +1137,8 @@ def outplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: return fused_experts_impl( hidden_states, @@ -1124,6 +1158,8 @@ def outplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) @@ -1145,6 +1181,8 @@ def outplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: return torch.empty_like(hidden_states) @@ -1176,6 +1214,8 @@ def fused_experts( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): if inplace: torch.ops.lightllm.inplace_fused_experts_impl( @@ -1195,6 +1235,8 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return hidden_states else: @@ -1215,4 +1257,6 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py index e16351eec8..4f95cca7c6 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py @@ -14,12 +14,20 @@ def _moe_sum_reduce_kernel( output_ptr, output_stride_0, output_stride_1, + shared_ptr, + shared_stride_0, + shared_stride_1, + gate_ptr, + gate_stride_0, + gate_stride_1, token_num: int, topk_num: int, hidden_dim: int, BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr, NUM_STAGE: tl.constexpr, + HAS_SHARED_GATE: tl.constexpr, + GATE_DIM: tl.constexpr, ): input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) @@ -42,12 +50,38 @@ def _moe_sum_reduce_kernel( for i in tl.range(0, topk_num, num_stages=NUM_STAGE): tmp = tl.load(input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0) accumulator += tmp + if HAS_SHARED_GATE: + shared = tl.load( + shared_ptr + token_index * shared_stride_0 + offs_dim * shared_stride_1, + mask=offs_dim < dim_end, + other=0.0, + ).to(tl.float32) + if GATE_DIM == 1: + gate = tl.load(gate_ptr + token_index * gate_stride_0).to(tl.float32) + tl.zeros( + (BLOCK_DIM,), dtype=tl.float32 + ) + else: + gate = tl.load( + gate_ptr + token_index * gate_stride_0 + offs_dim * gate_stride_1, + mask=offs_dim < dim_end, + other=0.0, + ).to(tl.float32) + gate = 1.0 / (1.0 + tl.exp(-gate)) + accumulator += shared * gate store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim tl.store(store_t_ptr, accumulator.to(input_ptr.dtype.element_ty), mask=offs_dim < dim_end) -def _get_moe_sum_reduce_static_key(input: torch.Tensor, output: torch.Tensor): - return {"topk_num": input.shape[1], "hidden_dim": input.shape[2], "out_dtype": str(output.dtype)} +def _get_moe_sum_reduce_static_key( + input: torch.Tensor, output: torch.Tensor, shared: torch.Tensor = None, gate: torch.Tensor = None +): + return { + "topk_num": input.shape[1], + "hidden_dim": input.shape[2], + "out_dtype": str(output.dtype), + "has_shared_gate": shared is not None, + "gate_dim": 0 if gate is None else gate.shape[-1], + } def _get_moe_sum_reduce_configs(): @@ -67,12 +101,20 @@ def _get_moe_sum_reduce_configs(): run_key_func=lambda input: input.shape[0], mutates_args=["output"], ) -def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = None): +def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, shared=None, gate=None, run_config: Dict = None): assert input.is_contiguous() assert output.is_contiguous() token_num, topk_num, hidden_dim = input.shape assert output.shape[0] == token_num and output.shape[1] == hidden_dim + has_shared_gate = shared is not None + if has_shared_gate: + assert gate is not None + shared = shared.view(token_num, hidden_dim) + gate = gate.view(token_num, gate.shape[-1]) + assert shared.is_contiguous() + assert gate.is_contiguous() + assert gate.shape[1] in (1, hidden_dim) if not run_config: run_config = { @@ -97,12 +139,20 @@ def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = *input.stride(), output, *output.stride(), + shared if has_shared_gate else output, + shared.stride(0) if has_shared_gate else 0, + shared.stride(1) if has_shared_gate else 0, + gate if has_shared_gate else output, + gate.stride(0) if has_shared_gate else 0, + gate.stride(1) if has_shared_gate else 0, token_num=token_num, topk_num=topk_num, hidden_dim=hidden_dim, BLOCK_M=BLOCK_M, BLOCK_DIM=BLOCK_DIM, NUM_STAGE=NUM_STAGE, + HAS_SHARED_GATE=has_shared_gate, + GATE_DIM=gate.shape[1] if has_shared_gate else 0, num_warps=num_warps, ) return diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py index d9f631cbd0..fd4c16043d 100644 --- a/lightllm/common/basemodel/triton_kernel/linear_att_copy.py +++ b/lightllm/common/basemodel/triton_kernel/linear_att_copy.py @@ -5,12 +5,13 @@ @triton.jit def _copy_linear_att_state_to_kv_buffer( - gpu_conv_ptr, # [linear_layer_num, size_num, xdim] + gpu_conv_ptr, # [linear_layer_num, size_num, conv_dim * gpu_widened_width] (uint8 tail) gpu_ssm_ptr, # [linear_layer_num, size_num, xxdim] - cpu_kv_conv_ptr, # [size, linear_layer_num, xdim] + cpu_kv_conv_ptr, # [size, linear_layer_num, conv_dim * width_narrow] (uint8 tail) cpu_kv_ssm_ptr, # [size, linear_layer_num, xxdim] b_req_idx, # [batch_size,] big_page_buffer_ids, # [batch_size,] + num_accepted_tokens_ptr, # [batch_size,] gpu_conv_stride_l, gpu_conv_stride_s, gpu_conv_stride_d, @@ -24,7 +25,9 @@ def _copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_stride_l, cpu_kv_ssm_stride_d, mtp_step, - gpu_conv_tail_dim, + conv_dim, # number of conv rows (the d dimension) + gpu_conv_row_bytes, # widened per-row byte length: gpu_widened_width * itemsize + conv_narrow_row_bytes, # narrow per-row byte length: width_narrow * itemsize gpu_ssm_tail_dim, BLOCK: tl.constexpr, ): @@ -40,28 +43,26 @@ def _copy_linear_att_state_to_kv_buffer( return cur_req_idx = tl.load(b_req_idx + cur_batch).to(tl.int64) - cur_state_req_idx = (cur_req_idx * (mtp_step + 1)).to(tl.int64) + accept_len = tl.load(num_accepted_tokens_ptr + cur_batch).to(tl.int64) + canonical_off = accept_len - 1 - for i in range(tl.cdiv(gpu_conv_tail_dim, BLOCK)): - gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) - mask = gpu_start_off < gpu_conv_tail_dim - conv_data = tl.load( - gpu_conv_ptr + cur_layer * gpu_conv_stride_l + cur_state_req_idx * gpu_conv_stride_s + gpu_start_off, - mask=mask, - ) - dest_conv_ptr = ( - cpu_kv_conv_ptr - + big_page_buffer_idx * cpu_kv_conv_stride_s - + cur_layer * cpu_kv_conv_stride_l - + gpu_start_off - ) - tl.store(dest_conv_ptr, conv_data, mask=mask) + conv_src_slot = cur_req_idx + conv_off_bytes = canonical_off * gpu_conv_stride_d + gpu_conv_base = gpu_conv_ptr + cur_layer * gpu_conv_stride_l + conv_src_slot * gpu_conv_stride_s + conv_off_bytes + cpu_conv_base = cpu_kv_conv_ptr + big_page_buffer_idx * cpu_kv_conv_stride_s + cur_layer * cpu_kv_conv_stride_l + for d in range(conv_dim): + for i in range(tl.cdiv(conv_narrow_row_bytes, BLOCK)): + off = i * BLOCK + tl.arange(0, BLOCK) + mask = off < conv_narrow_row_bytes + conv_data = tl.load(gpu_conv_base + d * gpu_conv_row_bytes + off, mask=mask) + tl.store(cpu_conv_base + d * cpu_kv_conv_stride_d + off, conv_data, mask=mask) + ssm_src_slot = (cur_req_idx * (mtp_step + 1) + canonical_off).to(tl.int64) for i in range(tl.cdiv(gpu_ssm_tail_dim, BLOCK)): gpu_start_off = i * BLOCK + tl.arange(0, BLOCK) mask = gpu_start_off < gpu_ssm_tail_dim ssm_data = tl.load( - gpu_ssm_ptr + cur_layer * gpu_ssm_stride_l + cur_state_req_idx * gpu_ssm_stride_s + gpu_start_off, + gpu_ssm_ptr + cur_layer * gpu_ssm_stride_l + ssm_src_slot * gpu_ssm_stride_s + gpu_start_off, mask=mask, ) dest_ssm_ptr = ( @@ -75,32 +76,51 @@ def _copy_linear_att_state_to_kv_buffer( def copy_linear_att_state_to_kv_buffer( b_req_idx: torch.Tensor, big_page_buffer_ids: torch.Tensor, - gpu_conv_state: torch.Tensor, # [linear_layer_num, s, ...] - gpu_ssm_state: torch.Tensor, # [linear_layer_num, s, ...] - cpu_kv_conv_state: torch.Tensor, # [s, linear_layer_num, ...] - cpu_kv_ssm_state: torch.Tensor, # [s, linear_layer_num, ...] + gpu_conv_state: torch.Tensor, # [linear_layer_num, s_widened, conv_dim, gpu_widened_width] + gpu_ssm_state: torch.Tensor, # [linear_layer_num, s_block, ...] + cpu_kv_conv_state: torch.Tensor, # [size, linear_layer_num, conv_dim, width_narrow] + cpu_kv_ssm_state: torch.Tensor, # [size, linear_layer_num, ...] mtp_step: int, + b_num_accepted_tokens: torch.Tensor, # [batch_size,] per-req post-accept count (>=1) ): assert len(b_req_idx) == big_page_buffer_ids.shape[0] + assert len(b_req_idx) == b_num_accepted_tokens.shape[0] BLOCK = 4096 - gpu_conv_state = gpu_conv_state.view(gpu_conv_state.shape[0], gpu_conv_state.shape[1], -1).view(dtype=torch.uint8) - gpu_ssm_state = gpu_ssm_state.view(gpu_ssm_state.shape[0], gpu_ssm_state.shape[1], -1).view(dtype=torch.uint8) - cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], -1).view( - dtype=torch.uint8 + + assert gpu_conv_state.dim() >= 4, "gpu_conv_state must be [layer, s, conv_dim, widened_width]" + assert cpu_kv_conv_state.dim() >= 4, "cpu_kv_conv_state must be [size, layer, conv_dim, width_narrow]" + # #6: the byte snapshot hardcodes gpu_conv_stride_d=conv_itemsize, which is only valid when the + # widened-width axis is element-contiguous (stride 1). Fail fast instead of snapshotting wrong bytes. + assert gpu_conv_state.stride(3) == 1, ( + "gpu_conv_state widened-width axis must be element-contiguous (stride 1); " + "gpu_conv_stride_d=conv_itemsize assumes it" + ) + # #18: canonical_off = accept_len - 1 indexes into the widened slot; bound it to [0, mtp_step] + # (accept_len in [1, mtp_step+1]) so a stale/oversized accept-count can't slice past the slot. + assert int(b_num_accepted_tokens.min()) >= 1 and int(b_num_accepted_tokens.max()) <= mtp_step + 1, ( + f"b_num_accepted_tokens out of range [1, {mtp_step + 1}]: " + f"min={int(b_num_accepted_tokens.min())} max={int(b_num_accepted_tokens.max())}" ) + conv_itemsize = gpu_conv_state.element_size() + gpu_conv_state = gpu_conv_state.view( + gpu_conv_state.shape[0], gpu_conv_state.shape[1], gpu_conv_state.shape[2], -1 + ).view(dtype=torch.uint8) + cpu_kv_conv_state = cpu_kv_conv_state.view( + cpu_kv_conv_state.shape[0], cpu_kv_conv_state.shape[1], cpu_kv_conv_state.shape[2], -1 + ).view(dtype=torch.uint8) + + gpu_ssm_state = gpu_ssm_state.view(gpu_ssm_state.shape[0], gpu_ssm_state.shape[1], -1).view(dtype=torch.uint8) cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], cpu_kv_ssm_state.shape[1], -1).view( dtype=torch.uint8 ) - assert gpu_conv_state.shape[-1] == cpu_kv_conv_state.shape[-1] + + assert gpu_conv_state.shape[2] == cpu_kv_conv_state.shape[2], "conv_dim mismatch between gpu and cpu conv buffers" assert gpu_ssm_state.shape[-1] == cpu_kv_ssm_state.shape[-1] - assert ( - gpu_conv_state.stride(-1) - == gpu_ssm_state.stride(-1) - == cpu_kv_conv_state.stride(-1) - == cpu_kv_ssm_state.stride(-1) - ) - gpu_conv_tail_dim = gpu_conv_state.shape[-1] + conv_dim = gpu_conv_state.shape[2] + gpu_conv_row_bytes = gpu_conv_state.shape[-1] # widened per-row byte length + conv_narrow_row_bytes = cpu_kv_conv_state.shape[-1] # narrow per-row byte length + assert conv_narrow_row_bytes <= gpu_conv_row_bytes gpu_ssm_tail_dim = gpu_ssm_state.shape[-1] layer_num = gpu_conv_state.shape[0] @@ -114,9 +134,10 @@ def copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_ptr=cpu_kv_ssm_state, b_req_idx=b_req_idx, big_page_buffer_ids=big_page_buffer_ids, + num_accepted_tokens_ptr=b_num_accepted_tokens, gpu_conv_stride_l=gpu_conv_state.stride(0), gpu_conv_stride_s=gpu_conv_state.stride(1), - gpu_conv_stride_d=gpu_conv_state.stride(2), + gpu_conv_stride_d=conv_itemsize, gpu_ssm_stride_l=gpu_ssm_state.stride(0), gpu_ssm_stride_s=gpu_ssm_state.stride(1), gpu_ssm_stride_d=gpu_ssm_state.stride(2), @@ -127,7 +148,9 @@ def copy_linear_att_state_to_kv_buffer( cpu_kv_ssm_stride_l=cpu_kv_ssm_state.stride(1), cpu_kv_ssm_stride_d=cpu_kv_ssm_state.stride(2), mtp_step=mtp_step, - gpu_conv_tail_dim=gpu_conv_tail_dim, + conv_dim=conv_dim, + gpu_conv_row_bytes=gpu_conv_row_bytes, + conv_narrow_row_bytes=conv_narrow_row_bytes, gpu_ssm_tail_dim=gpu_ssm_tail_dim, BLOCK=BLOCK, ) diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py index 37b27cadb2..1251dddc33 100644 --- a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py +++ b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py @@ -193,11 +193,7 @@ def copy_kv_buffer_to_cpu_cache( cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1] full_att_layer_num = gpu_kv_full_att_state.shape[-2] - assert ( - full_att_layer_num - == (linear_config.all_layer_num // linear_config.full_attention_interval) - == (linear_config.all_layer_num - linear_config.linear_layer_num) - ) + assert full_att_layer_num == linear_config.get_persisted_full_att_layer_num() assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] @@ -428,6 +424,7 @@ def copy_cpu_cache_to_kv_buffer( cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1] full_att_layer_num = gpu_full_att_kv_state.shape[-2] + assert full_att_layer_num == linear_config.get_persisted_full_att_layer_num() assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] diff --git a/lightllm/common/basemodel/triton_kernel/mtp_utils.py b/lightllm/common/basemodel/triton_kernel/mtp_utils.py index 2d70a68c05..bdd59c65e3 100644 --- a/lightllm/common/basemodel/triton_kernel/mtp_utils.py +++ b/lightllm/common/basemodel/triton_kernel/mtp_utils.py @@ -149,35 +149,48 @@ def mtp_scatter_next_token_ids( @triton.jit -def _fwd_kernel_gen_b_req_mtp_start_loc( - b_mtp_index, +def _fwd_kernel_scatter_accept_len( + req_to_accept_len, b_req_mtp_start_loc, - num_reqs: tl.constexpr, - batch_size: tl.constexpr, - BLOCK_SIZE: tl.constexpr, + b_req_idx, + mtp_accept_len, ): - offset = tl.arange(0, BLOCK_SIZE) - cur_mtp_index = tl.load(b_mtp_index + offset, mask=offset < batch_size, other=-1) - non_zero_mask = tl.where(cur_mtp_index == 0, 1, 0) # 1 0 1 0 0 - output_offset = tl.cumsum(non_zero_mask) - 1 - tl.store(b_req_mtp_start_loc + output_offset, offset, mask=non_zero_mask == 1) + cur_index = tl.program_id(0) + req_start_loc = tl.load(b_req_mtp_start_loc + cur_index) + cur_req_idx = tl.load(b_req_idx + req_start_loc) + accept_len = tl.load(mtp_accept_len + cur_index) + tl.store(req_to_accept_len + cur_req_idx, accept_len) return -def gen_b_req_mtp_start_loc(b_mtp_index: torch.Tensor, num_reqs: int): - b_req_mtp_start_loc = torch.empty((num_reqs,), dtype=torch.int32, device=b_mtp_index.device) - BLOCK_SIZE = triton.next_power_of_2(b_mtp_index.shape[0]) - batch_size = b_mtp_index.shape[0] - grid = (1,) - _fwd_kernel_gen_b_req_mtp_start_loc[grid]( - b_mtp_index=b_mtp_index, +def scatter_mtp_accept_len( + req_to_accept_len: torch.Tensor, + b_req_mtp_start_loc: torch.Tensor, + b_req_idx: torch.Tensor, + mtp_accept_len: torch.Tensor, +): + """ + 将本步每个真实请求(组首)的 accept 数量写入 GPU 常驻的 req_to_accept_len[req_idx]。 + 融合 `req_to_accept_len[b_req_idx[b_req_mtp_start_loc]] = mtp_accept_len` 的 gather+scatter + 为单次 launch、无中间张量。每个 program 处理一个真实请求。 + Args: + req_to_accept_len: (max_req_num + 1,) + b_req_mtp_start_loc: (num_reqs,) 每组首行在 batch 中的偏移 + b_req_idx: (batch_size,) grouped 布局的 req_idx(组首即该请求的 req_idx) + mtp_accept_len: (num_reqs,) + """ + num_reqs = mtp_accept_len.shape[0] + if num_reqs == 0: + return + grid = (num_reqs,) + _fwd_kernel_scatter_accept_len[grid]( + req_to_accept_len=req_to_accept_len, b_req_mtp_start_loc=b_req_mtp_start_loc, - num_reqs=num_reqs, - batch_size=batch_size, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=8, + b_req_idx=b_req_idx, + mtp_accept_len=mtp_accept_len, + num_warps=1, + num_stages=1, ) - return b_req_mtp_start_loc def test_mtp_verify(): @@ -201,13 +214,5 @@ def test_mtp_verify(): print(accepted_index) -def test_gen_b_req_mtp_start_loc(): - b_mtp_index = torch.tensor([0, 1, 0, 1, 2], dtype=torch.int32, device="cuda") - gt_output = torch.where(b_mtp_index == 0)[0] - b_req_mtp_start_loc = gen_b_req_mtp_start_loc(b_mtp_index, 2) - print(b_req_mtp_start_loc, gt_output) - - if __name__ == "__main__": test_mtp_verify() - # test_gen_b_req_mtp_start_loc() diff --git a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py index 89db5e00cb..c62c5eb5d2 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py @@ -16,7 +16,6 @@ def gated_rmsnorm_forward_kernel( W, # pointer to the weights B, # pointer to the biases Z, # pointer to the other branch (required, not optional) - Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_z_row, @@ -33,7 +32,6 @@ def gated_rmsnorm_forward_kernel( X += row * stride_x_row + group * N Y += row * stride_y_row + group * N Z += row * stride_z_row + group * N - Rstd += group * M W += group * N if HAS_BIAS: B += group * N @@ -47,7 +45,6 @@ def gated_rmsnorm_forward_kernel( xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) @@ -128,9 +125,6 @@ def gated_rmsnorm_forward( else: out = torch.empty_like(x) assert out.stride(-1) == 1 - # For RMS norm, we still need rstd for the kernel - rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) - # Default heuristic when autotune is disabled or no config provided if not run_config: # Less than 64KB per feature: enqueue fused kernel @@ -160,7 +154,6 @@ def gated_rmsnorm_forward( weight, bias, z, - rstd, x.stride(0), out.stride(0), z.stride(0), diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py index 8dc8558922..79f57ba051 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py @@ -3,6 +3,7 @@ import triton import triton.language as tl import os +from lightllm.common.triton_utils.autotuner import autotune rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8")) @@ -48,6 +49,51 @@ def _rms_norm_fwd_fused( tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) +@triton.jit +def _add_rms_norm_fwd_fused( + X, + R, + Y, + W, + x_stride0, + x_stride1, + r_stride0, + r_stride1, + y_stride0, + y_stride1, + N, + eps, + HAS_WEIGHT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + X += row * x_stride0 + R += row * r_stride0 + Y += row * y_stride0 + + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32) + x = x + r + tl.store(X + cols * x_stride1, x.to(X.dtype.element_ty), mask=mask) + _var += x * x + + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + y = x * rstd + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + y *= w + tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) + + def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None): # allocate output y = torch.empty_like(x) if out is None else out @@ -60,7 +106,7 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() + MAX_FUSED_SIZE = 65536 // x_arg.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) # print("BLOCK_SIZE:", BLOCK_SIZE) if N > BLOCK_SIZE: @@ -86,6 +132,77 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) return y +def _get_add_rmsnorm_configs(): + return [{"num_warps": nw} for nw in [4, 8, 16]] + + +def _get_add_rmsnorm_static_key(x_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor): + return { + "x_dtype": str(x_arg.dtype), + "out_dtype": str(y_arg.dtype), + "weight_dtype": "none" if weight is None else str(weight.dtype), + "N": x_arg.shape[1], + "has_weight": weight is not None, + } + + +@autotune( + kernel_name="add_rmsnorm_forward:v1", + configs_gen_func=_get_add_rmsnorm_configs, + static_key_func=_get_add_rmsnorm_static_key, + run_key_func=lambda x_arg: x_arg.shape[0], + mutates_args=["x_arg", "y_arg"], +) +def _add_rmsnorm_forward( + x_arg: torch.Tensor, + residual_arg: torch.Tensor, + y_arg: torch.Tensor, + weight: torch.Tensor, + eps: float, + run_config: dict = None, +): + M, N = x_arg.shape + MAX_FUSED_SIZE = 65536 // x_arg.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + if BLOCK_SIZE > 16384: + BLOCK_SIZE = 16384 + if not run_config: + run_config = {"num_warps": rmsnorm_num_warps} + _add_rms_norm_fwd_fused[(M,)]( + x_arg, + residual_arg, + y_arg, + weight, + x_arg.stride(0), + x_arg.stride(1), + residual_arg.stride(0), + residual_arg.stride(1), + y_arg.stride(0), + y_arg.stride(1), + N, + eps, + HAS_WEIGHT=weight is not None, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=run_config["num_warps"], + ) + return y_arg + + +def add_rmsnorm_forward(x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float, out=None): + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, x.shape[-1]) + residual_arg = residual.view(-1, x.shape[-1]) + y_arg = y.view(-1, x.shape[-1]) + assert x_arg.shape == residual_arg.shape == y_arg.shape + if weight is not None: + assert x_arg.shape[-1] == weight.shape[0] + assert y.data_ptr() == y_arg.data_ptr() + _add_rmsnorm_forward(x_arg, residual_arg, y_arg, weight, eps) + return y + + def torch_rms_norm(x, weight, eps): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight diff --git a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py index e86d2e819e..57a1c4d0a3 100644 --- a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py +++ b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py @@ -34,10 +34,11 @@ def _fwd_kernel_repack_kv_index( @torch.no_grad() -def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): +def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index, zero_output: bool = True): batch_size = req_index.shape[0] - # flashinfer requires out_kv_index to be zeroed before use - out_kv_index.zero_() + # Some flashinfer callers need zero-filled padding outside the valid indptr range. + if zero_output: + out_kv_index.zero_() BLOCK = 64 grid = ( batch_size, diff --git a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py index 109e813220..e3ae9493c7 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py +++ b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py @@ -24,6 +24,16 @@ def __init__(self, mem_manager): super().__init__(mem_manager) self.linear_config = LinearAttCacheConfig.load_from_args() + @staticmethod + def _get_persisted_full_att_layer_num(mem_manager) -> int: + persisted_full_att = getattr(mem_manager, "persisted_full_att_layer_num", None) + if persisted_full_att is None: + main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0]) + draft_full_att = getattr(mem_manager, "draft_full_att_layers", 0) + persisted_full_att = main_full_att + draft_full_att + assert 0 < persisted_full_att <= mem_manager.kv_buffer.shape[0] + return int(persisted_full_att) + def load_cpu_cache_to_gpu( self, mem_indexes: torch.Tensor, @@ -76,11 +86,14 @@ def load_cpu_cache_to_gpu( copy_cpu_cache_to_kv_buffer, ) + # Restore the persisted full-attn slice: main slots followed by MTP draft slots. + persisted_full_att = self._get_persisted_full_att_layer_num(mem_manager) + copy_cpu_cache_to_kv_buffer( mem_indexes=mem_indexes, big_page_buffer_ids=big_page_buffer_ids_gpu, page_indexes=page_indexes, - gpu_full_att_kv_state=mem_manager.kv_buffer, + gpu_full_att_kv_state=mem_manager.kv_buffer[:persisted_full_att], cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, @@ -169,12 +182,15 @@ def offload_gpu_kv_to_cpu_cache( copy_kv_buffer_to_cpu_cache, ) + # Persist the full-attn slice used for prefix reuse: main slots followed by MTP draft slots. + persisted_full_att = self._get_persisted_full_att_layer_num(mem_manager) + copy_kv_buffer_to_cpu_cache( mem_indexes=mem_indexes, page_indexes=page_indexes, page_readies=page_readies, big_page_buffer_ids=big_page_buffer_ids_gpu, - gpu_kv_full_att_state=mem_manager.kv_buffer, + gpu_kv_full_att_state=mem_manager.kv_buffer[:persisted_full_att], cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer, cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor, diff --git a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py index c7ce9d96ba..566ce5ea3f 100644 --- a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py @@ -208,9 +208,9 @@ def write_req_to_page( dp_mems: List["Qwen3NextMemManager"], ): conv_page, ssm_page = self.view_page_to_linear_att_state(page_index) - req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1) + conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx) for tp_index, mem in enumerate(dp_mems): - self._write_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page) + self._write_one_rank(mem, tp_index, conv_req_idx, ssm_req_idx, conv_page, ssm_page) return def read_page_to_req( @@ -220,21 +220,27 @@ def read_page_to_req( dp_mems: List["Qwen3NextMemManager"], ): conv_page, ssm_page = self.view_page_to_linear_att_state(page_index) - req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1) + conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx) for tp_index, mem in enumerate(dp_mems): - self._read_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page) + self._read_one_rank(mem, tp_index, conv_req_idx, ssm_req_idx, conv_page, ssm_page) return + def _get_req_state_indexes(self, req_idx: int): + mtp_size = get_env_start_args().mtp_step + 1 + # Conv is one widened slot per request; SSM keeps the historical S+1 block layout. + return req_idx, req_idx * mtp_size + def _write_one_rank( self, mem: "Qwen3NextMemManager", tp_index: int, - req_buffer_idx: int, + conv_req_idx: int, + ssm_req_idx: int, conv_page: torch.Tensor, ssm_page: torch.Tensor, ): - conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...] - ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...] + conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]] + ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...] self._copy_conv_state_to_page(conv_state, conv_page, mem, tp_index) self._copy_ssm_state_to_page(ssm_state, ssm_page, mem, tp_index) return @@ -408,12 +414,13 @@ def _read_one_rank( self, mem: "Qwen3NextMemManager", tp_index: int, - req_buffer_idx: int, + conv_req_idx: int, + ssm_req_idx: int, conv_page: torch.Tensor, ssm_page: torch.Tensor, ): - conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...] - ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...] + conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]] + ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...] self._copy_page_to_conv_state(conv_page, conv_state, mem, tp_index) self._copy_page_to_ssm_state(ssm_page, ssm_state, mem, tp_index) return diff --git a/lightllm/common/linear_att_cache_manager/config_objs.py b/lightllm/common/linear_att_cache_manager/config_objs.py index bc39067069..f48b9865a3 100644 --- a/lightllm/common/linear_att_cache_manager/config_objs.py +++ b/lightllm/common/linear_att_cache_manager/config_objs.py @@ -8,6 +8,16 @@ logger = init_logger(__name__) +def get_mtp_draft_full_att_layer_num(args) -> int: + # mtp_mode -> draft model 增加的 full-att KV 层数(与 envs_utils.get_added_mtp_kv_layer_num 同口径)。 + mtp_mode = getattr(args, "mtp_mode", None) + if mtp_mode == "eagle_with_att": + return 1 + if mtp_mode == "vanilla_with_att": + return getattr(args, "mtp_step", 0) + return 0 + + @dataclasses.dataclass class LinearAttCacheConfig: tp_world_size: int @@ -30,6 +40,7 @@ class LinearAttCacheConfig: ssm_state_dtype: torch.dtype full_attention_interval: int all_layer_num: int # 包括 linear att 和 full att 的层加起来的层数 + draft_full_att_layer_num: int = 0 def get_conv_dim(self): # 第一项对应q的参数,第二项对应k的参数,第三项对应v的参数 @@ -41,9 +52,25 @@ def get_conv_dim(self): + self.head_linear_v_dim * self.num_linear_v_heads ) - def get_conv_state_shape(self): + def get_main_full_att_layer_num(self): + main_full_att_layer_num = self.all_layer_num - self.linear_layer_num + assert main_full_att_layer_num == self.all_layer_num // self.full_attention_interval + return main_full_att_layer_num + + def get_persisted_full_att_layer_num(self): + return self.get_main_full_att_layer_num() + self.draft_full_att_layer_num + + def get_persisted_conv_state_shape(self): + # NARROW shape used for the CPU/disk persisted page and ALL byte math. + # Persisted state is always the committed (narrow) sliding window. return (self.get_conv_dim(), self.conv_kernel_size - 1) + def get_gpu_conv_state_shape(self, mtp_step: int): + # WIDENED working shape for the GPU buffer: holds the tentatively + # rolled-in S speculative tokens before acceptance. width-1 + S, where + # S = mtp_step (a verify step has seqlen=S+1 -> width-1+(seqlen-1)). + return (self.get_conv_dim(), (self.conv_kernel_size - 1) + mtp_step) + def get_ssm_state_shape(self): return (self.num_linear_v_heads, self.head_linear_k_dim, self.head_linear_v_dim) @@ -66,7 +93,7 @@ def get_cpu_cache_full_att_bytes(self): ) assert big_page_token_num == get_env_start_args().cpu_cache_token_page_size full_att_bytes = 2 * self.full_att_all_num_kv_heads * self.full_att_head_dim * self.full_att_dtype.itemsize - a = full_att_bytes * (self.all_layer_num - self.linear_layer_num) * big_page_token_num + a = full_att_bytes * self.get_persisted_full_att_layer_num() * big_page_token_num return a def get_cpu_cache_conv_bytes(self): @@ -113,4 +140,5 @@ def load_from_args() -> "LinearAttCacheConfig": ssm_state_dtype=get_torch_dtype(args.linear_att_ssm_data_type), full_attention_interval=llm_config["full_attention_interval"], all_layer_num=n_layer, + draft_full_att_layer_num=get_mtp_draft_full_att_layer_num(args), ) diff --git a/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py b/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py index 30dc4d937c..2ab4313e37 100644 --- a/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py +++ b/lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py @@ -24,7 +24,7 @@ def __init__( self.conv_state_cache = LayerCache( size=self.size, dtype=self.linear_config.conv_state_dtype, - shape=self.linear_config.get_conv_state_shape(), + shape=self.linear_config.get_persisted_conv_state_shape(), layer_num=self.linear_config.linear_layer_num, device="cpu", size_first=True, diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad35..38a6d37727 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -19,6 +19,24 @@ logger = init_logger(__name__) +# Width of req_to_next_token_ids: holds the seed token + up to (WIDTH - 1) MTP draft tokens. +REQ_NEXT_TOKEN_IDS_WIDTH = 8 + + +def _format_nbytes(nbytes: int) -> str: + mib = nbytes / (1024**2) + gib = nbytes / (1024**3) + return f"{mib:.2f} MiB ({gib:.2f} GiB)" + + +def assert_mtp_step_within_next_token_ids_width(mtp_step: int) -> None: + assert mtp_step <= REQ_NEXT_TOKEN_IDS_WIDTH - 1, ( + f"mtp_step={mtp_step} exceeds {REQ_NEXT_TOKEN_IDS_WIDTH - 1}; " + f"req_to_next_token_ids width is {REQ_NEXT_TOKEN_IDS_WIDTH} " + "(widening it is an explicit follow-up, spec §9)" + ) + + class _ReqNode: def __init__(self, index): self.index = index @@ -75,6 +93,11 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num + # Always resident: infer_batch.copy_linear_att_state_to_cache_buffer reads this + # unconditionally in the linear-att cache-copy path (even for mtp_step=0, where + # accept_len=1 is the correct non-widened value). MTP overwrites the live slots. + self.req_to_accept_len = torch.ones((max_request_num + 1,), dtype=torch.int32, device="cuda") + def alloc(self): return self.req_list.alloc() @@ -117,7 +140,7 @@ def __init__(self, max_request_num): self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_next_token_ids = torch.zeros( - (max_request_num + 1, 8), + (max_request_num + 1, REQ_NEXT_TOKEN_IDS_WIDTH), dtype=torch.int64, device="cuda", ) @@ -236,15 +259,13 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_con self.big_page_token_num = ( get_env_start_args().linear_att_page_block_num * get_env_start_args().linear_att_hash_page_size ) - assert ( - self.mtp_step == 0 - ), "currently only support mtp_step 0 for simplicity, more mtp_step support will be added in the future" + assert_mtp_step_within_next_token_ids_width(self.mtp_step) self.linear_config = linear_config self.req_to_conv_state = LayerCache( - size=(max_request_num + 1) * (self.mtp_step + 1), + size=(max_request_num + 1), dtype=self.linear_config.conv_state_dtype, - shape=self.linear_config.get_conv_state_shape(), + shape=self.linear_config.get_gpu_conv_state_shape(mtp_step=self.mtp_step), layer_num=self.linear_config.linear_layer_num, device="cuda", ) @@ -255,14 +276,30 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_con layer_num=self.linear_config.linear_layer_num, device="cuda", ) + conv_buffer = self.req_to_conv_state.buffer + ssm_buffer = self.req_to_ssm_state.buffer + conv_nbytes = conv_buffer.numel() * conv_buffer.element_size() + ssm_nbytes = ssm_buffer.numel() * ssm_buffer.element_size() + logger.info( + "linear att gpu state buffers: " + f"max_request_num={max_request_num}, hold_request_id={self.HOLD_REQUEST_ID}, mtp_step={self.mtp_step}, " + f"conv_state shape={tuple(conv_buffer.shape)}, dtype={conv_buffer.dtype}, " + f"nbytes={conv_nbytes}, memory={_format_nbytes(conv_nbytes)}; " + f"ssm_state shape={tuple(ssm_buffer.shape)}, dtype={ssm_buffer.dtype}, " + f"nbytes={ssm_nbytes}, memory={_format_nbytes(ssm_nbytes)}; " + f"total memory={_format_nbytes(conv_nbytes + ssm_nbytes)}" + ) return def init_linear_att_state(self, req: "InferReq"): - index = req.req_idx * (self.mtp_step + 1) - conv_state = self.req_to_conv_state.buffer[:, index, ...] - ssm_state = self.req_to_ssm_state.buffer[:, index, ...] - conv_state.fill_(0) - ssm_state.fill_(0) + conv_index = req.req_idx + ssm_start = req.req_idx * (self.mtp_step + 1) + self.req_to_conv_state.buffer[:, conv_index, ...].fill_(0) + # #17: zero the FULL (mtp_step + 1)-row SSM block, not just canonical row +0, so a future + # first-step verify reading offset>0 after fresh init never hits a never-written row (NaN). + self.req_to_ssm_state.buffer[:, ssm_start : ssm_start + (self.mtp_step + 1), ...].fill_(0) + if self.req_to_accept_len is not None: + self.req_to_accept_len[req.req_idx] = 1 return def get_mamba_cache(self, layer_idx_in_all: int): @@ -281,10 +318,13 @@ def copy_big_page_buffer_to_linear_att_state(self, big_page_buffer_idx: int, req big_page_buffers: LinearAttCacheManager = self.mem_manager.linear_att_big_page_buffers conv_state, ssm_state = big_page_buffers.get_state_cache(buffer_idx=big_page_buffer_idx) - dest_req_idx = req.req_idx * (self.mtp_step + 1) - - self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state - self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state + conv_dest = req.req_idx + ssm_dest = req.req_idx * (self.mtp_step + 1) + narrow_w = conv_state.shape[-1] # persisted (narrow) width + self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state + self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state + if self.req_to_accept_len is not None: + self.req_to_accept_len[req.req_idx] = 1 return def copy_small_page_buffer_to_linear_att_state( @@ -293,9 +333,13 @@ def copy_small_page_buffer_to_linear_att_state( conv_state, ssm_state = linear_att_small_page_buffers.get_state_cache( buffer_idx=req.shared_kv_node.small_page_buffer_idx ) - dest_req_idx = req.req_idx * (self.mtp_step + 1) + conv_dest = req.req_idx + ssm_dest = req.req_idx * (self.mtp_step + 1) + narrow_w = conv_state.shape[-1] # TODO 下面这个从 cpu cache 拷贝数据的 gpu的操作,是否是阻塞的操作。 # 同时,非连续对象的拷贝,可能存在效率问题。 - self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state - self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state + self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state + self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state + if self.req_to_accept_len is not None: + self.req_to_accept_len[req.req_idx] = 1 return diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index f15badde25..31ff5b8b65 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -100,6 +100,24 @@ def all_reduce(self, input_: torch.Tensor) -> None: return return dist.all_reduce(input_, group=self.device_group) + def all_reduce_residual_rmsnorm( + self, + input_: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + alloc_func=torch.empty, + ): + if self.flashinfer_reduce is None: + return None + return self.flashinfer_reduce.all_reduce_residual_rmsnorm( + input_, + residual=residual, + norm_weight=norm_weight, + eps=eps, + alloc_func=alloc_func, + ) + def all_gather_into_tensor(self, output_: torch.Tensor, input_: torch.Tensor, async_op: bool = False) -> None: return dist.all_gather_into_tensor(output_, input_, group=self.device_group, async_op=async_op) @@ -235,6 +253,27 @@ def all_reduce( return dist.all_reduce(input_, op, group, async_op) +def all_reduce_residual_rmsnorm( + input_: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, + alloc_func=torch.empty, +): + if _is_single_group(group=group): + return None + if isinstance(group, CustomProcessGroup): + return group.all_reduce_residual_rmsnorm( + input_, + residual=residual, + norm_weight=norm_weight, + eps=eps, + alloc_func=alloc_func, + ) + return None + + def all_gather_into_tensor( output_: torch.Tensor, input_: torch.Tensor, diff --git a/lightllm/distributed/flashinfer_all_reduce.py b/lightllm/distributed/flashinfer_all_reduce.py index 27856d9ac7..d469c3bc66 100644 --- a/lightllm/distributed/flashinfer_all_reduce.py +++ b/lightllm/distributed/flashinfer_all_reduce.py @@ -132,4 +132,39 @@ def all_reduce(self, inp: torch.Tensor) -> torch.Tensor: input=inp, workspace=self._workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce, + launch_with_pdl=True, ) + + def all_reduce_residual_rmsnorm( + self, + inp: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + alloc_func=torch.empty, + ): + if ( + residual.shape != inp.shape + or residual.dtype != inp.dtype + or not residual.is_cuda + or norm_weight.dtype != inp.dtype + or norm_weight.shape[0] != inp.shape[-1] + ): + return None + if not self.should_use(inp): + return None + + residual_out = alloc_func(inp.shape, dtype=inp.dtype, device=inp.device) + norm_out = alloc_func(inp.shape, dtype=inp.dtype, device=inp.device) + flashinfer_comm.allreduce_fusion( + input=inp, + workspace=self._workspace, + pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + launch_with_pdl=True, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=norm_weight, + rms_eps=eps, + ) + return residual_out, norm_out diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py index d23475c1cf..35bd6f7925 100644 --- a/lightllm/models/qwen3_5/infer_struct.py +++ b/lightllm/models/qwen3_5/infer_struct.py @@ -1,8 +1,4 @@ -import torch -from typing import List - from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args class Qwen35InferStateInfo(Qwen2VLInferStateInfo): @@ -12,8 +8,7 @@ def __init__(self): def init_some_extra_state(self, model): super().init_some_extra_state(model) - self.b_att_seq_len = self.b_seq_len - mtp_step = get_env_start_args().mtp_step + from lightllm.common.basemodel.mtp_verify_extra_state import init_mtp_verify_extra_state - self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + init_mtp_verify_extra_state(self, model) return diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py index afbd02a482..d9ac369960 100644 --- a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -28,14 +28,24 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkvo_gate_proj = getattr(layer_weight, "qkvo_gate_proj", None) + if qkvo_gate_proj is None: + qkv_out = layer_weight.qkv_proj.mm(input) + o_gate = layer_weight._o_gate_proj.mm(input) + else: + qkv_gate_out = qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - o_gate = layer_weight._o_gate_proj.mm(input) - # In-place sigmoid for gate - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], diff --git a/lightllm/models/qwen3_5_moe_mtp/__init__.py b/lightllm/models/qwen3_5_moe_mtp/__init__.py new file mode 100644 index 0000000000..c8885f8869 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel + +__all__ = ["Qwen3_5MoeMTPModel"] diff --git a/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py new file mode 100644 index 0000000000..dcad1087d4 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/layer_weights/__init__.py @@ -0,0 +1,5 @@ +from lightllm.models.qwen3_5_moe_mtp.layer_weights.transformer_layer_weight import ( + Qwen3_5MoeMTPTransformerLayerWeight, +) + +__all__ = ["Qwen3_5MoeMTPTransformerLayerWeight"] diff --git a/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..b2700aa0bd --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,154 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import ( + COLMMWeight, + FusedMoeWeight, + ROWMMWeight, + QKVROWNMMWeight, +) +from lightllm.models.qwen3_5_moe.layer_weights.transformer_layer_weight import ( + Qwen35MOETransformerLayerWeight, +) +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen3_5MoeMTPTransformerLayerWeight(Qwen35MOETransformerLayerWeight): + # MTP draft-model weights live under the `mtp.layers.*` checkpoint namespace; the + # main-model attention/norm names (`model.layers.*`) are retargeted to it, while the + # MoE expert / shared-expert names are built directly with the mtp prefix below. + + _MAIN_PREFIX = "model.layers." + _MTP_PREFIX = "mtp.layers." + + _ATTN_NORM_NAME_ATTRS = ( + "_q_weight_name", + "_q_norm_name", + "_q_bias_name", + "_k_weight_name", + "_k_norm_name", + "_k_bias_name", + "_v_weight_name", + "_v_bias_name", + "_kv_weight_name", + "_kv_bias_name", + "_o_weight_name", + "_o_bias_name", + "_att_norm_weight_name", + "_att_norm_bias_name", + "_ffn_norm_weight_name", + "_ffn_norm_bias_name", + ) + + def _retarget(self, name): + if name is None: + return None + return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1) + + def _retarget_attn_norm_names(self): + for attr in self._ATTN_NORM_NAME_ATTRS: + setattr(self, attr, self._retarget(getattr(self, attr))) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight" + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) + + def _init_weight_names(self): + super()._init_weight_names() + self._retarget_attn_norm_names() + + def _init_moe(self): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], + weight_names=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) + self._init_gated_ffn() + + def _init_gated_ffn(self): + hidden_size = self.network_config_["hidden_size"] + if "shared_expert_intermediate_size" not in self.network_config_: + return + + prefix = f"{self._MTP_PREFIX}{self.layer_num_}.mlp.shared_expert" + inter_size = self.network_config_["shared_expert_intermediate_size"] + if get_env_start_args().enable_ep_moe: + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + tp_rank=0, + tp_world_size=1, + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), + tp_rank=0, + tp_world_size=1, + ) + else: + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), + ) + + self.ffn_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/qwen3_5_moe_mtp/model.py b/lightllm/models/qwen3_5_moe_mtp/model.py new file mode 100644 index 0000000000..022864f6b3 --- /dev/null +++ b/lightllm/models/qwen3_5_moe_mtp/model.py @@ -0,0 +1,8 @@ +from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel +from lightllm.models.qwen3_5_moe_mtp.layer_weights.transformer_layer_weight import ( + Qwen3_5MoeMTPTransformerLayerWeight, +) + + +class Qwen3_5MoeMTPModel(Qwen3_5MTPModel): + transformer_weight_class = Qwen3_5MoeMTPTransformerLayerWeight diff --git a/lightllm/models/qwen3_5_mtp/__init__.py b/lightllm/models/qwen3_5_mtp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_mtp/layer_infer/__init__.py b/lightllm/models/qwen3_5_mtp/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 0000000000..906a0ab62c --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,40 @@ +import torch + +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_5_mtp.layer_weights.pre_and_post_layer_weight import Qwen3_5MTPPreAndPostLayerWeight +from lightllm.models.llama.infer_struct import LlamaInferStateInfo + + +class Qwen3_5MTPPreLayerInfer(Qwen3VLMultimodalPreLayerInfer): + def __init__(self, network_config): + super().__init__(network_config) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_fuse( + self, + input_embdings: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3_5MTPPreAndPostLayerWeight, + ) -> torch.Tensor: + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert ( + input_embdings.shape[0] == tgt_embdings.shape[0] + ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" + + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + return layer_weight.eh_proj_weight_.mm(cat_embdings) + + def context_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3_5MTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_fuse(input_embdings, infer_state, layer_weight) + + def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3_5MTPPreAndPostLayerWeight): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_fuse(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_5_mtp/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..25c56a0d7e --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,45 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + NoTpGEMMANormWeight, + ROWMMWeight, +) +from lightllm.common.quantization import Quantcfg + + +class Qwen3_5MTPPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, quant_cfg: Quantcfg): + super().__init__(data_type, network_config) + self.quant_cfg: Quantcfg = quant_cfg + hidden_size = network_config["hidden_size"] + + self.eh_proj_weight_ = ROWMMWeight( + in_dim=hidden_size * 2, + out_dims=[hidden_size], + weight_names="mtp.fc.weight", + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(0, "eh_proj"), + tp_rank=0, + tp_world_size=1, + ) + self.enorm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_embedding.weight", + data_type=self.data_type_, + ) + self.hnorm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.pre_fc_norm_hidden.weight", + data_type=self.data_type_, + ) + self.final_norm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, + weight_name="mtp.norm.weight", + data_type=self.data_type_, + ) + + # Shared with the main Qwen3.5 model, injected by the model class (not loaded here). + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None + return diff --git a/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..4c76b4bce0 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,80 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, QKVROWNMMWeight +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( + Qwen35TransformerLayerWeight, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3_5MTPTransformerLayerWeight(Qwen35TransformerLayerWeight): + # MTP draft-model weights live under the `mtp.layers.*` checkpoint namespace, so every + # main-model layer name (`model.layers.*`) is retargeted to it at load time. + + _MAIN_PREFIX = "model.layers." + _MTP_PREFIX = "mtp.layers." + + _ATTN_NORM_NAME_ATTRS = ( + "_q_weight_name", + "_q_norm_name", + "_q_bias_name", + "_k_weight_name", + "_k_norm_name", + "_k_bias_name", + "_v_weight_name", + "_v_bias_name", + "_kv_weight_name", + "_kv_bias_name", + "_o_weight_name", + "_o_bias_name", + "_att_norm_weight_name", + "_att_norm_bias_name", + "_ffn_norm_weight_name", + "_ffn_norm_bias_name", + ) + + def _retarget(self, name): + if name is None: + return None + return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1) + + def _retarget_attn_norm_names(self): + for attr in self._ATTN_NORM_NAME_ATTRS: + setattr(self, attr, self._retarget(getattr(self, attr))) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight" + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) + + def _init_weight_names(self): + super()._init_weight_names() + # Retarget all main-model layer key names to the mtp.* namespace. + self._retarget_attn_norm_names() + # MLP (dense) projection names retargeted by Qwen35TransformerLayerWeight. + self._gate_weight_name = self._retarget(self._gate_weight_name) + self._gate_bias_name = self._retarget(self._gate_bias_name) + self._up_weight_name = self._retarget(self._up_weight_name) + self._up_bias_name = self._retarget(self._up_bias_name) + self._gate_up_weight_name = self._retarget(self._gate_up_weight_name) + self._gate_up_bias_name = self._retarget(self._gate_up_bias_name) + self._down_weight_name = self._retarget(self._down_weight_name) + self._down_bias_name = self._retarget(self._down_bias_name) diff --git a/lightllm/models/qwen3_5_mtp/model.py b/lightllm/models/qwen3_5_mtp/model.py new file mode 100644 index 0000000000..b98524a997 --- /dev/null +++ b/lightllm/models/qwen3_5_mtp/model.py @@ -0,0 +1,109 @@ +from typing import List + +from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel +from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import Qwen35TransformerLayerInfer +from lightllm.models.qwen3_5_mtp.layer_weights.pre_and_post_layer_weight import Qwen3_5MTPPreAndPostLayerWeight +from lightllm.models.qwen3_5_mtp.layer_weights.transformer_layer_weight import Qwen3_5MTPTransformerLayerWeight +from lightllm.models.qwen3_5_mtp.layer_infer.pre_layer_infer import Qwen3_5MTPPreLayerInfer +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3_5MTPModel(Qwen3_5TpPartModel): + + pre_and_post_weight_class = Qwen3_5MTPPreAndPostLayerWeight + pre_layer_infer_class = Qwen3_5MTPPreLayerInfer + transformer_weight_class = Qwen3_5MTPTransformerLayerWeight + transformer_layer_infer_class = Qwen35TransformerLayerInfer + + # MTP draft model: reuses the main model's req/mem managers and rope caches, and is + # marked so the decode CUDA-graph / padding paths detect it (is_mtp_draft_model). + is_mtp_draft_model = True + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_config(self): + super()._init_config() + self.config["full_attention_interval"] = 1 + self.config["num_hidden_layers"] = 1 + self.config["n_layer"] = 1 + return + + def _init_some_value(self): + super()._init_some_value() + self.layers_num = 1 + return + + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(0, self.config["n_layer"]) + ] + # Shared with the main Qwen3.5 model (mtp_use_dedicated_embeddings: false). + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + return + + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None + # Build the single draft layer with layer_num == 0 so that, with + # full_attention_interval == 1, it takes the full-attention (mrope) path. + super()._init_infer_layer(start_layer_index=0) + self._assign_draft_kv_slot() + return + + def _assign_draft_kv_slot(self): + mem_manager = self.main_model.mem_manager + main_full_att = getattr(mem_manager, "main_full_att_layer_num", None) + interval = self.main_model.config["full_attention_interval"] + if main_full_att is None: + # Non-hybrid / unexpected mem_manager: nothing to remap. + return + + draft_idx = len(self.mtp_previous_draft_models) + draft_full_att_layers = getattr(mem_manager, "draft_full_att_layers", None) + if draft_full_att_layers is not None: + assert draft_idx < draft_full_att_layers, ( + f"draft_idx {draft_idx} out of range for draft_full_att_layers " + f"{draft_full_att_layers}; mem_manager not sized for this many MTP draft blocks" + ) + draft_kv_slot = main_full_att + draft_idx + layer_infer = self.layers_infer[0] + layer_infer.layer_num_ = draft_kv_slot * interval + logger.info( + f"Qwen3.5 MTP draft layer assigned dedicated full-attn KV slot {draft_kv_slot} " + f"(layer_num_={layer_infer.layer_num_}, interval={interval}, main_full_att={main_full_att})" + ) + return diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py index 0006a682f1..bb1516673a 100644 --- a/lightllm/models/qwen3next/infer_struct.py +++ b/lightllm/models/qwen3next/infer_struct.py @@ -1,6 +1,4 @@ -import torch from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args class Qwen3NextInferStateInfo(LlamaInferStateInfo): @@ -10,7 +8,7 @@ def __init__(self): def init_some_extra_state(self, model): super().init_some_extra_state(model) - self.b_att_seq_len = self.b_seq_len - mtp_step = get_env_start_args().mtp_step - self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index + from lightllm.common.basemodel.mtp_verify_extra_state import init_mtp_verify_extra_state + + init_mtp_verify_extra_state(self, model) return diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index bb48bfe49c..cc4a653abd 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -10,8 +10,10 @@ from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager from typing import Tuple -from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.gdn_decode_pack import conv_pack_gdn_decode_inputs +from lightllm.models.qwen3next.triton_kernel.shared_expert_gate import add_shared_expert_gate_, sigmoid_mul_ from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule from lightllm.distributed import all_reduce @@ -114,15 +116,14 @@ def _compute_shared_expert( ): input = input.view(-1, self.embed_dim_) shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) - gate = layer_weight.ffn_gate.mm(input).sigmoid_() - shared_expert_out.mul_(gate) - return shared_expert_out + gate = layer_weight.ffn_gate.mm(input) + return shared_expert_out, gate def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + shared_expert_out, gate = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape @@ -135,15 +136,16 @@ def _moe_ffn_tp( use_grouped_topk=False, topk_group=None, num_expert_group=None, + shared_expert_out=shared_expert_out, + shared_expert_gate=gate, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) - hidden_states.add_(shared_expert_out) return hidden_states def _moe_ffn_edp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + shared_expert_out, gate = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input token_num, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -158,7 +160,7 @@ def _moe_ffn_edp( is_prefill=infer_state.is_prefill, ) ep_output = ep_output.view(token_num, hidden_dim) - ep_output.add_(shared_expert_out) + add_shared_expert_gate_(ep_output, shared_expert_out, gate) return ep_output def _get_qkv( @@ -169,13 +171,25 @@ def _get_qkv( ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkvo_gate_proj = getattr(layer_weight, "qkvo_gate_proj", None) + if qkvo_gate_proj is None: + qkv_out = layer_weight.qkv_proj.mm(input) + o_gate = layer_weight._o_gate_proj.mm(input) + else: + qkv_gate_out = qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ * 2 + + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_ * 2, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1, ) - o_gate = layer_weight._o_gate_proj.mm(input) - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -199,15 +213,24 @@ def _get_o( input, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, + ) -> torch.Tensor: + o_tensor = self._get_o_local(input=input, infer_state=infer_state, layer_weight=layer_weight) + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) + return o_tensor + + def _get_o_local( + self, + input, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, ) -> torch.Tensor: """Output projection with gating (in-place multiply to save one allocation).""" if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - input.mul_(infer_state.gate_value) + sigmoid_mul_(input, infer_state.gate_value) infer_state.gate_value = None o_tensor = layer_weight.o_proj.mm(input) - o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor # ==================== GDN Helper Methods ==================== @@ -254,11 +277,24 @@ def gdn_forward( if is_prefill: core_attn_out, z = self._gdn_prefill_wrapper_run(mixed_qkvzba, infer_state, layer_weight) + elif getattr(infer_state, "is_mtp_verify", False): + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) + conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) + core_attn_out = self._gdn_verify_kernel( + mixed_qkv, + conv_states, + ssm_states, + a, + b, + infer_state, + layer_weight, + ) else: mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) - core_attn_out = self._gdn_decode_kernel( + core_attn_out, z = self._gdn_decode_kernel( mixed_qkv, + z, conv_states, ssm_states, a, @@ -374,7 +410,7 @@ def _gdn_prefill_kernel( layer_weight.linear_conv1d.mm_param.weight, bias=layer_weight.linear_conv1d.bias, query_start_loc=infer_state.b1_cu_q_seq_len, - cache_indices=infer_state.b_buffer_idx, + cache_indices=infer_state.b_conv_buffer_idx, has_initial_state=infer_state.b_ready_cache_len > 0, conv_states=conv_states, activation=self.activation, @@ -406,6 +442,7 @@ def _gdn_prefill_kernel( def _gdn_decode_kernel( self, mixed_qkv: torch.Tensor, + z: torch.Tensor, conv_states: torch.Tensor, ssm_states: torch.Tensor, a: torch.Tensor, @@ -413,25 +450,79 @@ def _gdn_decode_kernel( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): - mixed_qkv = causal_conv1d_update( + # Recurrent processing with fused gating. Decode uses a specialized + # conv+pack kernel to avoid materializing the post-conv qkv tensor + # before immediately splitting it into q/k/v. + query, key, value, z, a, b = conv_pack_gdn_decode_inputs( + mixed_qkv, + z, + a, + b, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + layer_weight.linear_conv1d.bias, + infer_state.b_conv_buffer_idx, + self.activation, + self.tp_num_k_heads, + self.head_k_dim, + self.tp_num_v_heads, + self.head_v_dim, + ) + core_attn_out, _ = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + initial_state=ssm_states, + inplace_final_state=True, + ssm_state_indices=infer_state.b_buffer_idx, + use_qk_l2norm_in_kernel=True, + A_log=layer_weight.linear_A_log.weight, + dt_bias=layer_weight.linear_dt_bias.weight, + a_raw=a, + b_raw=b, + ) + return core_attn_out, z + + def _gdn_verify_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ): + from lightllm.models.qwen3next.triton_kernel.causal_conv1d_spec import ( + causal_conv1d_update as causal_conv1d_update_spec, + ) + + mixed_qkv = causal_conv1d_update_spec( mixed_qkv, conv_states, layer_weight.linear_conv1d.mm_param.weight, bias=layer_weight.linear_conv1d.bias, activation=self.activation, - conv_state_indices=infer_state.b_buffer_idx, + conv_state_indices=infer_state.b_conv_buffer_idx, + num_accepted_tokens=infer_state.b_num_accepted_tokens, + query_start_loc=infer_state.b_gdn_verify_cu_seqlens, ) - # Recurrent processing with fused gating - # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally - query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) + query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=False) + assert infer_state.b_ssm_index_rows.dim() == 2, "SSM index rows must be 2D [N, S+1]" + # #8b: b_num_accepted_tokens >= 1 is guaranteed upstream (init sets accept_len=1; the + # offload/snapshot guards bound it to [1, mtp_step+1]). The old per-layer per-step .all() + # D2H sync stalled the GPU on the eager decode hot path; it is redundant here. core_attn_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, v=value, initial_state=ssm_states, inplace_final_state=True, - ssm_state_indices=infer_state.b_buffer_idx, + cu_seqlens=infer_state.b_gdn_verify_cu_seqlens.to(torch.long), + ssm_state_indices=infer_state.b_ssm_index_rows, + ssm_state_write_indices=infer_state.b_ssm_index_rows, + num_accepted_tokens=infer_state.b_num_accepted_tokens, use_qk_l2norm_in_kernel=True, A_log=layer_weight.linear_A_log.weight, dt_bias=layer_weight.linear_dt_bias.weight, diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 0d415ca0e8..51b702039b 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -11,6 +11,83 @@ QKVROWNMMWeight, QKGEMMANormWeight, ) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import get_row_slice_mixin +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size + + +class QKVGatedROWNMMWeight(MMWeightTpl): + def __init__( + self, + in_dim, + q_head_num, + kv_head_num, + head_dim, + weight_names, + data_type, + bias_names=None, + quant_method=None, + tp_rank=None, + tp_world_size=None, + ): + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + self.q_repeat_times = 1 + self.kv_repeat_times = 1 + assert ( + q_head_num % self.tp_world_size_ == 0 + ), f"q_head_num must be divisible by tp_world_size_, found {q_head_num} % {self.tp_world_size_}" + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or vice versa, " + f"found {kv_head_num} % {self.tp_world_size_}" + ) + q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + super().__init__( + in_dim=in_dim, + out_dims=[q_hidden_size, kv_hidden_size, kv_hidden_size, q_hidden_size], + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.q_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.q_repeat_times, + ) + self.kv_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.kv_repeat_times, + ) + + def _get_param_slicer(self, sub_child_index): + if sub_child_index == 0 or sub_child_index == 3: + return self.q_param_slicer + return self.kv_param_slicer + + def load_hf_weights(self, weights): + super().load_hf_weights(weights) + if self.bias_names is not None: + for sub_child_index, bias_name in enumerate(self.bias_names): + if bias_name is None: + self.bias_list[sub_child_index].zero_() + self.bias_list[sub_child_index].load_ok = True + + def _get_tp_padded_head_num(self, head_num): + if head_num % self.tp_world_size_ == 0: + return head_num // self.tp_world_size_ + if self.tp_world_size_ % head_num == 0: + self.kv_repeat_times = self.tp_world_size_ // head_num + return self.kv_repeat_times * head_num // self.tp_world_size_ + raise ValueError( + f"head_num must be divisible by tp_world_size_ or vice versa, found {head_num} % {self.tp_world_size_}" + ) class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -23,25 +100,39 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): def _init_qkv(self): in_dim = self.n_embed q_out_dim = self.q_head_num_ * self.head_dim - self.qkv_proj = QKVROWNMMWeight( - in_dim=in_dim, - q_head_num=self.q_head_num_, - kv_head_num=self.k_head_num_, - head_dim=self.head_dim, - weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], - data_type=self.data_type_, - bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], - quant_method=self.get_quant_method("qkv_proj"), - ) self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" - self._o_gate_proj = ROWMMWeight( - in_dim=in_dim, - out_dims=[q_out_dim], - weight_names=[self._o_gate_weight_name], - data_type=self.data_type_, - bias_names=None, - quant_method=self.get_quant_method("o_gate_proj"), - ) + qkv_quant = self.get_quant_method("qkv_proj") + gate_quant = self.get_quant_method("o_gate_proj") + if qkv_quant.method_name == "none" and gate_quant.method_name == "none": + self.qkvo_gate_proj = QKVGatedROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name, self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name, None], + quant_method=qkv_quant, + ) + else: + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=qkv_quant, + ) + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=gate_quant, + ) def _init_weight(self): if self.is_linear_attention_layer: diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 9b5e9b7a50..f64386a50d 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -16,7 +16,12 @@ from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.req_manager import ReqManagerForMamba -from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig +from lightllm.common.linear_att_cache_manager.config_objs import ( + LinearAttCacheConfig, + get_mtp_draft_full_att_layer_num, +) +from lightllm.common.basemodel.batch_objs import ModelOutput +from lightllm.distributed import all_reduce, all_reduce_residual_rmsnorm logger = init_logger(__name__) @@ -51,6 +56,29 @@ def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch def autotune_layers(self): return self.config["full_attention_interval"] + def _autotune_extra_warmup(self): + if not self.trans_layers_weight: + return + + norm_weight = self.trans_layers_weight[0].ffn_norm_weight_ + add_rmsnorm = getattr(norm_weight, "add_rmsnorm", None) + if add_rmsnorm is None: + return + + hidden_dim = norm_weight.weight.shape[0] + max_batch_size = min(self.graph_max_batch_size, self.batch_max_tokens) + warmup_batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + warmup_batch_sizes = [bs for bs in warmup_batch_sizes if bs <= max_batch_size] + if max_batch_size not in warmup_batch_sizes: + warmup_batch_sizes.append(max_batch_size) + + for batch_size in sorted(set(warmup_batch_sizes)): + x = torch.zeros((batch_size, hidden_dim), dtype=self.data_type, device="cuda") + residual = torch.zeros_like(x) + out = torch.empty_like(x) + add_rmsnorm(input=x, residual=residual, eps=self.layers_infer[0].eps_, out=out) + return + def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) @@ -59,6 +87,7 @@ def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 start_args: StartArgs = get_env_start_args() ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + draft_full_att_layers = get_mtp_draft_full_att_layer_num(start_args) self.linear_config = LinearAttCacheConfig( tp_world_size=self.tp_world_size_, full_att_all_num_kv_heads=self.config["num_key_value_heads"], @@ -78,17 +107,24 @@ def _init_mem_manager(self): ssm_state_dtype=ssm_dtype_dict[start_args.linear_att_ssm_data_type], full_attention_interval=self.config["full_attention_interval"], all_layer_num=self.config["n_layer"], + draft_full_att_layer_num=draft_full_att_layers, ) + main_full_att = self.linear_config.get_main_full_att_layer_num() + persisted_full_att = self.linear_config.get_persisted_full_att_layer_num() + self.mem_manager = Qwen3NextMemManager( size=self.max_total_token_num, dtype=self.data_type, num_kv_heads=self.num_kv_heads, head_dim=self.config["head_dim"], - full_att_layer_num=self.linear_config.all_layer_num - self.linear_config.linear_layer_num, + full_att_layer_num=persisted_full_att, linear_config=self.linear_config, mem_fraction=self.mem_fraction, ) + self.mem_manager.main_full_att_layer_num = main_full_att + self.mem_manager.draft_full_att_layers = draft_full_att_layers + self.mem_manager.persisted_full_att_layer_num = persisted_full_att def _init_req_manager(self): create_max_seq_len = 0 @@ -102,3 +138,80 @@ def _init_req_manager(self): self.max_req_num, create_max_seq_len, None, linear_config=LinearAttCacheConfig.load_from_args() ) return + + def _token_forward(self, infer_state: Qwen3NextInferStateInfo): + input_ids = infer_state.input_ids + input_embs = self.pre_infer.token_forward(input_ids, infer_state, self.pre_post_weight) + input_embs = self.pre_infer._tpsp_sp_split(input=input_embs, infer_state=infer_state) + + next_att_normed = None + for i in range(self.layers_num): + layer: Qwen3NextTransformerLayerInfer = self.layers_infer[i] + layer_weight: Qwen3NextTransformerLayerWeight = self.trans_layers_weight[i] + + if next_att_normed is None: + input1 = layer._att_norm(input_embs, infer_state, layer_weight) + else: + input1 = next_att_normed + next_att_normed = None + + if layer.is_linear_attention_layer: + o = layer.token_attention_forward(input1, infer_state, layer_weight) + input1 = layer._add_residual_ffn_norm(input_embs, o, infer_state, layer_weight) + o = None + else: + q, cache_kv = layer._get_qkv(input1, infer_state, layer_weight) + layer._post_cache_kv(cache_kv, infer_state, layer_weight) + o = layer._token_attention_kernel(q, infer_state, layer_weight) + q = None + o = layer._get_o_local(o, infer_state, layer_weight) + fused = None + if layer.tp_world_size_ > 1: + fused = all_reduce_residual_rmsnorm( + o, + residual=input_embs.view(-1, layer.embed_dim_), + norm_weight=layer_weight.ffn_norm_weight_.weight, + eps=layer.eps_, + group=infer_state.dist_group, + alloc_func=layer.alloc_tensor, + ) + if fused is None: + if layer.tp_world_size_ > 1: + all_reduce(o, group=infer_state.dist_group) + input1 = layer._add_residual_ffn_norm(input_embs, o, infer_state, layer_weight) + else: + input_embs, input1 = fused + o = None + + ffn_out = layer._ffn(input1, infer_state, layer_weight) + ffn_out = ffn_out.view(-1, layer.embed_dim_) + + if i + 1 < self.layers_num: + next_layer: Qwen3NextTransformerLayerInfer = self.layers_infer[i + 1] + next_layer_weight: Qwen3NextTransformerLayerWeight = self.trans_layers_weight[i + 1] + add_rmsnorm = getattr(next_layer_weight.att_norm_weight_, "add_rmsnorm", None) + if add_rmsnorm is not None: + next_att_normed = add_rmsnorm( + input=input_embs, + residual=ffn_out, + eps=next_layer.eps_, + alloc_func=next_layer.alloc_tensor, + ) + continue + + input_embs.add_(ffn_out) + + last_input_embs = self.post_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + predict_logits: torch.Tensor = self.post_infer.token_forward( + last_input_embs, infer_state=infer_state, layer_weight=self.pre_post_weight + ) + + model_output = ModelOutput(logits=predict_logits.contiguous()) + if self.is_mtp_mode: + input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + model_output.mtp_main_output_hiddens = input_embs.contiguous() + + if infer_state.is_cuda_graph: + model_output.to_no_ref_tensor() + + return model_output diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py new file mode 100644 index 0000000000..2f0e22fa3f --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d_spec.py @@ -0,0 +1,468 @@ +# Vendored from vLLM v0.14.1 +# source: vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# commit: d7de043d55d1dd629554467e23874097e1c48993 +# Adapted for LightLLM: imports point at standard triton; the vLLM-specific +# block-table params (block_idx_last_scheduled_token, initial_state_idx, +# null_block_id) are dropped — LightLLM uses contiguous per-request slots. +# Supports spec-decode: writes per-position conv state to a single widened slot +# per request and reads from offset (num_accepted_tokens-1). +# +# Upstream copyright notice: +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Tri Dao. +# Adapted from +# https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +from typing import Optional + +import torch +import triton +import triton.language as tl + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + # LightLLM uses contiguous per-request slots, so the cache block for both + # the initial-state read and the final write is always conv_state_indices[idx_seq]. + conv_state_init = 0 + current_last_index = 0 + + # cache_idx + conv_states_input_coord = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init).to( + tl.int64 + ) + + if USE_PAD_SLOT: # noqa + if conv_states_input_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_VARLEN: + query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) + query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64) + # revise state_len and seqlen + state_len = state_len - (seqlen - (query_end_index - query_start_index)) + seqlen = query_end_index - query_start_index + x_offset = query_start_index * stride_x_token + o_offset = query_start_index * stride_o_token + else: + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = ( + conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 6: + conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] + col4 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_states_input_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] + + x_ptrs = x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + # Write the updated state back. In LightLLM the read and write slots are the + # same contiguous per-request slot (current_last_index == conv_state_init == 0), + # so this resolves to the same conv_state_indices[idx_seq] used for the read. + conv_states_offset = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index).to( + tl.int64 + ) + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[ + None, : + ] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok + )[ + :, None + ] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 5: + w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor + w_col4 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 6: + w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor + w_col5 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + elif KERNEL_WIDTH == 5: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x + elif KERNEL_WIDTH == 6: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & (idx_feats < dim) # token-index # feature-index + o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + pad_slot_id: int = -1, +): + """Spec-decode capable conv1d update. When num_accepted_tokens/query_start_loc + are None it must behave like a single-token decode update. x may be (batch, dim) + single-token or (num_tokens, dim) flattened varlen with query_start_loc grouping + each request's S+1 candidates. conv_state is (num_slots, dim, state_len) with + state_len = (width-1)+S widened. Read offset = num_accepted_tokens-1; writes to + the same slot. + + Args: + x: input tensor of shape ``(batch, dim)`` (single-token decode), + ``(batch, dim, seqlen)`` (single/multi token), or ``(num_tokens, dim)`` + flattened varlen grouped by ``query_start_loc``. + conv_state: ``(num_slots, dim, state_len)`` with ``state_len >= width - 1``. + For spec decode the slot is widened to ``(width - 1) + S`` where ``S`` is + the number of speculative tokens (so ``seqlen == S + 1``). + weight: depthwise filter of shape ``(dim, width)``. + bias: optional ``(dim,)`` bias. + activation: ``None``, ``"silu"`` or ``"swish"``. + cache_seqlens: accepted for call-compatibility with the non-spec wrapper; + unused here. + conv_state_indices: ``(batch,)`` int32 mapping each request to its conv_state + slot. Required when ``query_start_loc`` is given. + num_accepted_tokens: ``(batch,)`` int32. When not None the conv_state read + offset for each request is ``num_accepted_tokens - 1`` (sliding window + spec-decode update). + query_start_loc: ``(batch + 1,)`` int32 varlen cumulative token offsets; when + None the call is a plain single-/multi-token decode update. + pad_slot_id: slot id that marks padded entries to skip. + + Returns: + Output tensor with the same shape as ``x`` (the kernel overwrites ``x`` in + place), one conv output per input token. + """ + if activation is not None: + assert activation in ["silu", "swish"] + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + unsqueeze = query_start_loc is None and x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + if query_start_loc is None: + batch, dim, seqlen = x.shape + else: + assert conv_state_indices is not None + batch = conv_state_indices.size(0) + dim = x.size(1) + # The MTP verify layout is uniform (mtp_step+1) tokens per request, so seqlen is + # structurally x.size(0) // batch. Compute it without a D2H sync on query_start_loc on + # BOTH the capture and eager paths (#8a) — the eager .item() ran once per GDN layer per + # decode step. .item() is also illegal during CUDA-graph capture. + assert x.size(0) % batch == 0, "varlen conv update expects a uniform per-request length" + seqlen = x.size(0) // batch + _, width = weight.shape + # conv_state: (num_slots, dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + # adopt the strategy in vLLM that overwrites 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + if query_start_loc is None: + # X (batch, dim, seqlen) + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + # X (num_tokens, dim) + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 + + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0 + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + conv_state_indices, + num_accepted_tokens, + query_start_loc, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_VARLEN=query_start_loc is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) + if unsqueeze: + out = out.squeeze(-1) + return out.to(original_x_dtype) diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py new file mode 100644 index 0000000000..a025e35c64 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py @@ -0,0 +1,284 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + BLOCK_QKV: tl.constexpr, + BLOCK_GATE: tl.constexpr, +): + row = tl.program_id(0) + qkv_offsets = tl.arange(0, BLOCK_QKV) + + q_mask = qkv_offsets < q_dim + q_vals = tl.load(mixed_qkv + row * stride_m_b + qkv_offsets * stride_m_d, mask=q_mask, other=0.0) + tl.store(q_out + row * q_dim + qkv_offsets, q_vals, mask=q_mask) + + k_mask = qkv_offsets < k_dim + k_vals = tl.load( + mixed_qkv + row * stride_m_b + (q_dim + qkv_offsets) * stride_m_d, + mask=k_mask, + other=0.0, + ) + tl.store(k_out + row * k_dim + qkv_offsets, k_vals, mask=k_mask) + + v_mask = qkv_offsets < v_dim + v_vals = tl.load( + mixed_qkv + row * stride_m_b + (q_dim + k_dim + qkv_offsets) * stride_m_d, + mask=v_mask, + other=0.0, + ) + tl.store(v_out + row * v_dim + qkv_offsets, v_vals, mask=v_mask) + + z_vals = tl.load(z_raw + row * stride_z_b + qkv_offsets, mask=v_mask, other=0.0) + tl.store(z_out + row * v_dim + qkv_offsets, z_vals, mask=v_mask) + + gate_offsets = tl.arange(0, BLOCK_GATE) + gate_mask = gate_offsets < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + gate_offsets * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + gate_offsets * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + gate_offsets, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + gate_offsets, b_vals, mask=gate_mask) + + +@torch.no_grad() +def pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_qkv = triton.next_power_of_2(max(q_dim, k_dim, v_dim)) + block_gate = triton.next_power_of_2(gate_dim) + _pack_gdn_decode_kernel[(batch,)]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + BLOCK_QKV=block_qkv, + BLOCK_GATE=block_gate, + num_warps=4, + ) + return q, k, v, z, a, b + + +@triton.jit +def _conv_pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + stride_s_b: tl.constexpr, + stride_s_d: tl.constexpr, + stride_s_w: tl.constexpr, + stride_w_d: tl.constexpr, + stride_w_w: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + conv_dim: tl.constexpr, + HAS_BIAS: tl.constexpr, + APPLY_SILU: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offs = block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < conv_dim + state_idx = tl.load(conv_state_indices + row) + + x = tl.load(mixed_qkv + row * stride_m_b + offs * stride_m_d, mask=mask, other=0.0).to(tl.float32) + s0 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 0 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + s1 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 1 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + s2 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 2 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + w0 = tl.load(conv_weight + offs * stride_w_d + 0 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w1 = tl.load(conv_weight + offs * stride_w_d + 1 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w2 = tl.load(conv_weight + offs * stride_w_d + 2 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w3 = tl.load(conv_weight + offs * stride_w_d + 3 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + y = s0 * w0 + s1 * w1 + s2 * w2 + x * w3 + if HAS_BIAS: + bias = tl.load(conv_bias + offs, mask=mask, other=0.0).to(tl.float32) + y += bias + if APPLY_SILU: + y = y * tl.sigmoid(y) + + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 0 * stride_s_w, s1, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 1 * stride_s_w, s2, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 2 * stride_s_w, x, mask=mask) + + q_mask = offs < q_dim + k_mask = (offs >= q_dim) & (offs < q_dim + k_dim) + v_mask = (offs >= q_dim + k_dim) & (offs < conv_dim) + tl.store(q_out + row * q_dim + offs, y, mask=q_mask) + tl.store(k_out + row * k_dim + (offs - q_dim), y, mask=k_mask) + tl.store(v_out + row * v_dim + (offs - q_dim - k_dim), y, mask=v_mask) + + z_mask = offs < v_dim + z_vals = tl.load(z_raw + row * stride_z_b + offs, mask=z_mask, other=0.0) + tl.store(z_out + row * v_dim + offs, z_vals, mask=z_mask) + + gate_mask = offs < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + offs * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + offs * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + offs, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + offs, b_vals, mask=gate_mask) + + +@torch.no_grad() +def conv_pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + conv_state: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + conv_state_indices: torch.Tensor, + activation: str, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + conv_dim = q_dim + k_dim + v_dim + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_size = 256 + grid = (batch, triton.cdiv(conv_dim, block_size)) + _conv_pack_gdn_decode_kernel[grid]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + conv_weight.stride(0), + conv_weight.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + conv_dim, + HAS_BIAS=conv_bias is not None, + APPLY_SILU=activation in ["silu", "swish"], + BLOCK_SIZE=block_size, + num_warps=8, + ) + return q, k, v, z, a, b diff --git a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py new file mode 100644 index 0000000000..c2b110def6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py @@ -0,0 +1,108 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _add_shared_expert_gate_kernel( + hidden, + shared, + gate, + stride_h_m: tl.constexpr, + stride_h_n: tl.constexpr, + stride_s_m: tl.constexpr, + stride_s_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + + hidden_ptrs = hidden + row * stride_h_m + offs * stride_h_n + shared_vals = tl.load(shared + row * stride_s_m + offs * stride_s_n, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + hidden_vals = tl.load(hidden_ptrs, mask=mask, other=0.0).to(tl.float32) + gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) + out = hidden_vals + shared_vals * gate_vals + tl.store(hidden_ptrs, out.to(hidden.dtype.element_ty), mask=mask) + + +@triton.jit +def _sigmoid_mul_kernel( + x, + gate, + stride_x_m: tl.constexpr, + stride_x_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + x_ptrs = x + row * stride_x_m + offs * stride_x_n + x_vals = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) + tl.store(x_ptrs, (x_vals * gate_vals).to(x.dtype.element_ty), mask=mask) + + +@torch.no_grad() +def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + hidden_arg = hidden.view(-1, hidden.shape[-1]) + shared_arg = shared.view(-1, hidden.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert hidden_arg.shape == shared_arg.shape + assert gate_arg.shape[0] == hidden_arg.shape[0] and gate_arg.shape[1] in (1, hidden_arg.shape[1]) + _, n = hidden_arg.shape + block_n = triton.next_power_of_2(n) + _add_shared_expert_gate_kernel[(hidden_arg.shape[0],)]( + hidden_arg, + shared_arg, + gate_arg, + hidden_arg.stride(0), + hidden_arg.stride(1), + shared_arg.stride(0), + shared_arg.stride(1), + gate_arg.stride(0), + gate_arg.stride(1), + n, + gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return hidden + + +@torch.no_grad() +def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + x_arg = x.view(-1, x.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1]) + _, n = x_arg.shape + block_n = triton.next_power_of_2(n) + _sigmoid_mul_kernel[(x_arg.shape[0],)]( + x_arg, + gate_arg, + x_arg.stride(0), + x_arg.stride(1), + gate_arg.stride(0), + gate_arg.stride(1), + n, + gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return x diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1bdf8f3427..7cb0599d69 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -401,7 +401,7 @@ def make_argument_parser() -> argparse.ArgumentParser: default=["auto"], help="""decode attention kernel used in llm. auto: automatically select best backend based on GPU and available packages - (priority: flashinfer > fa3 > triton)""", + (priority: fa3 > flashinfer > triton)""", ) parser.add_argument( "--vit_att_backend", diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 0d934c44c9..df79913e23 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -314,6 +314,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req "n": request.n, "best_of": request.n, "add_special_tokens": False, + "return_logprobs": request.logprobs is not None, "seed": request.seed, } @@ -822,6 +823,7 @@ async def completions_impl(request: CompletionRequest, raw_request: Request) -> "n": request.n, "best_of": request.best_of, "add_special_tokens": False, + "return_logprobs": request.logprobs is not None, "seed": request.seed, } if request.max_completion_tokens is not None: diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index c39559f5f6..3515fbf1a1 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -304,6 +304,7 @@ class SamplingParams(ctypes.Structure): ), # whether to add spaces between special tokens when decoding ("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True ("disable_prompt_cache", ctypes.c_bool), # whether to disable prompt cache + ("return_logprobs", ctypes.c_bool), # whether generated token logprobs are required by the caller ("seed", ctypes.c_int64), # random seed ] @@ -340,6 +341,7 @@ def init(self, tokenizer, **kwargs): self.add_special_tokens = kwargs.get("add_special_tokens", True) self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) self.print_eos_token = kwargs.get("print_eos_token", False) + self.return_logprobs = kwargs.get("return_logprobs", True) self.seed = kwargs.get("seed", -1) self.exponential_decay_length_penalty = ExponentialDecayLengthPenalty() @@ -486,6 +488,7 @@ def to_dict(self): "add_spaces_between_special_tokens": self.add_spaces_between_special_tokens, "print_eos_token": self.print_eos_token, "disable_prompt_cache": self.disable_prompt_cache, + "return_logprobs": self.return_logprobs, "seed": self.seed, } diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index dfb8866601..c0560419ff 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -145,7 +145,7 @@ async def wait_to_model_ready(self): "weight_dir": self.model_weightdir, "load_way": self.load_way, "max_total_token_num": self.max_total_token_num, - "max_req_num": self.args.running_max_req_size + 8, + "max_req_num": self.args.running_max_req_size, "max_seq_length": self.args.max_req_total_len + 8, # 留一点余量 "nccl_host": self.args.nccl_host, "nccl_port": self.args.nccl_port, diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 5c2d0d45fb..be8c022594 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -361,6 +361,11 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L if not self.is_linear_att_mixed_model: return + # 当 dynamic prompt cache 被禁用时 radix_cache 为 None,没有大页/小页缓冲可写, + # 线性层状态仅存于 req_manager 的 GPU buffer 即可,直接跳过跨请求缓存拷贝。 + if self.radix_cache is None: + return + # 大页对应的 linear att 的拷贝 big_page_token_num = self.args.linear_att_hash_page_size * self.args.linear_att_page_block_num big_page_buffer_ids = [] @@ -384,6 +389,12 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer + # accept 数量改由 GPU 常驻的 req_to_accept_len 按 req_idx gather(不再读 req.mtp_accept_len)。 + req_idxs = torch.tensor( + [req.req_idx for req in reqs], dtype=torch.int32, requires_grad=False, device="cpu" + ).cuda(non_blocking=True) + b_num_accepted_tokens = self.req_manager.req_to_accept_len[req_idxs] + copy_linear_att_state_to_kv_buffer( b_req_idx=b_req_idx, big_page_buffer_ids=big_page_buffer_ids, @@ -392,6 +403,7 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L cpu_kv_conv_state=self.radix_cache.linear_att_big_page_buffers.conv_state_cache.buffer, cpu_kv_ssm_state=self.radix_cache.linear_att_big_page_buffers.ssm_state_cache.buffer, mtp_step=self.args.mtp_step, + b_num_accepted_tokens=b_num_accepted_tokens, ) assert not self.args.disable_chunked_prefill, "chunked prefill mode must be enabled for linear att mixed model" @@ -407,9 +419,20 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L self.radix_cache.linear_att_small_page_buffers.alloc_one_state_cache() ) if req.tail_linear_att_small_page_buffer_id is not None: - src_buffer_idx = req.req_idx * (self.args.mtp_step + 1) - gpu_conv_state = self.req_manager.req_to_conv_state.buffer[:, src_buffer_idx, ...] - gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...] + # 冷路径(prefill 跨小页边界):单标量从 GPU buffer 读回做 Python 切片下标。 + accept_len = int(self.req_manager.req_to_accept_len[req.req_idx].item()) + assert 1 <= accept_len <= self.args.mtp_step + 1, ( + f"mtp_accept_len={accept_len} out of range " + f"[1, {self.args.mtp_step + 1}]; would slice past the widened conv slot" + ) + canonical_off = accept_len - 1 + conv_src_idx = req.req_idx + ssm_src_idx = req.req_idx * (self.args.mtp_step + 1) + canonical_off + narrow_w = self.req_manager.linear_config.get_persisted_conv_state_shape()[-1] + gpu_conv_state = self.req_manager.req_to_conv_state.buffer[ + :, conv_src_idx, ..., canonical_off : canonical_off + narrow_w + ] + gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, ssm_src_idx, ...] dst_buffer_idx = req.tail_linear_att_small_page_buffer_id dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache( diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a65dfb1bbb..bfd0b95856 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -16,7 +16,7 @@ from lightllm.common.linear_att_cache_manager import LinearAttCacheManager from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import LinearAttPagedRadixCache from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache -from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput +from lightllm.common.basemodel.batch_objs import ModelOutput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name @@ -41,10 +41,6 @@ ) from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack -from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel -from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel -from lightllm.models.mistral_mtp.model import MistralMTPModel -from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import PDChunckedTransTaskRet @@ -328,27 +324,20 @@ def init_mtp_draft_model(self, main_kvargs: dict): "mtp_previous_draft_models": self.draft_models.copy(), } - # Select MTP model class based on model type + # Select MTP model class based on model type (single source of truth: #10). + from lightllm.server.router.model_infer.mode_backend.mtp_model_factory import create_mtp_draft_model + model_type = mtp_model_cfg.get("model_type", "") - if model_type == "deepseek_v3": - assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] - self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) - elif model_type == "qwen3_moe": - assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] - self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) - elif model_type == "mistral": - assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] - self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "glm4_moe_lite": - assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] - self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) - else: - raise ValueError(f"Unsupported MTP model type: {model_type}") + self.draft_models.append(create_mtp_draft_model(model_type, self.args.mtp_mode, mtp_model_kvargs)) self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return - def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): + def _async_copy_next_token_infos_to_pin_mem( + self, + next_token_ids: torch.Tensor, + next_token_logprobs: Optional[torch.Tensor], + ): """ 这个函数会把next token id和logprobs保存到pinned memory中 这样可以保障post_handle 函数可以读取到正常的输出结果。 @@ -357,9 +346,13 @@ def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, key="next_token_ids", gpu_tensor=next_token_ids, ) - next_token_logprobs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( - key="next_token_logprobs", - gpu_tensor=next_token_logprobs, + next_token_logprobs_cpu = ( + None + if next_token_logprobs is None + else g_pin_mem_manager.async_copy_from_gpu_tensor( + key="next_token_logprobs", + gpu_tensor=next_token_logprobs, + ) ) return next_token_ids_cpu, next_token_logprobs_cpu @@ -712,7 +705,7 @@ def _post_handle( self, run_reqs: List[InferReq], next_token_ids: List[int], - next_token_logprobs: List[float], + next_token_logprobs: Optional[List[float]], run_reqs_update_packs: List[InferReqUpdatePack], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, pd_prefill_chunked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, @@ -721,9 +714,18 @@ def _post_handle( extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。 """ - for req_obj, next_token_id, next_token_logprob, pack in zip( - run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs - ): + if next_token_logprobs is None: + iter_items = zip(run_reqs, next_token_ids, run_reqs_update_packs) + else: + iter_items = zip(run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs) + + for item in iter_items: + if next_token_logprobs is None: + req_obj, next_token_id, pack = item + next_token_logprob = 0.0 + else: + req_obj, next_token_id, next_token_logprob, pack = item + req_obj: InferReq = req_obj pack: InferReqUpdatePack = pack pack.handle( @@ -773,8 +775,7 @@ def _update_mtp_accept_ratio( def _gen_argmax_token_ids(self, model_output: ModelOutput): logits = model_output.logits - probs = torch.softmax(logits, dim=-1) - draft_next_token_ids_gpu = torch.argmax(probs, dim=-1) + draft_next_token_ids_gpu = torch.argmax(logits, dim=-1) return draft_next_token_ids_gpu def _sample_and_scatter_token( @@ -812,10 +813,38 @@ def _sample_and_scatter_token( mask=b_has_out, ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( - next_token_ids, next_token_logprobs + next_token_ids, + next_token_logprobs, ) return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + def _can_decode_pre_post_before_prev_post_handle( + self, + run_reqs: List[InferReq], + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + ) -> bool: + if not self.support_overlap: + return False + if extra_post_req_handle_func is not None or self.decode_mask_func is not None: + return False + if self.args.mtp_mode: + return False + + for req_obj in run_reqs: + if req_obj.mtp_step != 0: + return False + if req_obj.infer_aborted or req_obj.finish_status.is_finished(): + return False + + shm_param = req_obj.sampling_param.shm_param + if not shm_param.ignore_eos: + return False + if len(req_obj.stop_sequences) != 0: + return False + if req_obj.cur_output_len + 1 >= shm_param.max_new_tokens: + return False + return True + def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 792a10a788..2ae7cac322 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -1,5 +1,6 @@ import torch import time +import copy from typing import List, Optional, Callable, Dict, Any from queue import Queue from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend @@ -19,6 +20,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.common.basemodel.triton_kernel.mtp_utils import ( mtp_scatter_next_token_ids, + scatter_mtp_accept_len, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id @@ -163,11 +165,22 @@ def decode_normal( sync_event.record() # 第二阶段 - event_pack.notify_post_handle_and_wait_pre_post_handle() - update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=False) + can_pre_post_early = self._can_decode_pre_post_before_prev_post_handle( + run_reqs=run_reqs, + extra_post_req_handle_func=self.extra_post_req_handle_func, + ) + if can_pre_post_early: + event_pack.notify_post_handle_event.set() + update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=False) + event_pack.notify_forward_event.set() + event_pack.wait_pre_post_handle_event.wait() + else: + event_pack.notify_post_handle_and_wait_pre_post_handle() + update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=False) # 第三阶段 - event_pack.notify_forward_and_wait_post_handle() + if not can_pre_post_early: + event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() self._post_handle( run_reqs=run_reqs, @@ -241,22 +254,20 @@ def decode_mtp( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): - b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids - b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] - b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list( - key="b_req_mtp_start_loc", - data=b_req_mtp_start_loc, - dtype=torch.int32, - ).cuda(non_blocking=True) + n_real = model_input.batch_size // (self.mtp_step + 1) + b_req_mtp_start_loc = torch.arange(n_real, dtype=torch.int32, device="cuda") * (self.mtp_step + 1) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, b_req_idx=model_input.b_req_idx, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + scatter_mtp_accept_len( + self.model.req_manager.req_to_accept_len, b_req_mtp_start_loc, model_input.b_req_idx, mtp_accept_len + ) accepted_index_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="accepted_index", gpu_tensor=accepted_index, diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index e6b9d1c18d..9c83a5f352 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -1,3 +1,4 @@ +import copy import torch import time import torch.nn.functional as F @@ -20,7 +21,7 @@ from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager -from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids +from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids, scatter_mtp_accept_len from .control_state import DPControlState @@ -462,6 +463,9 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): b_req_idx=b_req_idx, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + scatter_mtp_accept_len( + self.model.req_manager.req_to_accept_len, b_req_mtp_start_loc, b_req_idx, mtp_accept_len + ) accepted_index_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="accepted_index", gpu_tensor=accepted_index, @@ -587,7 +591,6 @@ def _draft_decode_eagle( real_req_num = req_num // (self.mtp_step + 1) padded_req_num = model_input.batch_size // (self.mtp_step + 1) - real_req_num - eagle_mem_indexes_cpu = None if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(real_req_num * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(real_req_num * self.mtp_step) @@ -742,7 +745,6 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_mtp_index_cpu0 = model_input0.b_mtp_index b_mtp_index_cpu1 = model_input1.b_mtp_index with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -773,6 +775,9 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_req_idx=b_req_idx, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + scatter_mtp_accept_len( + self.model.req_manager.req_to_accept_len, b_req_mtp_start_loc, b_req_idx, mtp_accept_len + ) accepted_index_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="accepted_index", gpu_tensor=accepted_index, @@ -879,7 +884,7 @@ def _draft_decode_vanilla_overlap( draft_next_token_ids_gpu1 = torch.zeros((model_input1.batch_size), dtype=torch.int64, device="cuda") if req_num0 > 0: draft_next_token_ids_gpu0[0:req_num0].copy_(next_token_ids[0:req_num0], non_blocking=True) - if req_num1 > 1: + if req_num1 > 0: draft_next_token_ids_gpu1[0:req_num1].copy_( next_token_ids[req_num0 : (req_num0 + req_num1)], non_blocking=True ) @@ -937,7 +942,7 @@ def _draft_decode_eagle_overlap( draft_next_token_ids_gpu1 = torch.zeros((model_input1.batch_size), dtype=torch.int64, device="cuda") if req_num0 > 0: draft_next_token_ids_gpu0[0:req_num0].copy_(next_token_ids[0:req_num0], non_blocking=True) - if req_num1 > 1: + if req_num1 > 0: draft_next_token_ids_gpu1[0:req_num1].copy_( next_token_ids[req_num0 : (req_num0 + req_num1)], non_blocking=True ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index 5b29ea0510..14fd0c21c6 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,5 +1,5 @@ import torch -from typing import List, Tuple +from typing import List, Optional, Tuple, Union from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids @@ -7,8 +7,24 @@ from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.envs_utils import get_env_start_args +_flashinfer_top_k_top_p_sampling_from_logits = None +_flashinfer_top_k_top_p_sampling_from_logits_checked = False +_flashinfer_top_k_top_p_sampling_from_probs = None +_flashinfer_top_k_top_p_sampling_from_probs_checked = False +_flashinfer_top_p_sampling_from_probs = None +_flashinfer_top_p_sampling_from_probs_checked = False +_flashinfer_top_k_sampling_from_probs = None +_flashinfer_top_k_sampling_from_probs_checked = False +_uniform_tensor_cache = {} +_softmax_out_cache = {} +_is_flashinfer_sampling_backend = None + def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): + fast_next_token_ids = _try_flashinfer_sample_without_penalty(logits, reqs) + if fast_next_token_ids is not None: + return fast_next_token_ids.view(-1), None + ( b_req_idx, b_temperatures, @@ -23,6 +39,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): skip_top_k, skip_top_p, exist_req_use_random_seed, + need_logprobs, ) = _get_post_sample_tensors(reqs) eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) @@ -75,7 +92,18 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): cu_invalid_token_num=cu_invalid_token_num, ) - logits.div_(b_temperatures.view((-1, 1))) + if b_temperatures is not None: + logits.div_(b_temperatures.view((-1, 1))) + + if is_all_greedy and not need_logprobs: + batch_next_token_ids = torch.argmax(logits, -1) + if get_env_start_args().mtp_mode: + batch_next_token_logprobs = torch.zeros( + batch_next_token_ids.shape, dtype=torch.float32, device=batch_next_token_ids.device + ) + return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) + return batch_next_token_ids.view(-1), None + probs = torch.softmax(logits, dim=-1) if is_all_greedy: @@ -86,16 +114,82 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): elif skip_top_k and skip_top_p: # topk 等于整个词表,topp 等于1.0,等价于不进行topk topp过滤,直接进行随机采样,可以提升采样速度 batch_next_token_ids = _random_sample(probs, reqs, exist_req_use_random_seed) + if not need_logprobs: + return batch_next_token_ids.view(-1), None batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) else: batch_next_token_ids, batch_next_token_logprobs = _top_p_top_k_sample( - reqs, probs, b_top_ps, b_top_ks, exist_req_use_random_seed + reqs, + probs, + b_top_ps, + b_top_ks, + skip_top_k, + skip_top_p, + exist_req_use_random_seed, + need_logprobs, ) + if batch_next_token_logprobs is None: + return batch_next_token_ids.view(-1), None return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) +def _try_flashinfer_sample_without_penalty(logits: torch.Tensor, reqs: List[InferReq]) -> Optional[torch.Tensor]: + if not _is_flashinfer_sampling() or not reqs: + return None + + first_param = reqs[0].sampling_param.shm_param + top_p = first_param.top_p + top_k = first_param.top_k + temperature = first_param.temperature + vocab_size = reqs[0].vocab_size + + if top_k <= 1 or (top_k == vocab_size and top_p == 1.0): + return None + + for req in reqs: + shm_param = req.sampling_param.shm_param + if shm_param.return_logprobs: + return None + if req.generator is not None: + return None + if len(req.sampling_param.invalid_token_ids) != 0: + return None + if not shm_param.ignore_eos: + return None + if shm_param.presence_penalty != 0.0: + return None + if shm_param.frequency_penalty != 0.0: + return None + if shm_param.repetition_penalty != 1.0: + return None + if shm_param.temperature != temperature: + return None + if shm_param.top_p != top_p: + return None + if shm_param.top_k != top_k: + return None + + if temperature != 1.0: + logits.div_(temperature) + + if top_k == vocab_size and top_p != 1.0: + top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) + return _flashinfer_top_p_sample_from_logits(logits, top_p_tensor) + + top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) + top_k_tensor = _get_uniform_tensor(top_k, logits.shape[0], torch.int32, logits.device) + return _flashinfer_top_p_top_k_sample_from_logits(logits, top_p_tensor, top_k_tensor) + + +def _is_flashinfer_sampling() -> bool: + global _is_flashinfer_sampling_backend + if _is_flashinfer_sampling_backend is None: + _is_flashinfer_sampling_backend = get_env_start_args().sampling_backend == "flashinfer" + return _is_flashinfer_sampling_backend + + def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): probs_sort, probs_idx = probs.sort(dim=-1, descending=True) @@ -107,13 +201,123 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor return probs_sort, probs_idx +def _flashinfer_top_p_top_k_sample_from_logits( + logits: torch.Tensor, + b_top_ps: Union[torch.Tensor, float], + b_top_ks: Union[torch.Tensor, int], +) -> Optional[torch.Tensor]: + global _flashinfer_top_k_top_p_sampling_from_logits + global _flashinfer_top_k_top_p_sampling_from_logits_checked + + if not _flashinfer_top_k_top_p_sampling_from_logits_checked: + try: + from flashinfer.sampling import top_k_top_p_sampling_from_logits + except ImportError: + top_k_top_p_sampling_from_logits = None + _flashinfer_top_k_top_p_sampling_from_logits = top_k_top_p_sampling_from_logits + _flashinfer_top_k_top_p_sampling_from_logits_checked = True + + if _flashinfer_top_k_top_p_sampling_from_logits is None: + return None + + return _flashinfer_top_k_top_p_sampling_from_logits( + logits, + b_top_ks, + b_top_ps, + filter_apply_order="joint", + deterministic=True, + check_nan=False, + ) + + +def _flashinfer_top_p_sample_from_logits( + logits: torch.Tensor, top_p: Union[torch.Tensor, float] +) -> Optional[torch.Tensor]: + probs = _softmax_out(logits) + return _flashinfer_top_p_sample_from_probs(probs, top_p) + + +def _get_uniform_tensor(value: Union[float, int], size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + key = (str(device), dtype, size, value) + tensor = _uniform_tensor_cache.get(key) + if tensor is None: + tensor = torch.full((size,), value, dtype=dtype, device=device) + _uniform_tensor_cache[key] = tensor + return tensor + + +def _softmax_out(logits: torch.Tensor) -> torch.Tensor: + key = (str(logits.device), logits.dtype, tuple(logits.shape)) + probs = _softmax_out_cache.get(key) + if probs is None: + probs = torch.empty_like(logits) + _softmax_out_cache[key] = probs + torch.ops.aten._softmax.out(logits, -1, False, out=probs) + return probs + + +def _get_flashinfer_top_k_top_p_sampling_from_probs(): + global _flashinfer_top_k_top_p_sampling_from_probs + global _flashinfer_top_k_top_p_sampling_from_probs_checked + + if not _flashinfer_top_k_top_p_sampling_from_probs_checked: + try: + from flashinfer.sampling import top_k_top_p_sampling_from_probs + except ImportError: + top_k_top_p_sampling_from_probs = None + _flashinfer_top_k_top_p_sampling_from_probs = top_k_top_p_sampling_from_probs + _flashinfer_top_k_top_p_sampling_from_probs_checked = True + return _flashinfer_top_k_top_p_sampling_from_probs + + +def _flashinfer_top_p_sample_from_probs( + probs: torch.Tensor, top_p: Union[torch.Tensor, float] +) -> Optional[torch.Tensor]: + global _flashinfer_top_p_sampling_from_probs + global _flashinfer_top_p_sampling_from_probs_checked + + if not _flashinfer_top_p_sampling_from_probs_checked: + try: + from flashinfer.sampling import top_p_sampling_from_probs + except ImportError: + top_p_sampling_from_probs = None + _flashinfer_top_p_sampling_from_probs = top_p_sampling_from_probs + _flashinfer_top_p_sampling_from_probs_checked = True + + if _flashinfer_top_p_sampling_from_probs is None: + return None + + return _flashinfer_top_p_sampling_from_probs(probs, top_p, deterministic=True, check_nan=False) + + +def _flashinfer_top_k_sample_from_probs(probs: torch.Tensor, top_k: Union[torch.Tensor, int]) -> Optional[torch.Tensor]: + global _flashinfer_top_k_sampling_from_probs + global _flashinfer_top_k_sampling_from_probs_checked + + if not _flashinfer_top_k_sampling_from_probs_checked: + try: + from flashinfer.sampling import top_k_sampling_from_probs + except ImportError: + top_k_sampling_from_probs = None + _flashinfer_top_k_sampling_from_probs = top_k_sampling_from_probs + _flashinfer_top_k_sampling_from_probs_checked = True + + if _flashinfer_top_k_sampling_from_probs is None: + return None + + return _flashinfer_top_k_sampling_from_probs(probs, top_k, deterministic=True, check_nan=False) + + def _top_p_top_k_sample( reqs: List[InferReq], probs: torch.Tensor, - b_top_ps: torch.Tensor, - b_top_ks: torch.Tensor, + b_top_ps: Union[torch.Tensor, float], + b_top_ks: Union[torch.Tensor, int], + skip_top_k: bool, + skip_top_p: bool, exist_req_use_random_seed: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: + need_logprobs: bool, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: sampling_backend = get_env_start_args().sampling_backend if sampling_backend == "triton": @@ -123,19 +327,32 @@ def _top_p_top_k_sample( else: sampled_index = _random_sample(probs_sort, reqs, exist_req_use_random_seed).view(-1, 1) next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index) + if not need_logprobs: + return next_token_ids.view(-1), None next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index)) return next_token_ids.view(-1), next_token_logprobs.view(-1) elif sampling_backend == "flashinfer": - from flashinfer.sampling import top_k_top_p_sampling_from_probs - - batch_next_token_ids = top_k_top_p_sampling_from_probs( - probs, - b_top_ks, - b_top_ps, - filter_apply_order="joint", - check_nan=False, - ) + if skip_top_k: + batch_next_token_ids = _flashinfer_top_p_sample_from_probs(probs, b_top_ps) + elif skip_top_p: + batch_next_token_ids = _flashinfer_top_k_sample_from_probs(probs, b_top_ks) + else: + top_k_top_p_sampling_from_probs = _get_flashinfer_top_k_top_p_sampling_from_probs() + if top_k_top_p_sampling_from_probs is None: + raise ImportError("flashinfer.sampling.top_k_top_p_sampling_from_probs is not available") + batch_next_token_ids = top_k_top_p_sampling_from_probs( + probs, + b_top_ks, + b_top_ps, + filter_apply_order="joint", + deterministic=True, + check_nan=False, + ) + if batch_next_token_ids is None: + raise ImportError("flashinfer sampling op is not available") + if not need_logprobs: + return batch_next_token_ids.view(-1), None int64_batch_next_token_ids = torch.empty_like(batch_next_token_ids, dtype=torch.int64) int64_batch_next_token_ids[:] = batch_next_token_ids batch_next_token_probs = torch.gather(probs, dim=1, index=int64_batch_next_token_ids.view(-1, 1)) @@ -165,6 +382,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_k = True skip_top_p = True exist_req_use_random_seed = False + need_logprobs = False + all_temperature_one = True # invalid token ids invalid_token_ids: List[int] = [] @@ -192,6 +411,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_p = False if req_obj.generator is not None: exist_req_use_random_seed = True + if shm_param.return_logprobs: + need_logprobs = True + if shm_param.temperature != 1.0: + all_temperature_one = False req_idxes.append(req_obj.req_idx) invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids) cu_invalid_token_num.append(invalid_token_num_start) @@ -200,13 +423,25 @@ def _get_post_sample_tensors(reqs: List[InferReq]): invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids) req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) - temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) - top_ps_cpu = g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32) - top_ks_cpu = g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32) length_penalty_param_cpu = g_pin_mem_manager.gen_from_list( key="length_penalty_param", data=length_penalty_param, dtype=torch.int32 ) mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool) + temperatures_cpu = ( + None + if all_temperature_one + else g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) + ) + sampling_backend = get_env_start_args().sampling_backend + need_top_k_top_p_tensors = (not is_all_greedy) and (not (skip_top_k and skip_top_p)) + need_top_ps_tensor = need_top_k_top_p_tensors and (sampling_backend != "flashinfer" or not skip_top_p) + need_top_ks_tensor = need_top_k_top_p_tensors and (sampling_backend != "flashinfer" or not skip_top_k) + top_ps_cpu = ( + g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32) if need_top_ps_tensor else None + ) + top_ks_cpu = ( + g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32) if need_top_ks_tensor else None + ) if has_invalid_token_ids: invalid_token_ids_cpu = g_pin_mem_manager.gen_from_list( @@ -218,9 +453,9 @@ def _get_post_sample_tensors(reqs: List[InferReq]): return ( req_idxes_cpu.cuda(non_blocking=True), - temperatures_cpu.cuda(non_blocking=True), - top_ps_cpu.cuda(non_blocking=True), - top_ks_cpu.cuda(non_blocking=True), + temperatures_cpu.cuda(non_blocking=True) if temperatures_cpu is not None else None, + top_ps_cpu.cuda(non_blocking=True) if top_ps_cpu is not None else None, + top_ks_cpu.cuda(non_blocking=True) if top_ks_cpu is not None else None, length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, @@ -230,4 +465,5 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_k, skip_top_p, exist_req_use_random_seed, + need_logprobs, ) diff --git a/lightllm/server/router/model_infer/mode_backend/mtp_model_factory.py b/lightllm/server/router/model_infer/mode_backend/mtp_model_factory.py new file mode 100644 index 0000000000..1b4ade1ac0 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/mtp_model_factory.py @@ -0,0 +1,33 @@ +from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel +from lightllm.models.mistral_mtp.model import MistralMTPModel +from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel + + +def create_mtp_draft_model(model_type: str, mtp_mode: str, mtp_model_kvargs: dict): + """Single source of truth for (model_type, mtp_mode) -> MTP draft model (#10). + Shared by base_backend and the static MTP benchmark.""" + if model_type == "deepseek_v3": + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + return Deepseek3MTPModel(mtp_model_kvargs) + elif model_type == "qwen3_moe": + assert mtp_mode in ["vanilla_no_att", "eagle_no_att"] + return Qwen3MOEMTPModel(mtp_model_kvargs) + elif model_type == "mistral": + assert mtp_mode in ["vanilla_no_att", "eagle_no_att"] + return MistralMTPModel(mtp_model_kvargs) + elif model_type == "glm4_moe_lite": + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + return Glm4MoeLiteMTPModel(mtp_model_kvargs) + elif model_type in ("qwen3_5", "qwen3_5_text"): + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel + + return Qwen3_5MTPModel(mtp_model_kvargs) + elif model_type in ("qwen3_5_moe", "qwen3_5_moe_text"): + assert mtp_mode in ["vanilla_with_att", "eagle_with_att"] + from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel + + return Qwen3_5MoeMTPModel(mtp_model_kvargs) + else: + raise ValueError(f"Unsupported MTP model type: {model_type}") diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 494908cb10..ff5ad0127b 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -120,8 +120,8 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": if args.mtp_mode is not None: # TODO 可能会存在不同mtp模式的精度问题 - assert is_linear_att_mixed_model(args.model_dir) is False, "linear att mixed model does not support mtp mode" - cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() + if not is_linear_att_mixed_model(args.model_dir): + cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num() cpu_cache_page_num = int( (args.cpu_cache_storage_size * 1024 * 1024 * 1024) / (cpu_cache_meta.calcu_one_page_size()) diff --git a/lightllm/utils/sgl_utils.py b/lightllm/utils/sgl_utils.py index b48a62506d..b79a554f48 100644 --- a/lightllm/utils/sgl_utils.py +++ b/lightllm/utils/sgl_utils.py @@ -17,14 +17,16 @@ ) try: - from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, get_scheduler_metadata flash_attn_varlen_func = flash_attn_varlen_func flash_attn_with_kvcache = flash_attn_with_kvcache + get_scheduler_metadata = get_scheduler_metadata merge_state_v2 = sgl_ops.merge_state_v2 except: flash_attn_varlen_func = None flash_attn_with_kvcache = None + get_scheduler_metadata = None merge_state_v2 = None logger.warning( "sgl_kernel is not installed, or the installed version did not support fa3. \ diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index f2c900af09..c640a1152c 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -36,7 +36,7 @@ def test_model_inference(args): "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, "mem_fraction": args.mem_fraction, - "max_req_num": 2048, + "max_req_num": args.static_max_req_num, "batch_max_tokens": 1024, "run_mode": "normal", "max_seq_length": args.max_req_total_len, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index 72f06a919c..7689fc97df 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -69,7 +69,7 @@ def test_model_inference_mtp(args): "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, "mem_faction": args.mem_fraction, - "max_req_num": 2000, + "max_req_num": args.static_max_req_num, "batch_max_tokens": 2048, "run_mode": "normal", "max_seq_length": args.max_req_total_len, diff --git a/test/benchmark/static_inference/test_model.py b/test/benchmark/static_inference/test_model.py index 5b3751bcc3..78b841bbeb 100644 --- a/test/benchmark/static_inference/test_model.py +++ b/test/benchmark/static_inference/test_model.py @@ -30,6 +30,12 @@ def test_model_infer(self): parser.add_argument("--batch_size", type=int, default=None, help="batch size") parser.add_argument("--input_len", type=int, default=64, help="input sequence length") parser.add_argument("--output_len", type=int, default=128, help="output sequence length") + parser.add_argument( + "--static_max_req_num", + type=int, + default=2048, + help="max_req_num used by the standalone static benchmark harness", + ) parser.add_argument( "--profile", action="store_true", diff --git a/test/common/test_req_manager.py b/test/common/test_req_manager.py new file mode 100644 index 0000000000..7ab9062d7c --- /dev/null +++ b/test/common/test_req_manager.py @@ -0,0 +1,20 @@ +import ast +from pathlib import Path + + +def test_linear_att_state_buffer_log_reports_shape_and_memory(): + source = Path("lightllm/common/req_manager.py").read_text() + module = ast.parse(source) + + class_node = next( + node for node in module.body if isinstance(node, ast.ClassDef) and node.name == "ReqManagerForMamba" + ) + init_node = next(node for node in class_node.body if isinstance(node, ast.FunctionDef) and node.name == "__init__") + init_source = ast.unparse(init_node) + + assert "logger.info" in init_source + assert "conv_state shape=" in init_source + assert "ssm_state shape=" in init_source + assert "total memory=" in init_source + assert "_format_nbytes(conv_nbytes)" in init_source + assert "_format_nbytes(ssm_nbytes)" in init_source diff --git a/test/router/test_model_kvargs.py b/test/router/test_model_kvargs.py new file mode 100644 index 0000000000..6f8f9d2e36 --- /dev/null +++ b/test/router/test_model_kvargs.py @@ -0,0 +1,17 @@ +import ast +from pathlib import Path + + +def test_model_kvargs_uses_running_max_req_size_without_extra_padding(): + source = Path("lightllm/server/router/manager.py").read_text() + module = ast.parse(source) + + for node in ast.walk(module): + if not isinstance(node, ast.Dict): + continue + for key, value in zip(node.keys, node.values): + if isinstance(key, ast.Constant) and key.value == "max_req_num": + assert ast.unparse(value) == "self.args.running_max_req_size" + return + + raise AssertionError("max_req_num kvarg was not found")