Skip to content

Commit f775043

Browse files
author
niushengxiao
committed
fix: fix bugs
1 parent 53748ef commit f775043

5 files changed

Lines changed: 78 additions & 6 deletions

File tree

lightllm/common/basemodel/basemodel.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph
2424
from lightllm.common.quantization import Quantcfg
2525
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed
26+
from lightllm.utils.config_utils import _derive_max_req_total_len_from_model_config
2627
from lightllm.utils.log_utils import init_logger
2728
from lightllm.utils.dist_utils import get_dp_world_size
2829
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num
@@ -105,8 +106,8 @@ def __init__(self, kvargs):
105106
self._init_quant()
106107

107108
self._init_weights()
108-
self._init_req_manager()
109109
self._init_mem_manager()
110+
self._init_req_manager()
110111
# 因为类似 qwen3.5 的linear 架构的模型,其 req_manager 会存储运行时使用的大量 linear state
111112
# 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值
112113
self.req_manager.mem_manager = self.mem_manager
@@ -210,6 +211,26 @@ def _init_kv_move_buffer(self):
210211
if self.run_mode in ["prefill", "decode"]:
211212
self.mem_manager.alloc_kv_move_buffer(self.mem_manager.size)
212213

214+
# 推导出的max_req_total_len如果显存预算支持不了,需要进一步截断到可支持的长度
215+
def _safe_clamp_auto_max_req_total_len(self):
216+
max_total_token_num = self.mem_manager.size
217+
if self.max_seq_length is None or self.max_seq_length <= max_total_token_num:
218+
return
219+
220+
# 只截断推导生成的max_req_total_len
221+
old_max_req_total_len = self.max_seq_length - 8
222+
derived_max_req_total_len = _derive_max_req_total_len_from_model_config(self.weight_dir_)
223+
if derived_max_req_total_len is None or old_max_req_total_len != derived_max_req_total_len:
224+
return
225+
226+
supported_max_req_total_len = max(max_total_token_num - 8, 1)
227+
self.args.max_req_total_len = supported_max_req_total_len
228+
self.max_seq_length = supported_max_req_total_len + 8
229+
230+
if self.graph_max_len_in_batch == old_max_req_total_len:
231+
self.args.graph_max_len_in_batch = min(self.args.graph_max_len_in_batch, supported_max_req_total_len)
232+
self.graph_max_len_in_batch = self.args.graph_max_len_in_batch
233+
213234
def _check_mem_size(self):
214235
self.max_total_token_num = self.mem_manager.size
215236

@@ -232,6 +253,7 @@ def _check_mem_size(self):
232253
return
233254

234255
def _init_req_manager(self):
256+
self._safe_clamp_auto_max_req_total_len()
235257
create_max_seq_len = 0
236258

237259
if self.batch_max_tokens is not None:

lightllm/common/kv_cache_mem_manager/mem_manager.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1515
from lightllm.distributed.pynccl import PyNcclCommunicator
1616
from lightllm.utils.dist_utils import get_current_device_id
17-
from lightllm.utils.config_utils import get_num_key_value_heads
17+
from lightllm.utils.config_utils import get_num_key_value_heads, get_vocab_size
1818
from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io
1919
from lightllm.utils.device_utils import kv_trans_use_p2p
2020
from lightllm.utils.shm_utils import create_or_link_shm
@@ -61,22 +61,37 @@ def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]:
6161
def get_cell_size(self):
6262
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
6363

64+
def get_req_manager_reserve_bytes(self):
65+
args = get_env_start_args()
66+
max_request_num = args.running_max_req_size + 8
67+
max_sequence_length = max(args.batch_max_tokens or 0, (args.max_req_total_len or 0) + 8)
68+
req_state_num = max_request_num + 1
69+
70+
reserve_bytes = req_state_num * max_sequence_length * torch._utils._element_size(torch.int32)
71+
reserve_bytes += req_state_num * 4 * torch._utils._element_size(torch.float32)
72+
reserve_bytes += req_state_num * 8 * torch._utils._element_size(torch.int64)
73+
if args.penalty_counter_mode == "gpu_counter":
74+
reserve_bytes += req_state_num * get_vocab_size(args.model_dir) * torch._utils._element_size(torch.int32)
75+
return reserve_bytes
76+
6477
def profile_size(self, mem_fraction):
6578
if self.size is not None:
6679
return
6780

