Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _copy_kv_buffer_to_cpu_cache(
head_scale_size,
BLOCK: tl.constexpr,
):
block_index_start = tl.program_id(0)
split_index_start = tl.program_id(0)
grid_num = tl.num_programs(0)
# 将 所有stride 切成 tl.int64
cpu_cache_full_att_stride_p = tl.cast(cpu_cache_full_att_stride_p, tl.int64)
Expand All @@ -62,7 +62,7 @@ def _copy_kv_buffer_to_cpu_cache(
cpu_kv_ssm_stride_s = tl.cast(cpu_kv_ssm_stride_s, tl.int64)
cpu_kv_ssm_stride_d = tl.cast(cpu_kv_ssm_stride_d, tl.int64)

for block_index in range(block_index_start, page_num, grid_num):
for block_index in range(page_num):
cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64)
run_flag = 1
if cpu_page_index == -1:
Expand All @@ -76,7 +76,7 @@ def _copy_kv_buffer_to_cpu_cache(
head_flag = 0

mem_start_ptr = mem_indexes_ptr + big_page_token_num * block_index
for i in range(tl.cdiv(gpu_full_att_tail_dim, BLOCK) * run_flag * head_flag):
for i in range(split_index_start, tl.cdiv(gpu_full_att_tail_dim, BLOCK) * run_flag * head_flag, grid_num):
gpu_start_i = i * BLOCK + tl.arange(0, BLOCK)
mask = gpu_start_i < gpu_full_att_tail_dim
per_token_size = gpu_full_att_tail_dim // big_page_token_num
Expand All @@ -103,7 +103,7 @@ def _copy_kv_buffer_to_cpu_cache(

big_page_idx = tl.load(big_page_buffer_ids + block_index)

for i in range(tl.cdiv(cpu_kv_conv_tail_dim, BLOCK) * run_flag):
for i in range(split_index_start, tl.cdiv(cpu_kv_conv_tail_dim, BLOCK) * run_flag, grid_num):
gpu_start_i = i * BLOCK + tl.arange(0, BLOCK)
mask = gpu_start_i < cpu_kv_conv_tail_dim
cpu_kv_conv_data = tl.load(
Expand All @@ -119,7 +119,7 @@ def _copy_kv_buffer_to_cpu_cache(
)
tl.store(dest_cpu_cache_conv_ptr, cpu_kv_conv_data, mask=mask)

for i in range(tl.cdiv(cpu_kv_ssm_tail_dim, BLOCK) * run_flag):
for i in range(split_index_start, tl.cdiv(cpu_kv_ssm_tail_dim, BLOCK) * run_flag, grid_num):
gpu_start_i = i * BLOCK + tl.arange(0, BLOCK)
mask = gpu_start_i < cpu_kv_ssm_tail_dim

Expand Down Expand Up @@ -149,7 +149,7 @@ def copy_kv_buffer_to_cpu_cache(
tp_world_size: int,
big_page_token_num: int,
linear_config: LinearAttCacheConfig,
grid_num: int = 16,
grid_num: int = 12,
):
assert len(page_indexes) == len(page_readies) == len(big_page_buffer_ids)
assert len(mem_indexes) % len(page_indexes) == 0
Expand All @@ -172,15 +172,25 @@ def copy_kv_buffer_to_cpu_cache(
else:
cpu_cache_full_att = cpu_cache_tensor[:, 0:a].view(cpu_page_num, linear_config.full_att_all_num_kv_heads, -1)

cpu_cache_conv = cpu_cache_tensor[:, a : (a + b)].view(cpu_page_num, tp_world_size, -1)
cpu_cache_ssm = cpu_cache_tensor[:, (a + b) : (a + b + c)].view(cpu_page_num, tp_world_size, -1)
cpu_cache_full_att = cpu_cache_full_att.view(dtype=torch.uint64)
# 保证可以以128bit对齐的方式进行数据的load 和 store。
assert cpu_cache_full_att.shape[-1] % 2 == 0

cpu_cache_conv = cpu_cache_tensor[:, a : (a + b)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64)
cpu_cache_ssm = (
cpu_cache_tensor[:, (a + b) : (a + b + c)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64)
)

gpu_kv_full_att_state = gpu_kv_full_att_state.view(
gpu_kv_full_att_state.shape[0], gpu_kv_full_att_state.shape[1], -1
).view(dtype=torch.uint8)
).view(dtype=torch.uint64)

gpu_kv_full_att_state = gpu_kv_full_att_state.permute(1, 0, 2) # [s, layer_num, xxdim]
cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], -1).view(dtype=torch.uint8)
cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], -1).view(dtype=torch.uint8)
# 保证可以以128bit对齐的方式进行数据的load 和 store。
assert gpu_kv_full_att_state.shape[-1] % 2 == 0

cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], -1).view(dtype=torch.uint64)
cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], -1).view(dtype=torch.uint64)

gpu_full_att_tail_dim = gpu_kv_full_att_state.shape[-1] * gpu_kv_full_att_state.shape[-2] * big_page_token_num
cpu_kv_conv_tail_dim = cpu_kv_conv_state.shape[-1]
Expand All @@ -195,6 +205,7 @@ def copy_kv_buffer_to_cpu_cache(
assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1]
assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1]
assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1]
assert gpu_kv_full_att_state.stride(2) == 1
assert (
gpu_full_att_tail_dim % big_page_token_num == 0
and (gpu_full_att_tail_dim // big_page_token_num) % full_att_layer_num == 0
Expand Down Expand Up @@ -278,7 +289,7 @@ def _copy_cpu_cache_to_kv_buffer(
head_scale_size,
BLOCK: tl.constexpr,
):
block_index_start = tl.program_id(0)
split_index_start = tl.program_id(0)
grid_num = tl.num_programs(0)
# 将 所有stride 切成 tl.int64
cpu_cache_full_att_stride_p = tl.cast(cpu_cache_full_att_stride_p, tl.int64)
Expand All @@ -298,11 +309,11 @@ def _copy_cpu_cache_to_kv_buffer(
cpu_kv_ssm_stride_s = tl.cast(cpu_kv_ssm_stride_s, tl.int64)
cpu_kv_ssm_stride_d = tl.cast(cpu_kv_ssm_stride_d, tl.int64)

for block_index in range(block_index_start, page_num, grid_num):
for block_index in range(page_num):
cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64)

mem_start_ptr = mem_indexes_ptr + big_page_token_num * block_index
for i in range(tl.cdiv(gpu_full_att_tail_dim, BLOCK)):
for i in range(split_index_start, tl.cdiv(gpu_full_att_tail_dim, BLOCK), grid_num):
gpu_start_i = i * BLOCK + tl.arange(0, BLOCK)
mask = gpu_start_i < gpu_full_att_tail_dim
per_token_size = gpu_full_att_tail_dim // big_page_token_num
Expand All @@ -318,20 +329,26 @@ def _copy_cpu_cache_to_kv_buffer(
+ (tp_rank // head_scale_size) * cpu_cache_full_att_stride_h
+ gpu_start_i
)
cpu_full_att_data = tl.load(src_cpu_cache_full_att_ptr, mask=mask & (mem_index != -1), other=0)
# 标记主要是为了让编译器可以以128bit的方式生成指令进行拉取
mem_mask = mem_index != -1
mem_mask = tl.max_constancy(mem_mask, [2])
dim_index = tl.max_contiguous(dim_index, [2])
mem_index = tl.max_constancy(mem_index, [2])

cpu_full_att_data = tl.load(src_cpu_cache_full_att_ptr, mask=mask & mem_mask, other=0)

tl.store(
gpu_kv_full_att_state
+ mem_index * gpu_kv_full_att_stride_s
+ layer_index * gpu_kv_full_att_stride_l
+ dim_index * gpu_kv_full_att_stride_d,
+ dim_index,
cpu_full_att_data,
mask=mask & (mem_index != -1),
mask=mask & mem_mask,
)

big_page_idx = tl.load(big_page_buffer_ids + block_index)

for i in range(tl.cdiv(cpu_kv_conv_tail_dim, BLOCK)):
for i in range(split_index_start, tl.cdiv(cpu_kv_conv_tail_dim, BLOCK), grid_num):
gpu_start_i = i * BLOCK + tl.arange(0, BLOCK)
mask = gpu_start_i < cpu_kv_conv_tail_dim

Expand All @@ -349,7 +366,7 @@ def _copy_cpu_cache_to_kv_buffer(
mask=mask,
)

for i in range(tl.cdiv(cpu_kv_ssm_tail_dim, BLOCK)):
for i in range(split_index_start, tl.cdiv(cpu_kv_ssm_tail_dim, BLOCK), grid_num):
gpu_start_i = i * BLOCK + tl.arange(0, BLOCK)
mask = gpu_start_i < cpu_kv_ssm_tail_dim

Expand Down Expand Up @@ -379,8 +396,9 @@ def copy_cpu_cache_to_kv_buffer(
tp_world_size: int,
big_page_token_num: int,
linear_config: LinearAttCacheConfig,
grid_num: int = 16,
grid_num: int = 12,
):

assert len(mem_indexes) % len(page_indexes) == 0

BLOCK = 4096
Expand All @@ -400,15 +418,25 @@ def copy_cpu_cache_to_kv_buffer(
else:
cpu_cache_full_att = cpu_cache_tensor[:, 0:a].view(cpu_page_num, linear_config.full_att_all_num_kv_heads, -1)

cpu_cache_conv = cpu_cache_tensor[:, a : (a + b)].view(cpu_page_num, tp_world_size, -1)
cpu_cache_ssm = cpu_cache_tensor[:, (a + b) : (a + b + c)].view(cpu_page_num, tp_world_size, -1)
cpu_cache_full_att = cpu_cache_full_att.view(dtype=torch.uint64)
# 保证可以以128bit对齐的方式进行数据的load 和 store。
assert cpu_cache_full_att.shape[-1] % 2 == 0

cpu_cache_conv = cpu_cache_tensor[:, a : (a + b)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64)
cpu_cache_ssm = (
cpu_cache_tensor[:, (a + b) : (a + b + c)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64)
)

gpu_full_att_kv_state = gpu_full_att_kv_state.view(
gpu_full_att_kv_state.shape[0], gpu_full_att_kv_state.shape[1], -1
).view(dtype=torch.uint8)
).view(dtype=torch.uint64)
gpu_full_att_kv_state = gpu_full_att_kv_state.permute(1, 0, 2) # [s, layer_num, xxdim]
cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], -1).view(dtype=torch.uint8)
cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], -1).view(dtype=torch.uint8)

# 保证可以以128bit对齐的方式进行数据的load 和 store。
assert gpu_full_att_kv_state.shape[-1] % 2 == 0

cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], -1).view(dtype=torch.uint64)
cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], -1).view(dtype=torch.uint64)

gpu_full_att_tail_dim = gpu_full_att_kv_state.shape[-1] * gpu_full_att_kv_state.shape[-2] * big_page_token_num
cpu_kv_conv_tail_dim = cpu_kv_conv_state.shape[-1]
Expand All @@ -418,6 +446,7 @@ def copy_cpu_cache_to_kv_buffer(
assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1]
assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1]
assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1]
assert gpu_full_att_kv_state.stride(2) == 1

assert (tp_rank // head_scale_size) < linear_config.full_att_all_num_kv_heads

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]):

if self.need_sync_compute_stream():
# TODO fa3 现在必须使用同步模式, 未来需要移除
g_infer_context.get_overlap_stream().synchronize()
torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream())
# g_infer_context.get_overlap_stream().synchronize()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out code should be removed to maintain a clean codebase.


mem_manager = self.backend.model.mem_manager
req_manager = self.backend.model.req_manager
Expand Down
Loading
Loading