Skip to content

feat: refactor kv buffer#1265

Open
blueswhen wants to merge 2 commits intomainfrom
kv_buffer
Open

feat: refactor kv buffer#1265
blueswhen wants to merge 2 commits intomainfrom
kv_buffer

Conversation

@blueswhen
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the KV cache management system by introducing a "KvBuffer" abstraction and a corresponding "KvBufferAdapter" to separate storage logic from business operations like page IO and CPU offloading. This change affects the base memory manager, quantized variants, and model-specific implementations like DeepSeek-V2 and Qwen3Next. Feedback includes the need to implement "setitem" and handle multi-dimensional indexing in the new buffer classes to prevent runtime errors during item assignment, as well as performance optimizations for layer index lookups by caching pointer-to-index mappings.

Comment on lines +25 to +26
def __getitem__(self, item):
return self._buffer[item]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The KvBuffer class is missing the __setitem__ method. Since self.kv_buffer in MemoryManager is now a KvBuffer object rather than a raw tensor, direct item assignment (e.g., self.kv_buffer[slice, ...] = value) will raise a TypeError. This is critical for methods like MemoryManager._write_kv_move_data and Deepseek2MemoryManager._write_kv_move_data which rely on this behavior for PD separation.

Suggested change
def __getitem__(self, item):
return self._buffer[item]
def __getitem__(self, item):
return self._buffer[item]
def __setitem__(self, key, value):
self._buffer[key] = value

Comment on lines +50 to +51
def __getitem__(self, item):
return self._buffers[item]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

HybridKvBuffer does not support multi-dimensional indexing or item assignment, which are used in MemoryManager (e.g., lines 234 and 271). Because self._buffers is a list, indexing it with a tuple or slice will fail. This will cause crashes in models using HybridKvBuffer (like Qwen3Next) when features like PD separation are enabled.

Suggested change
def __getitem__(self, item):
return self._buffers[item]
def __getitem__(self, item):
if isinstance(item, tuple):
return self._buffers[item[0]][item[1:]]
return self._buffers[item]
def __setitem__(self, key, value):
if isinstance(key, tuple):
self._buffers[key[0]][key[1:]] = value
else:
self._buffers[key] = value

Comment on lines +61 to +65
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
find_dict = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))}
assert key in find_dict
return find_dict[key]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The find_layer_index method builds a dictionary of layer pointers on every call. Since the KV buffer is pre-allocated and its layer pointers are constant, this dictionary should be cached to avoid unnecessary overhead during inference.

Suggested change
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
find_dict = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))}
assert key in find_dict
return find_dict[key]
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
if not hasattr(self, "_layer_ptr_to_idx"):
self._layer_ptr_to_idx = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))}
assert key in self._layer_ptr_to_idx
return self._layer_ptr_to_idx[key]

Comment on lines +87 to +95
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
find_dict = {
layer_buffer.data_ptr(): layer_index
for layer_index, layer_buffer in enumerate(self._buffers)
if layer_buffer is not None
}
assert key in find_dict
return find_dict[key]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to KvBuffer, the dictionary of layer pointers in HybridKvBuffer.find_layer_index should be cached to improve performance.

Suggested change
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
find_dict = {
layer_buffer.data_ptr(): layer_index
for layer_index, layer_buffer in enumerate(self._buffers)
if layer_buffer is not None
}
assert key in find_dict
return find_dict[key]
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
if not hasattr(self, "_layer_ptr_to_idx"):
self._layer_ptr_to_idx = {
layer_buffer.data_ptr(): layer_index
for layer_index, layer_buffer in enumerate(self._buffers)
if layer_buffer is not None
}
assert key in self._layer_ptr_to_idx
return self._layer_ptr_to_idx[key]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant