Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the KV cache management system by introducing a "KvBuffer" abstraction and a corresponding "KvBufferAdapter" to separate storage logic from business operations like page IO and CPU offloading. This change affects the base memory manager, quantized variants, and model-specific implementations like DeepSeek-V2 and Qwen3Next. Feedback includes the need to implement "setitem" and handle multi-dimensional indexing in the new buffer classes to prevent runtime errors during item assignment, as well as performance optimizations for layer index lookups by caching pointer-to-index mappings.
| def __getitem__(self, item): | ||
| return self._buffer[item] |
There was a problem hiding this comment.
The KvBuffer class is missing the __setitem__ method. Since self.kv_buffer in MemoryManager is now a KvBuffer object rather than a raw tensor, direct item assignment (e.g., self.kv_buffer[slice, ...] = value) will raise a TypeError. This is critical for methods like MemoryManager._write_kv_move_data and Deepseek2MemoryManager._write_kv_move_data which rely on this behavior for PD separation.
| def __getitem__(self, item): | |
| return self._buffer[item] | |
| def __getitem__(self, item): | |
| return self._buffer[item] | |
| def __setitem__(self, key, value): | |
| self._buffer[key] = value |
| def __getitem__(self, item): | ||
| return self._buffers[item] |
There was a problem hiding this comment.
HybridKvBuffer does not support multi-dimensional indexing or item assignment, which are used in MemoryManager (e.g., lines 234 and 271). Because self._buffers is a list, indexing it with a tuple or slice will fail. This will cause crashes in models using HybridKvBuffer (like Qwen3Next) when features like PD separation are enabled.
| def __getitem__(self, item): | |
| return self._buffers[item] | |
| def __getitem__(self, item): | |
| if isinstance(item, tuple): | |
| return self._buffers[item[0]][item[1:]] | |
| return self._buffers[item] | |
| def __setitem__(self, key, value): | |
| if isinstance(key, tuple): | |
| self._buffers[key[0]][key[1:]] = value | |
| else: | |
| self._buffers[key] = value |
| def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int: | ||
| key = min(k.data_ptr(), v.data_ptr()) | ||
| find_dict = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))} | ||
| assert key in find_dict | ||
| return find_dict[key] |
There was a problem hiding this comment.
The find_layer_index method builds a dictionary of layer pointers on every call. Since the KV buffer is pre-allocated and its layer pointers are constant, this dictionary should be cached to avoid unnecessary overhead during inference.
| def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int: | |
| key = min(k.data_ptr(), v.data_ptr()) | |
| find_dict = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))} | |
| assert key in find_dict | |
| return find_dict[key] | |
| def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int: | |
| key = min(k.data_ptr(), v.data_ptr()) | |
| if not hasattr(self, "_layer_ptr_to_idx"): | |
| self._layer_ptr_to_idx = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))} | |
| assert key in self._layer_ptr_to_idx | |
| return self._layer_ptr_to_idx[key] |
| def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int: | ||
| key = min(k.data_ptr(), v.data_ptr()) | ||
| find_dict = { | ||
| layer_buffer.data_ptr(): layer_index | ||
| for layer_index, layer_buffer in enumerate(self._buffers) | ||
| if layer_buffer is not None | ||
| } | ||
| assert key in find_dict | ||
| return find_dict[key] |
There was a problem hiding this comment.
Similar to KvBuffer, the dictionary of layer pointers in HybridKvBuffer.find_layer_index should be cached to improve performance.
| def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int: | |
| key = min(k.data_ptr(), v.data_ptr()) | |
| find_dict = { | |
| layer_buffer.data_ptr(): layer_index | |
| for layer_index, layer_buffer in enumerate(self._buffers) | |
| if layer_buffer is not None | |
| } | |
| assert key in find_dict | |
| return find_dict[key] | |
| def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int: | |
| key = min(k.data_ptr(), v.data_ptr()) | |
| if not hasattr(self, "_layer_ptr_to_idx"): | |
| self._layer_ptr_to_idx = { | |
| layer_buffer.data_ptr(): layer_index | |
| for layer_index, layer_buffer in enumerate(self._buffers) | |
| if layer_buffer is not None | |
| } | |
| assert key in self._layer_ptr_to_idx | |
| return self._layer_ptr_to_idx[key] |
No description provided.