|
1 | 1 | import torch |
| 2 | +import torch.distributed as dist |
2 | 3 | from lightllm.utils.log_utils import init_logger |
3 | 4 | from lightllm.common.kv_cache_mem_manager.kv_buffer.hybrid_kv_buffer import HybridKvBuffer |
4 | 5 | from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager |
| 6 | +from lightllm.utils.envs_utils import get_env_start_args |
| 7 | +from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory |
5 | 8 |
|
6 | 9 | logger = init_logger(__name__) |
7 | 10 |
|
@@ -45,6 +48,52 @@ def __init__( |
45 | 48 |
|
46 | 49 | super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) |
47 | 50 |
|
| 51 | + def profile_size(self, mem_fraction): |
| 52 | + if self.size is not None: |
| 53 | + return |
| 54 | + |
| 55 | + world_size = dist.get_world_size() |
| 56 | + total_memory = get_total_gpu_memory() |
| 57 | + available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) |
| 58 | + |
| 59 | + conv_dim = ( |
| 60 | + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads |
| 61 | + ) |
| 62 | + mamba_cell_size = ( |
| 63 | + self.linear_attn_layer_num |
| 64 | + * conv_dim |
| 65 | + * (self.conv_kernel_size - 1) |
| 66 | + * torch._utils._element_size(self.conv_state_dtype) |
| 67 | + ) + ( |
| 68 | + self.linear_attn_layer_num |
| 69 | + * self.num_linear_v_heads |
| 70 | + * self.head_linear_k_dim |
| 71 | + * self.head_linear_v_dim |
| 72 | + * torch._utils._element_size(self.ssm_state_dtype) |
| 73 | + ) |
| 74 | + |
| 75 | + if self.linear_attn_cache_size is None: |
| 76 | + start_args = get_env_start_args() |
| 77 | + mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 |
| 78 | + self.linear_attn_cache_size = int(available_memory * mamba_cache_ratio * 1024 ** 3 / mamba_cell_size) |
| 79 | + reserved_mamba_memory = self.linear_attn_cache_size * mamba_cell_size / (1024 ** 3) |
| 80 | + available_memory -= reserved_mamba_memory |
| 81 | + |
| 82 | + cell_size = self.get_cell_size() |
| 83 | + self.size = int(available_memory * 1024 ** 3 / cell_size) |
| 84 | + if world_size > 1: |
| 85 | + tensor = torch.tensor(self.size, dtype=torch.int64, device="cuda") |
| 86 | + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) |
| 87 | + self.size = tensor.item() |
| 88 | + |
| 89 | + logger.info( |
| 90 | + f"{available_memory} GB space is available for full attention kv cache after reserving " |
| 91 | + f"{reserved_mamba_memory} GB for mamba cache\n" |
| 92 | + f"{cell_size / 1024 ** 2} MB is the size of one token kv cache\n" |
| 93 | + f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n" |
| 94 | + ) |
| 95 | + return |
| 96 | + |
48 | 97 | def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): |
49 | 98 | # KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ..., |
50 | 99 | # None, kv_cache, mtp_kv_cache, mtp_kv_cache] |
|
0 commit comments