|
2 | 2 | # modify from: https://github.com/vllm-project/vllm |
3 | 3 | import json |
4 | 4 | import math |
| 5 | +from collections.abc import Sequence |
5 | 6 | from dataclasses import dataclass |
| 7 | +from operator import index as as_index |
6 | 8 |
|
7 | 9 | import torch |
8 | 10 |
|
@@ -565,3 +567,79 @@ def get_cache_state_size(state_shapes: list[tuple[tuple[int], torch.dtype]]) -> |
565 | 567 | def state_caches(self): |
566 | 568 | """State caches.""" |
567 | 569 | return self._state_caches |
| 570 | + |
| 571 | + @staticmethod |
| 572 | + def _index_list(idx: int | Sequence[int]): |
| 573 | + """Normalize host-side cache indices.""" |
| 574 | + if isinstance(idx, torch.Tensor): |
| 575 | + raise TypeError('State cache copy indices must be host integers, not torch.Tensor.') |
| 576 | + if isinstance(idx, (str, bytes)): |
| 577 | + raise TypeError('State cache copy indices must be an int or a sequence of ints.') |
| 578 | + try: |
| 579 | + return [as_index(idx)] |
| 580 | + except TypeError: |
| 581 | + pass |
| 582 | + if not isinstance(idx, Sequence): |
| 583 | + raise TypeError('State cache copy indices must be an int or a sequence of ints.') |
| 584 | + if any(isinstance(item, torch.Tensor) for item in idx): |
| 585 | + raise TypeError('State cache copy indices must be host integers, not torch.Tensor.') |
| 586 | + return [as_index(item) for item in idx] |
| 587 | + |
| 588 | + @staticmethod |
| 589 | + def _validate_index_bounds(indices: Sequence[int], num_caches: int): |
| 590 | + """Check normalized cache indices are valid state slots.""" |
| 591 | + for idx in indices: |
| 592 | + if idx < 0 or idx >= num_caches: |
| 593 | + raise ValueError(f'State cache index {idx} is out of range [0, {num_caches}).') |
| 594 | + |
| 595 | + @staticmethod |
| 596 | + def _copy_ranges(src_list: list[int], dst_list: list[int]): |
| 597 | + """Yield contiguous copy ranges as (src_start, dst_start, length).""" |
| 598 | + pairs = sorted(zip(src_list, dst_list)) |
| 599 | + if len(pairs) == 0: |
| 600 | + return |
| 601 | + start_src = prev_src = pairs[0][0] |
| 602 | + start_dst = prev_dst = pairs[0][1] |
| 603 | + length = 1 |
| 604 | + for src, dst in pairs[1:]: |
| 605 | + if src == prev_src + 1 and dst == prev_dst + 1: |
| 606 | + prev_src = src |
| 607 | + prev_dst = dst |
| 608 | + length += 1 |
| 609 | + continue |
| 610 | + yield start_src, start_dst, length |
| 611 | + start_src = prev_src = src |
| 612 | + start_dst = prev_dst = dst |
| 613 | + length = 1 |
| 614 | + yield start_src, start_dst, length |
| 615 | + |
| 616 | + def copy_caches(self, src_idx: int | Sequence[int], dst_idx: int | Sequence[int]): |
| 617 | + """Copy state cache slots. |
| 618 | +
|
| 619 | + This is the low-level primitive needed by SSM prefix caching: a frozen |
| 620 | + state checkpoint can be copied into a newly allocated runtime slot |
| 621 | + before the next forward. |
| 622 | + """ |
| 623 | + if len(self._state_caches) <= 0: |
| 624 | + return |
| 625 | + |
| 626 | + src_list = self._index_list(src_idx) |
| 627 | + dst_list = self._index_list(dst_idx) |
| 628 | + if len(src_list) != len(dst_list): |
| 629 | + raise ValueError('src_idx and dst_idx must have the same number of elements.') |
| 630 | + if len(src_list) == 0: |
| 631 | + return |
| 632 | + num_caches = self.mem_pool.size(0) |
| 633 | + self._validate_index_bounds(src_list, num_caches) |
| 634 | + self._validate_index_bounds(dst_list, num_caches) |
| 635 | + dst_set = set(dst_list) |
| 636 | + if len(dst_set) != len(dst_list): |
| 637 | + raise ValueError('dst_idx must not contain duplicate entries.') |
| 638 | + if not set(src_list).isdisjoint(dst_set): |
| 639 | + raise ValueError('src_idx and dst_idx must not overlap for stream-ordered state copies.') |
| 640 | + |
| 641 | + for src, dst, length in self._copy_ranges(src_list, dst_list): |
| 642 | + if length == 1: |
| 643 | + self.mem_pool[dst].copy_(self.mem_pool[src], non_blocking=True) |
| 644 | + else: |
| 645 | + self.mem_pool[dst:dst + length].copy_(self.mem_pool[src:src + length], non_blocking=True) |
0 commit comments