|
1 | 1 | import torch |
| 2 | +import torch.distributed as dist |
2 | 3 | import triton |
3 | 4 | from lightllm.utils.log_utils import init_logger |
4 | 5 | from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager |
5 | 6 | from lightllm.utils.envs_utils import get_env_start_args |
| 7 | +from lightllm.utils.dist_utils import get_current_device_id |
| 8 | +from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory |
6 | 9 | from lightllm.common.linear_att_cache_manager import LinearAttCacheConfig, LinearAttCacheManager |
7 | 10 | from .operator import LinearAttMemOperator |
8 | 11 | from typing import Tuple, Any |
@@ -32,6 +35,38 @@ def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: |
32 | 35 | layer_index = layer_index // self.linear_config.full_attention_interval |
33 | 36 | return super().get_att_input_params(layer_index) |
34 | 37 |
|
| 38 | + def profile_size(self, mem_fraction): |
| 39 | + if self.size is not None: |
| 40 | + return |
| 41 | + |
| 42 | + torch.cuda.empty_cache() |
| 43 | + args = get_env_start_args() |
| 44 | + reserve_bytes = self.get_req_manager_reserve_bytes() |
| 45 | + req_state_num = (args.running_max_req_size + 8 + 1) * (args.mtp_step + 1) |
| 46 | + reserve_bytes += ( |
| 47 | + req_state_num |
| 48 | + * self.linear_config.linear_layer_num |
| 49 | + * (self.linear_config.get_conv_state_bytes_per_layer() + self.linear_config.get_ssm_state_bytes_per_layer()) |
| 50 | + ) |
| 51 | + reserve_gb = reserve_bytes / (1024 ** 3) |
| 52 | + |
| 53 | + world_size = dist.get_world_size() |
| 54 | + available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction) |
| 55 | + available_memory -= reserve_gb |
| 56 | + cell_size = self.get_cell_size() |
| 57 | + self.size = max(int(available_memory * 1024 ** 3 / cell_size), 1) |
| 58 | + if world_size > 1: |
| 59 | + tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") |
| 60 | + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) |
| 61 | + self.size = tensor.item() |
| 62 | + logger.info( |
| 63 | + f"{str(available_memory)} GB space is available after load the model weight " |
| 64 | + f"and reserve {reserve_gb} GB for qwen3next req_manager\n" |
| 65 | + f"{str(cell_size / 1024 ** 2)} MB is the size of one token kv cache\n" |
| 66 | + f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n" |
| 67 | + ) |
| 68 | + return |
| 69 | + |
35 | 70 | def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): |
36 | 71 | super()._init_buffers(size, dtype, head_num, head_dim, layer_num) |
37 | 72 | # TODO 初始化线性 att 对应的部分 buffer. |
|
0 commit comments