Skip to content

Replace cachegen kernels with more performant coder kernels#209

Open
colinreyn wants to merge 1 commit into
LMCache:mainfrom
colinreyn:cr1/pac_coder
Open

Replace cachegen kernels with more performant coder kernels#209
colinreyn wants to merge 1 commit into
LMCache:mainfrom
colinreyn:cr1/pac_coder

Conversation

@colinreyn
Copy link
Copy Markdown
Contributor

#143 wired up a series of cachegen kernels. They were however limited by the interface imposed by LMCache that invoked an expensive RepeatInterleave in encode. This PR replaces those kernels with more performant kernels that remove the need for the work, done in collect_bytes, that was previously imposing a bottle neck. An incisive change is injected into path of LMCache's cachegen encode/decode to properly invoke the new kernels

These kernels are in-themselves faster (~4x) and on the encode side remove other expensive operations leading to a substantial performance improvement.

Here are some representative timings showing a few key results:

  • Cachegen decode is comparable to a naive serde with a remote backend (model and bandwidth dependenant)
  • Cachegen encode is slower by 30 - 50% compared to naive
  • The compression ratio ranges from 3.5x - 6x depending strongly on chunk size
image image image image

A minimal drop in accuracy is measured against the gsm8k benchmark although it should be noted that these results have been observed to depend on chunk size and model

Qwen3-8B - No Cache
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8817|±  |0.0089|
|     |       |strict-match    |     5|exact_match|↑  |0.8772|±  |0.0090|

Qwen3-8B - Pure cachegen cache
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8560|±  |0.0097|
|     |       |strict-match    |     5|exact_match|↑  |0.8605|±  |0.0095|
Qwen3-30B-A3B - No Cache
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8423|±  |0.0100|
|     |       |strict-match    |     5|exact_match|↑  |0.8825|±  |0.0089|

Qwen3-30B-A3B - Pure cachegen cache
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8400|±  |0.0101|
|     |       |strict-match    |     5|exact_match|↑  |0.8431|±  |0.0100|
Qwen2.5-7B-Instruct - No Cache
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8196|±  |0.0106|
|     |       |strict-match    |     5|exact_match|↑  |0.7877|±  |0.0113|

Qwen2.5-7B-Instruct - Cacgen Cache
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6710|±  |0.0129|
|     |       |strict-match    |     5|exact_match|↑  |0.6399|±  |0.0132|

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the existing cachegen implementation with a new PAC (Prepare, Encode, Decode) kernel suite, stubbing out the previous kernels and introducing new C++ and Python logic. The review feedback identifies several performance improvement opportunities, such as reducing redundant tensor clones and allocations in the encoding loop, avoiding inefficient device transfers by allocating directly on the NPU, and removing blocking synchronizations to allow for better task overlap. Additionally, a typo was identified in the naming of the metadata preparation function.

Comment thread lmcache_ascend/serde/pac.py Outdated
Comment on lines +120 to +133
local_out_buf = output_buffer.clone()
local_output_lengths = output_lengths.clone()
tmp_in = encode_input[:, start:end, :].clone()

lmc_ops.pac_encode(tmp_in, meta_data, local_out_buf, local_output_lengths)
max_len = local_output_lengths[-1, -1]

