|
| 1 | +from types import SimpleNamespace |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | + |
| 6 | +class _Buf: |
| 7 | + def __init__(self, tensor): |
| 8 | + self.buffer = tensor |
| 9 | + |
| 10 | + |
| 11 | +def _make_config(): |
| 12 | + return SimpleNamespace( |
| 13 | + tp_world_size=1, |
| 14 | + linear_layer_num=1, |
| 15 | + conv_kernel_size=4, |
| 16 | + global_linear_k_heads=1, |
| 17 | + global_linear_v_heads=1, |
| 18 | + num_linear_k_heads=1, |
| 19 | + num_linear_v_heads=1, |
| 20 | + head_linear_k_dim=2, |
| 21 | + head_linear_v_dim=3, |
| 22 | + ) |
| 23 | + |
| 24 | + |
| 25 | +def _make_mem(mtp_step=2, req_slots=4): |
| 26 | + config = _make_config() |
| 27 | + conv_dim = ( |
| 28 | + 2 * config.num_linear_k_heads * config.head_linear_k_dim |
| 29 | + + config.num_linear_v_heads * config.head_linear_v_dim |
| 30 | + ) |
| 31 | + narrow_w = config.conv_kernel_size - 1 |
| 32 | + conv = torch.full( |
| 33 | + (config.linear_layer_num, req_slots, conv_dim, narrow_w + mtp_step), |
| 34 | + -9.0, |
| 35 | + dtype=torch.float32, |
| 36 | + ) |
| 37 | + ssm = torch.full( |
| 38 | + ( |
| 39 | + config.linear_layer_num, |
| 40 | + req_slots * (mtp_step + 1), |
| 41 | + config.num_linear_v_heads, |
| 42 | + config.head_linear_k_dim, |
| 43 | + config.head_linear_v_dim, |
| 44 | + ), |
| 45 | + -11.0, |
| 46 | + dtype=torch.float32, |
| 47 | + ) |
| 48 | + return SimpleNamespace( |
| 49 | + linear_config=config, |
| 50 | + req_to_conv_state=_Buf(conv), |
| 51 | + req_to_ssm_state=_Buf(ssm), |
| 52 | + kv_move_buffer=torch.zeros((1, 4096), dtype=torch.uint8), |
| 53 | + ) |
| 54 | + |
| 55 | + |
| 56 | +def test_page_helper_writes_req_conv_slot_and_narrow_width(monkeypatch): |
| 57 | + import lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager as qwen3next_mem_manager |
| 58 | + from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextLinearAttPageHelper |
| 59 | + |
| 60 | + mtp_step = 2 |
| 61 | + req_idx = 2 |
| 62 | + monkeypatch.setattr(qwen3next_mem_manager, "get_env_start_args", lambda: SimpleNamespace(mtp_step=mtp_step)) |
| 63 | + |
| 64 | + mem = _make_mem(mtp_step=mtp_step) |
| 65 | + helper = Qwen3NextLinearAttPageHelper(mem) |
| 66 | + mem.kv_move_buffer = torch.zeros((1, helper.state_nbytes), dtype=torch.uint8) |
| 67 | + |
| 68 | + narrow_w = helper.conv_shape[-1] |
| 69 | + marker_conv = torch.arange( |
| 70 | + helper.conv_shape[0] * helper.conv_shape[1] * narrow_w, |
| 71 | + dtype=torch.float32, |
| 72 | + ).view(helper.conv_shape) |
| 73 | + marker_ssm = torch.arange( |
| 74 | + helper.ssm_shape[0] * helper.ssm_shape[1] * helper.ssm_shape[2] * helper.ssm_shape[3], |
| 75 | + dtype=torch.float32, |
| 76 | + ).view(helper.ssm_shape) |
| 77 | + |
| 78 | + mem.req_to_conv_state.buffer[:, req_idx, :, :narrow_w] = marker_conv |
| 79 | + mem.req_to_conv_state.buffer[:, req_idx, :, narrow_w:] = 999.0 |
| 80 | + mem.req_to_ssm_state.buffer[:, req_idx * (mtp_step + 1), ...] = marker_ssm |
| 81 | + |
| 82 | + helper.write_req_to_page(page_index=0, req_idx=req_idx, dp_mems=[mem]) |
| 83 | + |
| 84 | + conv_page, ssm_page = helper.view_page_to_linear_att_state(page_index=0) |
| 85 | + torch.testing.assert_close(conv_page, marker_conv) |
| 86 | + torch.testing.assert_close(ssm_page, marker_ssm) |
| 87 | + |
| 88 | + |
| 89 | +def test_page_helper_restores_narrow_conv_to_req_slot(monkeypatch): |
| 90 | + import lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager as qwen3next_mem_manager |
| 91 | + from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextLinearAttPageHelper |
| 92 | + |
| 93 | + mtp_step = 2 |
| 94 | + req_idx = 2 |
| 95 | + monkeypatch.setattr(qwen3next_mem_manager, "get_env_start_args", lambda: SimpleNamespace(mtp_step=mtp_step)) |
| 96 | + |
| 97 | + mem = _make_mem(mtp_step=mtp_step) |
| 98 | + helper = Qwen3NextLinearAttPageHelper(mem) |
| 99 | + mem.kv_move_buffer = torch.zeros((1, helper.state_nbytes), dtype=torch.uint8) |
| 100 | + conv_page, ssm_page = helper.view_page_to_linear_att_state(page_index=0) |
| 101 | + |
| 102 | + marker_conv = torch.arange(conv_page.numel(), dtype=torch.float32).view_as(conv_page) |
| 103 | + marker_ssm = torch.arange(ssm_page.numel(), dtype=torch.float32).view_as(ssm_page) |
| 104 | + conv_page.copy_(marker_conv) |
| 105 | + ssm_page.copy_(marker_ssm) |
| 106 | + |
| 107 | + helper.read_page_to_req(page_index=0, req_idx=req_idx, dp_mems=[mem]) |
| 108 | + |
| 109 | + narrow_w = helper.conv_shape[-1] |
| 110 | + torch.testing.assert_close(mem.req_to_conv_state.buffer[:, req_idx, :, :narrow_w], marker_conv) |
| 111 | + assert torch.all(mem.req_to_conv_state.buffer[:, req_idx, :, narrow_w:] == -9.0) |
| 112 | + torch.testing.assert_close(mem.req_to_ssm_state.buffer[:, req_idx * (mtp_step + 1), ...], marker_ssm) |
0 commit comments