6881
torch.cuda.empty_cache()
6982
world_size = dist.get_world_size()
70-
71-
available_memory = get_available_gpu_memory(world_size) * mem_fraction
83+
available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction)
84+
req_manager_reserve_gb = self.get_req_manager_reserve_bytes() / (1024 ** 3)
85+
available_memory -= req_manager_reserve_gb
7286
cell_size = self.get_cell_size()
7387
self.size = int(available_memory * 1024 ** 3 / cell_size)
7488
if world_size > 1:
7589
tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}")
7690
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
7791
self.size = tensor.item()
7892
logger.info(
79-
f"{str(available_memory)} GB space is available after load the model weight\n"
93+
f"{str(available_memory)} GB space is available after load the model weight "
94+
f"and reserve {req_manager_reserve_gb} GB for req_manager\n"
8095
f"{str(cell_size / 1024 ** 2)} MB is the size of one token kv cache\n"
8196
f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n"
8297
)

lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import torch
2+
import torch.distributed as dist
23
import triton
34
from lightllm.utils.log_utils import init_logger
45
from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager
56
from lightllm.utils.envs_utils import get_env_start_args
7+
from lightllm.utils.dist_utils import get_current_device_id
8+
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
69
from lightllm.common.linear_att_cache_manager import LinearAttCacheConfig, LinearAttCacheManager
710
from .operator import LinearAttMemOperator
811
from typing import Tuple, Any
@@ -32,6 +35,38 @@ def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]:
3235
layer_index = layer_index // self.linear_config.full_attention_interval
3336
return super().get_att_input_params(layer_index)
3437

38+
def profile_size(self, mem_fraction):
39+
if self.size is not None:
40+
return
41+
42+
torch.cuda.empty_cache()
43+
args = get_env_start_args()
44+
reserve_bytes = self.get_req_manager_reserve_bytes()
45+
req_state_num = (args.running_max_req_size + 8 + 1) * (args.mtp_step + 1)
46+
reserve_bytes += (
47+
req_state_num
48+
* self.linear_config.linear_layer_num
49+
* (self.linear_config.get_conv_state_bytes_per_layer() + self.linear_config.get_ssm_state_bytes_per_layer())
50+
)
51+
reserve_gb = reserve_bytes / (1024 ** 3)
52+
53+
world_size = dist.get_world_size()
54+
available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction)
55+
available_memory -= reserve_gb
56+
cell_size = self.get_cell_size()
57+
self.size = max(int(available_memory * 1024 ** 3 / cell_size), 1)
58+
if world_size > 1:
59+
tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}")
60+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
61+
self.size = tensor.item()
62+
logger.info(
63+
f"{str(available_memory)} GB space is available after load the model weight "
64+
f"and reserve {reserve_gb} GB for qwen3next req_manager\n"
65+
f"{str(cell_size / 1024 ** 2)} MB is the size of one token kv cache\n"
66+
f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n"
67+
)
68+
return
69+
3570
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
3671
super()._init_buffers(size, dtype, head_num, head_dim, layer_num)
3772
# TODO 初始化线性 att 对应的部分 buffer.

lightllm/models/qwen3next/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def _init_mem_manager(self):
8989
)
9090

9191
def _init_req_manager(self):
92+
self._safe_clamp_auto_max_req_total_len()
9293
create_max_seq_len = 0
9394

9495
if self.batch_max_tokens is not None:

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,3 @@ nixl==1.1.0
9898
xformers==0.0.35
9999
redis==7.3.0
100100
litellm>=1.52.0,<1.85
101-
flash-attn-4[13]==4.0.0b14

0 commit comments

Comments
 (0)