data_chunks.append(
CacheGenGPUBytestream(
bytestream=local_out_buf.flatten()[0:max_len],
bytestream_lengths=local_output_lengths,
ntokens=end - start,
)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current loop implementation is inefficient due to multiple redundant tensor clones and allocations.

  1. local_out_buf and local_output_lengths are cloned from the full-sized buffers on every iteration.
  2. tmp_in is cloned unnecessarily; slicing returns a view and the C++ kernel takes a const reference.

Instead, reuse the pre-allocated buffers for the kernel call and only clone the resulting slice to ensure the data is preserved for the CacheGenGPUBytestream objects.

Suggested change
local_out_buf = output_buffer.clone()
local_output_lengths = output_lengths.clone()
tmp_in = encode_input[:, start:end, :].clone()
lmc_ops.pac_encode(tmp_in, meta_data, local_out_buf, local_output_lengths)
max_len = local_output_lengths[-1, -1]
data_chunks.append(
CacheGenGPUBytestream(
bytestream=local_out_buf.flatten()[0:max_len],
bytestream_lengths=local_output_lengths,
ntokens=end - start,
)
)
tmp_in = encode_input[:, start:end, :]
output_buffer.zero_()
output_lengths.zero_()
lmc_ops.pac_encode(tmp_in, meta_data, output_buffer, output_lengths)
max_len = output_lengths[-1, -1]
data_chunks.append(
CacheGenGPUBytestream(
bytestream=output_buffer.flatten()[0:max_len].clone(),
bytestream_lengths=output_lengths.clone(),
ntokens=end - start,
)
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemini is wrong here about cloning the slice for tmp_in . input_sym.sizes() will return different results depending on if the clone is present or not. The number of tokens is detected from this so removing the clone would require broader changes. As the copy is not actually a performance bottleneck I'm ignoring that part of this suggestion. (moving the other clones is done)

Comment thread csrc/pac_kernels.cpp Outdated

static constexpr uint32_t AIV_MAX = 20;

void pac_perpare_enc_metadata(const at::Tensor &input_sym,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the function name: pac_perpare_enc_metadata should be pac_prepare_enc_metadata. This typo is also present in the header file, the pybind registration, and the Python caller.

Suggested change
void pac_perpare_enc_metadata(const at::Tensor &input_sym,
void pac_prepare_enc_metadata(const at::Tensor &input_sym,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

Comment thread csrc/pac_kernels.cpp
auto output_lengths_data_ptr =
static_cast<uint8_t *>(output_lengths.data_ptr());

auto workGM = torch::zeros({40 * 32}, input_sym.options().dtype(torch::kI32));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Allocating a workspace tensor (workGM) on every call to pac_encode introduces unnecessary overhead. Since the size is constant (1280 elements), consider using a pre-allocated buffer or a more efficient memory management strategy to avoid repeated NPU allocations during the encoding process.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, but it looks to be around 0.5% (30us vs 6ms) of the overhead and, implemented as it is, removes any concerns around concurrency or generality. I'm inclined to keep this simple and pay the small overhead cost

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, worth changing the Int8 instead of Int32.

key, value = out.float()

e = decode_stream.record_event(e)
e.synchronize()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of e.synchronize() blocks the host CPU until the NPU finishes the decoding process. This prevents the caller from overlapping other tasks with the NPU work. Consider removing this synchronization and letting the caller handle it, or only synchronize when the result is actually needed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AI proposed fix isn't workable but revisiting the problem the sync is fixing, actually it might be addressable by adding some (seemingly) redundant copies instead. First indication is that it works and is more efficient

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First indication was wrong, correctness issues come back in - in principle I agree that this isn't perfect but practically speaking it works

Comment thread lmcache_ascend/serde/pac.py Outdated
Comment on lines +102 to +104
meta_data = torch.zeros((nlayers, nchannels, n_bins), dtype=torch.int16).to(
device="npu"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Allocating a tensor on the CPU and then moving it to the NPU using .to(device="npu") is less efficient than direct allocation on the device. Use the device argument in the factory function.

    meta_data = torch.zeros((nlayers, nchannels, n_bins), dtype=torch.int16, device="npu")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

Comment thread lmcache_ascend/serde/pac.py Outdated
Comment on lines +108 to +113
output_buffer = torch.zeros((nlayers, nchannels, chunk_size), dtype=torch.uint8).to(
device="npu"
)
output_lengths = torch.zeros((nlayers, nchannels), dtype=torch.int32).to(
device="npu"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Directly allocate these tensors on the NPU to avoid the overhead of CPU allocation followed by a device transfer.

    output_buffer = torch.zeros((nlayers, nchannels, chunk_size), dtype=torch.uint8, device="npu")
    output_lengths = torch.zeros((nlayers, nchannels), dtype=torch.int32, device="npu")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

Comment thread csrc/pac_kernels.cpp
"Number of channels exceeds that supported be encode, contact "
"LMCache Ascend about changing this limitation");
TORCH_CHECK(nbins <= 32,
"Number of bins exceeds that supported be encode, contact."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: contact. -> contact

Comment thread csrc/pac_kernels.cpp
"of 32, contact LMCache Ascend about changing this limitation");

const c10::OptionalDeviceGuard device_guard(device_of(input_sym));
const aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
Copy link
Copy Markdown
Collaborator

@matthewygf matthewygf Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: replace stream() -> stream(false) if do not require synchronization of the task queue.

Comment thread csrc/pac_kernels.cpp
"chunking the input."
"Contact LMCache Ascend about changing this limitation");
TORCH_CHECK(nbins <= 32,
"Number of bins exceeds that supported be encode, contact."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: contact. -> contact

Comment thread csrc/pac_kernels.cpp
auto output_lengths_data_ptr =
static_cast<uint8_t *>(output_lengths.data_ptr());

auto workGM = torch::zeros({40 * 32}, input_sym.options().dtype(torch::kI32));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, worth changing the Int8 instead of Int32.

) -> List[Optional[MemoryObj]]:
source_bufs = old_batched_get_blocking(self, keys)

allocator = self.get_allocator_backend()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, please only target cachegen backend for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants