Skip to content

Commit bde19ca

Browse files
bugfix
1 parent 622777d commit bde19ca

3 files changed

Lines changed: 19 additions & 19 deletions

File tree

ucm/sparse/gsa_on_device/gsa_on_device.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
from ucm.sparse.gsa_on_device.gsa_on_device_config import GSAOnDeviceConfig
3838
from ucm.sparse.gsa_on_device.hash_encoder import HashEncoder
39+
from ucm.sparse.utils import cdiv
3940
from ucm.utils import Config
4041

4142
logger = init_logger(__name__)
@@ -255,13 +256,14 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
255256

256257
def init_for_pc(self):
257258
# for pc hit
258-
self.prefix_slot_mapping_buf = self._make_buffer(
259-
self.max_num_tokens * self.max_batch_size, dtype=torch.int64
259+
self.prefix_slot_mapping_buf = torch.empty(
260+
self.max_num_tokens * self.max_batch_size,
261+
device=self.device,
262+
dtype=torch.int64,
260263
)
261-
self.prefix_block_ids_buf = self._make_buffer(
262-
(self.max_num_tokens * self.max_batch_size + self.block_size - 1)
263-
// self.block_size
264-
+ 1,
264+
self.prefix_block_ids_buf = torch.empty(
265+
cdiv(self.max_num_tokens * self.max_batch_size, self.block_size),
266+
device=self.device,
265267
dtype=torch.int32,
266268
)
267269
self.token_idx_buf = torch.arange(
@@ -275,11 +277,6 @@ def _make_buffer(
275277
*size, dtype=dtype, device=self.device, pin_memory=True, with_numpy=numpy
276278
)
277279

278-
def _clear_buffer(self) -> None:
279-
self.decode_req_ids_buf.clear()
280-
self.prefix_slot_mapping_buf.clear()
281-
self.prefix_block_ids_buf.clear()
282-
283280
def hash_code(
284281
self,
285282
nope: Optional[torch.Tensor] = None,
@@ -709,7 +706,8 @@ def rebuild_prefix_cache_info_for_req(
709706
assert 0 <= qlen <= num_prompt_tokens
710707
num_prefix_tokens = num_prompt_tokens - qlen
711708
if num_prefix_tokens <= 0:
712-
return [], []
709+
empty = block_table_row[:0]
710+
return 0, 0, empty, empty
713711

714712
num_prefix_blocks = (num_prefix_tokens + block_size - 1) // block_size
715713
prefix_block_ids = block_table_row[:num_prefix_blocks] # [prefix_blocks]
@@ -749,7 +747,7 @@ def build_sparse_meta(
749747
compute_q_lens = (
750748
attn_metadata.query_start_loc[1:] - attn_metadata.query_start_loc[:-1]
751749
)
752-
self._clear_buffer()
750+
self.decode_req_ids_buf.clear()
753751

754752
self.num_reqs = len(scheduler_output.num_scheduled_tokens)
755753
for (
@@ -804,10 +802,10 @@ def build_sparse_meta(
804802
block_size=self.block_size,
805803
)
806804

807-
self.prefix_slot_mapping_buf.gpu[
805+
self.prefix_slot_mapping_buf[
808806
all_prefix_tokens : all_prefix_tokens + num_prefix_tokens
809807
] = prefix_slot_mapping
810-
self.prefix_block_ids_buf.gpu[
808+
self.prefix_block_ids_buf[
811809
all_prefix_blocks : all_prefix_blocks + num_prefix_blocks
812810
] = prefix_block_ids
813811

@@ -853,12 +851,10 @@ def build_sparse_meta(
853851

854852
self.has_pc_hit = num_pc_hit > 0
855853
if self.has_pc_hit:
856-
self.prefix_slot_mapping = self.prefix_slot_mapping_buf.gpu[
854+
self.prefix_slot_mapping = self.prefix_slot_mapping_buf[
857855
:all_prefix_tokens
858856
]
859-
self.prefix_block_ids = self.prefix_block_ids_buf.gpu[
860-
:all_prefix_blocks
861-
]
857+
self.prefix_block_ids = self.prefix_block_ids_buf[:all_prefix_blocks]
862858

863859
def maybe_init_cudagraph_buffers_for_topk(self, n, tile_scheduler_metadata):
864860
sm_parts = tile_scheduler_metadata.size(0)

ucm/sparse/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def round_up(x: int, y: int) -> int:
5454
return ((x + y - 1) // y) * y
5555

5656

57+
def cdiv(a: int, b: int) -> int:
58+
return -(a // -b)
59+
60+
5761
def get_type_size(dtype: torch.dtype) -> int:
5862
return torch.tensor([], dtype=dtype).element_size()
5963

0 commit comments

Comments
 (0)