@@ -102,9 +102,6 @@ def alloc_kv_move_buffer(self, max_req_total_len):
102102 return
103103
104104 def alloc_paged_kv_move_buffer (self , page_num , page_size ) -> torch .Tensor :
105- if isinstance (self , MemoryManager ) and type (self ) is not MemoryManager :
106- raise NotImplementedError ("subclass need reimpl this method" )
107-
108105 num_kv_head = get_num_key_value_heads (get_env_start_args ().model_dir )
109106 self .kv_move_buffer = torch .empty (
110107 (page_num , page_size , self .layer_num , 2 * num_kv_head , self .head_dim ), dtype = self .dtype , device = "cuda"
@@ -121,7 +118,10 @@ def write_mem_to_page_kv_move_buffer(
121118 dp_index : int ,
122119 mem_managers : List ["MemoryManager" ],
123120 dp_world_size : int ,
121+ page_kind : str = "kv" ,
122+ req_idx : int = None ,
124123 ):
124+ assert page_kind == "kv" , f"{ type (self ).__name__ } does not support page_kind={ page_kind } "
125125 cur_page = self .kv_move_buffer [page_index ]
126126 pin_mem_indexes = self ._buffer_mem_indexes_tensors [page_index ][0 : len (mem_indexes )]
127127 pin_mem_indexes .numpy ()[:] = mem_indexes
@@ -150,7 +150,10 @@ def read_page_kv_move_buffer_to_mem(
150150 dp_index : int ,
151151 mem_managers : List ["MemoryManager" ],
152152 dp_world_size : int ,
153+ page_kind : str = "kv" ,
154+ req_idx : int = None ,
153155 ):
156+ assert page_kind == "kv" , f"{ type (self ).__name__ } does not support page_kind={ page_kind } "
154157 cur_page = self .kv_move_buffer [page_index ]
155158 pin_mem_indexes = self ._buffer_mem_indexes_tensors [page_index ][0 : len (mem_indexes )]
156159 pin_mem_indexes .numpy ()[:] = mem_indexes
0 commit comments