Skip to content

Commit 3bdb03a

Browse files
author
niushengxiao
committed
fix
1 parent 1a991b4 commit 3bdb03a

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

lightllm/common/mamba_cache_mem_manager/cache_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def profile_size(
199199
):
200200
start_args = get_env_start_args()
201201
if self.size is not None and not start_args.disable_dynamic_prompt_cache:
202-
assert self.size < start_args.running_max_req_size * 2, (
202+
assert self.size >= start_args.running_max_req_size * 2, (
203203
f"error mamba_cache_size({self.size}), ",
204204
f"mamba_cache_size should be at least running_max_req_size * 2",
205205
f"({start_args.running_max_req_size * 2}), ",

lightllm/models/qwen3next/mem_manager.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import torch
2+
import torch.distributed as dist
23
from lightllm.utils.log_utils import init_logger
34
from lightllm.common.kv_cache_mem_manager.kv_buffer.hybrid_kv_buffer import HybridKvBuffer
45
from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager
6+
from lightllm.utils.envs_utils import get_env_start_args
7+
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
58

69
logger = init_logger(__name__)
710

@@ -45,6 +48,52 @@ def __init__(
4548

4649
super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction)
4750

51+
def profile_size(self, mem_fraction):
52+
if self.size is not None:
53+
return
54+
55+
world_size = dist.get_world_size()
56+
total_memory = get_total_gpu_memory()
57+
available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction)
58+
59+
conv_dim = (
60+
self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads
61+
)
62+
mamba_cell_size = (
63+
self.linear_attn_layer_num
64+
* conv_dim
65+
* (self.conv_kernel_size - 1)
66+
* torch._utils._element_size(self.conv_state_dtype)
67+
) + (
68+
self.linear_attn_layer_num
69+
* self.num_linear_v_heads
70+
* self.head_linear_k_dim
71+
* self.head_linear_v_dim
72+
* torch._utils._element_size(self.ssm_state_dtype)
73+
)
74+
75+
if self.linear_attn_cache_size is None:
76+
start_args = get_env_start_args()
77+
mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5
78+
self.linear_attn_cache_size = int(available_memory * mamba_cache_ratio * 1024 ** 3 / mamba_cell_size)
79+
reserved_mamba_memory = self.linear_attn_cache_size * mamba_cell_size / (1024 ** 3)
80+
available_memory -= reserved_mamba_memory
81+
82+
cell_size = self.get_cell_size()
83+
self.size = int(available_memory * 1024 ** 3 / cell_size)
84+
if world_size > 1:
85+
tensor = torch.tensor(self.size, dtype=torch.int64, device="cuda")
86+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
87+
self.size = tensor.item()
88+
89+
logger.info(
90+
f"{available_memory} GB space is available for full attention kv cache after reserving "
91+
f"{reserved_mamba_memory} GB for mamba cache\n"
92+
f"{cell_size / 1024 ** 2} MB is the size of one token kv cache\n"
93+
f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n"
94+
)
95+
return
96+
4897
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
4998
# KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ...,
5099
# None, kv_cache, mtp_kv_cache, mtp_kv_cache]

0 commit comments

Comments
 (0)