Replace cachegen kernels with more performant coder kernels#209
Replace cachegen kernels with more performant coder kernels#209colinreyn wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request replaces the existing cachegen implementation with a new PAC (Prepare, Encode, Decode) kernel suite, stubbing out the previous kernels and introducing new C++ and Python logic. The review feedback identifies several performance improvement opportunities, such as reducing redundant tensor clones and allocations in the encoding loop, avoiding inefficient device transfers by allocating directly on the NPU, and removing blocking synchronizations to allow for better task overlap. Additionally, a typo was identified in the naming of the metadata preparation function.
| local_out_buf = output_buffer.clone() | ||
| local_output_lengths = output_lengths.clone() | ||
| tmp_in = encode_input[:, start:end, :].clone() | ||
|
|
||
| lmc_ops.pac_encode(tmp_in, meta_data, local_out_buf, local_output_lengths) | ||
| max_len = local_output_lengths[-1, -1] | ||
|
|
||
| data_chunks.append( | ||
| CacheGenGPUBytestream( | ||
| bytestream=local_out_buf.flatten()[0:max_len], | ||
| bytestream_lengths=local_output_lengths, | ||
| ntokens=end - start, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
The current loop implementation is inefficient due to multiple redundant tensor clones and allocations.
local_out_bufandlocal_output_lengthsare cloned from the full-sized buffers on every iteration.tmp_inis cloned unnecessarily; slicing returns a view and the C++ kernel takes aconstreference.
Instead, reuse the pre-allocated buffers for the kernel call and only clone the resulting slice to ensure the data is preserved for the CacheGenGPUBytestream objects.
| local_out_buf = output_buffer.clone() | |
| local_output_lengths = output_lengths.clone() | |
| tmp_in = encode_input[:, start:end, :].clone() | |
| lmc_ops.pac_encode(tmp_in, meta_data, local_out_buf, local_output_lengths) | |
| max_len = local_output_lengths[-1, -1] | |
| data_chunks.append( | |
| CacheGenGPUBytestream( | |
| bytestream=local_out_buf.flatten()[0:max_len], | |
| bytestream_lengths=local_output_lengths, | |
| ntokens=end - start, | |
| ) | |
| ) | |
| tmp_in = encode_input[:, start:end, :] | |
| output_buffer.zero_() | |
| output_lengths.zero_() | |
| lmc_ops.pac_encode(tmp_in, meta_data, output_buffer, output_lengths) | |
| max_len = output_lengths[-1, -1] | |
| data_chunks.append( | |
| CacheGenGPUBytestream( | |
| bytestream=output_buffer.flatten()[0:max_len].clone(), | |
| bytestream_lengths=output_lengths.clone(), | |
| ntokens=end - start, | |
| ) | |
| ) |
There was a problem hiding this comment.
Gemini is wrong here about cloning the slice for tmp_in . input_sym.sizes() will return different results depending on if the clone is present or not. The number of tokens is detected from this so removing the clone would require broader changes. As the copy is not actually a performance bottleneck I'm ignoring that part of this suggestion. (moving the other clones is done)
|
|
||
| static constexpr uint32_t AIV_MAX = 20; | ||
|
|
||
| void pac_perpare_enc_metadata(const at::Tensor &input_sym, |
There was a problem hiding this comment.
There is a typo in the function name: pac_perpare_enc_metadata should be pac_prepare_enc_metadata. This typo is also present in the header file, the pybind registration, and the Python caller.
| void pac_perpare_enc_metadata(const at::Tensor &input_sym, | |
| void pac_prepare_enc_metadata(const at::Tensor &input_sym, |
| auto output_lengths_data_ptr = | ||
| static_cast<uint8_t *>(output_lengths.data_ptr()); | ||
|
|
||
| auto workGM = torch::zeros({40 * 32}, input_sym.options().dtype(torch::kI32)); |
There was a problem hiding this comment.
There was a problem hiding this comment.
Perhaps, but it looks to be around 0.5% (30us vs 6ms) of the overhead and, implemented as it is, removes any concerns around concurrency or generality. I'm inclined to keep this simple and pay the small overhead cost
There was a problem hiding this comment.
As discussed, worth changing the Int8 instead of Int32.
| key, value = out.float() | ||
|
|
||
| e = decode_stream.record_event(e) | ||
| e.synchronize() |
There was a problem hiding this comment.
There was a problem hiding this comment.
The AI proposed fix isn't workable but revisiting the problem the sync is fixing, actually it might be addressable by adding some (seemingly) redundant copies instead. First indication is that it works and is more efficient
There was a problem hiding this comment.
First indication was wrong, correctness issues come back in - in principle I agree that this isn't perfect but practically speaking it works
| meta_data = torch.zeros((nlayers, nchannels, n_bins), dtype=torch.int16).to( | ||
| device="npu" | ||
| ) |
There was a problem hiding this comment.
| output_buffer = torch.zeros((nlayers, nchannels, chunk_size), dtype=torch.uint8).to( | ||
| device="npu" | ||
| ) | ||
| output_lengths = torch.zeros((nlayers, nchannels), dtype=torch.int32).to( | ||
| device="npu" | ||
| ) |
There was a problem hiding this comment.
8633b30 to
83313d6
Compare
| "Number of channels exceeds that supported be encode, contact " | ||
| "LMCache Ascend about changing this limitation"); | ||
| TORCH_CHECK(nbins <= 32, | ||
| "Number of bins exceeds that supported be encode, contact." |
There was a problem hiding this comment.
nit: contact. -> contact
| "of 32, contact LMCache Ascend about changing this limitation"); | ||
|
|
||
| const c10::OptionalDeviceGuard device_guard(device_of(input_sym)); | ||
| const aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); |
There was a problem hiding this comment.
nit: replace stream() -> stream(false) if do not require synchronization of the task queue.
| "chunking the input." | ||
| "Contact LMCache Ascend about changing this limitation"); | ||
| TORCH_CHECK(nbins <= 32, | ||
| "Number of bins exceeds that supported be encode, contact." |
There was a problem hiding this comment.
nit: contact. -> contact
| auto output_lengths_data_ptr = | ||
| static_cast<uint8_t *>(output_lengths.data_ptr()); | ||
|
|
||
| auto workGM = torch::zeros({40 * 32}, input_sym.options().dtype(torch::kI32)); |
There was a problem hiding this comment.
As discussed, worth changing the Int8 instead of Int32.
| ) -> List[Optional[MemoryObj]]: | ||
| source_bufs = old_batched_get_blocking(self, keys) | ||
|
|
||
| allocator = self.get_allocator_backend() |
There was a problem hiding this comment.
As discussed, please only target cachegen backend for now.
#143 wired up a series of cachegen kernels. They were however limited by the interface imposed by LMCache that invoked an expensive
RepeatInterleavein encode. This PR replaces those kernels with more performant kernels that remove the need for the work, done incollect_bytes, that was previously imposing a bottle neck. An incisive change is injected into path of LMCache's cachegen encode/decode to properly invoke the new kernelsThese kernels are in-themselves faster (~4x) and on the encode side remove other expensive operations leading to a substantial performance improvement.
Here are some representative timings showing a few key results:
A minimal drop in accuracy is measured against the gsm8k benchmark although it should be noted that these results have been observed to depend on chunk size and model