Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
int const max_seq_len = max(decoder_seq_len, encoder_seq_len);

// Only the first chunk needs to store encoder kv input to the kv cache.
bool const store_encoder_kv_cache = (decoder_seq_len == decoder_cache_seq_len);
bool const store_encoder_kv_cache = params.cross_kv_input != nullptr && (decoder_seq_len == decoder_cache_seq_len);

// Offsets and strides.
int const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
Expand Down
13 changes: 7 additions & 6 deletions cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ nb::tuple trtllmGenContextPreprocessBinding(torch::Tensor qkv_input, torch::Tens
double rotary_embedding_scale, int64_t rotary_embedding_max_positions, int64_t position_embedding_type,
double bmm1_scale, double bmm2_scale, int64_t attention_chunk_size, bool fp8_context_fmha, bool paged_context_fmha,
bool is_mla_enable, int64_t multi_processor_count, int64_t total_num_blocks, int64_t kv_factor,
bool need_build_kv_cache_metadata)
bool need_build_kv_cache_metadata, std::optional<torch::Tensor> cross_kv, bool cross_attention)
{
auto result = [&]()
{
Expand All @@ -70,7 +70,7 @@ nb::tuple trtllmGenContextPreprocessBinding(torch::Tensor qkv_input, torch::Tens
max_past_kv_length, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type,
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, bmm1_scale, bmm2_scale,
attention_chunk_size, fp8_context_fmha, paged_context_fmha, is_mla_enable, multi_processor_count,
total_num_blocks, kv_factor, need_build_kv_cache_metadata);
total_num_blocks, kv_factor, need_build_kv_cache_metadata, cross_kv, cross_attention);
}();

return nb::make_tuple(std::get<0>(result), optionalToObject(std::get<1>(result)),
Expand All @@ -92,7 +92,7 @@ nb::tuple trtllmGenGenerationPreprocessBinding(torch::Tensor qkv_input, torch::T
int64_t rotary_embedding_scale_type, double rotary_embedding_scale, int64_t rotary_embedding_max_positions,
int64_t position_embedding_type, double bmm1_scale, double bmm2_scale, bool fp8_context_fmha,
int64_t predicted_tokens_per_seq, int64_t attention_chunk_size, int64_t multi_processor_count,
int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata)
int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata, bool cross_attention)
{
auto result = [&]()
{
Expand All @@ -106,7 +106,7 @@ nb::tuple trtllmGenGenerationPreprocessBinding(torch::Tensor qkv_input, torch::T
rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale,
rotary_embedding_max_positions, position_embedding_type, bmm1_scale, bmm2_scale, fp8_context_fmha,
predicted_tokens_per_seq, attention_chunk_size, multi_processor_count, total_num_blocks, kv_factor,
need_build_kv_cache_metadata);
need_build_kv_cache_metadata, cross_attention);
}();

return nb::make_tuple(std::get<0>(result), optionalToObject(std::get<1>(result)),
Expand Down Expand Up @@ -273,7 +273,8 @@ void initBindings(nb::module_& m)
nb::arg("position_embedding_type"), nb::arg("bmm1_scale"), nb::arg("bmm2_scale"),
nb::arg("attention_chunk_size"), nb::arg("fp8_context_fmha"), nb::arg("paged_context_fmha"),
nb::arg("is_mla_enable"), nb::arg("multi_processor_count"), nb::arg("total_num_blocks"), nb::arg("kv_factor"),
nb::arg("need_build_kv_cache_metadata") = true, "Fused nanobind context preprocess for trtllm-gen attention.");
nb::arg("need_build_kv_cache_metadata") = true, nb::arg("cross_kv").none() = nb::none(),
nb::arg("cross_attention") = false, "Fused nanobind context preprocess for trtllm-gen attention.");

m.def("trtllm_gen_context_postprocess", &torch_ext::trtllmGenContextPostprocess, nb::arg("qkv_input"),
nb::arg("workspace"), nb::arg("sequence_lengths"), nb::arg("context_lengths"),
Expand Down Expand Up @@ -332,6 +333,6 @@ void initBindings(nb::module_& m)
nb::arg("position_embedding_type"), nb::arg("bmm1_scale"), nb::arg("bmm2_scale"), nb::arg("fp8_context_fmha"),
nb::arg("predicted_tokens_per_seq"), nb::arg("attention_chunk_size"), nb::arg("multi_processor_count"),
nb::arg("total_num_blocks"), nb::arg("kv_factor"), nb::arg("need_build_kv_cache_metadata") = true,
"Fused nanobind generation preprocess for trtllm-gen attention.");
nb::arg("cross_attention") = false, "Fused nanobind generation preprocess for trtllm-gen attention.");
}
} // namespace tensorrt_llm::nanobind::thop
5 changes: 3 additions & 2 deletions cpp/tensorrt_llm/thop/trtllmGenFusedOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor
double rotary_embedding_scale, int64_t rotary_embedding_max_positions, int64_t position_embedding_type,
double bmm1_scale, double bmm2_scale, int64_t attention_chunk_size, bool fp8_context_fmha, bool paged_context_fmha,
bool is_mla_enable, int64_t multi_processor_count, int64_t total_num_blocks, int64_t kv_factor,
bool need_build_kv_cache_metadata);
bool need_build_kv_cache_metadata, std::optional<torch::Tensor> cross_kv = std::nullopt,
bool cross_attention = false);

void trtllmGenContextPostprocess(torch::Tensor qkv_input, torch::Tensor workspace, torch::Tensor sequence_lengths,
torch::Tensor context_lengths, std::optional<torch::Tensor> kv_cache_block_offsets,
Expand Down Expand Up @@ -72,7 +73,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace,
int64_t rotary_embedding_scale_type, double rotary_embedding_scale, int64_t rotary_embedding_max_positions,
int64_t position_embedding_type, double bmm1_scale, double bmm2_scale, bool fp8_context_fmha,
int64_t predicted_tokens_per_seq, int64_t attention_chunk_size, int64_t multi_processor_count,
int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata);
int64_t total_num_blocks, int64_t kv_factor, bool need_build_kv_cache_metadata, bool cross_attention = false);

} // namespace torch_ext

Expand Down
53 changes: 35 additions & 18 deletions cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,22 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor
int64_t const position_embedding_type, double const bmm1_scale, double const bmm2_scale,
int64_t const attention_chunk_size, bool const fp8_context_fmha, bool const paged_context_fmha,
bool const is_mla_enable, int64_t const multi_processor_count, int64_t const total_num_blocks,
int64_t const kv_factor, bool const need_build_kv_cache_metadata)
int64_t const kv_factor, bool const need_build_kv_cache_metadata, std::optional<torch::Tensor> cross_kv,
bool const cross_attention)
{
(void) bmm2_scale;
TORCH_CHECK(host_kv_cache_pool_pointers.has_value(), "host_kv_cache_pool_pointers is required.");
TORCH_CHECK(host_kv_cache_pool_mapping.has_value(), "host_kv_cache_pool_mapping is required.");
TORCH_CHECK(kv_cache_block_offsets.has_value(), "kv_cache_block_offsets is required.");
TORCH_CHECK(!cross_attention || !is_mla_enable, "trtllm-gen cross attention does not support MLA.");

bool const separateQKvOutput = paged_context_fmha || fp8_context_fmha;
bool const separateQKvOutput = paged_context_fmha || fp8_context_fmha || cross_attention;
auto const qkvScalarType = qkv_input.scalar_type();
auto const qkvElementSize = static_cast<size_t>(qkv_input.element_size());
auto const quantMode = tensorrt_llm::common::QuantMode(static_cast<uint32_t>(kv_cache_quant_mode));
int64_t const effectiveMaxAttentionWindowSize = cross_attention ? max_past_kv_length : max_attention_window_size;
int64_t const effectiveCyclicAttentionWindowSize
= cross_attention ? max_past_kv_length : cyclic_attention_window_size;
auto const views = [&]
{
auto const layout = TrtllmAttentionWorkspaceManager::buildContextLayout(
Expand Down Expand Up @@ -307,8 +312,8 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor
decoderInfoParams.fmhaBmm2Scale = ptrs.fmhaBmm2ScalePtr;
decoderInfoParams.batchSize = static_cast<int>(batch_size);
decoderInfoParams.maxQSeqLength = static_cast<int>(input_seq_length);
decoderInfoParams.maxEncoderQSeqLength = 0;
decoderInfoParams.attentionWindowSize = static_cast<int>(cyclic_attention_window_size);
decoderInfoParams.maxEncoderQSeqLength = cross_attention ? static_cast<int>(max_past_kv_length) : 0;
decoderInfoParams.attentionWindowSize = static_cast<int>(effectiveCyclicAttentionWindowSize);
decoderInfoParams.numTokens = static_cast<int>(num_tokens);
decoderInfoParams.removePadding = true;
decoderInfoParams.attentionMaskType = static_cast<AttentionMaskType>(mask_type);
Expand All @@ -333,12 +338,13 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor
{
return buildPagedKvCacheBuffers(kv_cache_block_offsets, host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping, quantMode, layer_idx, batch_size, tokens_per_block, num_kv_heads, head_size,
cyclic_attention_window_size, max_attention_window_size, 0, 0, is_mla_enable, qkvElementSize);
effectiveCyclicAttentionWindowSize, effectiveMaxAttentionWindowSize, 0, 0, is_mla_enable,
qkvElementSize);
}();

QKVPreprocessingParams<void, KVBlockArray> qkvParams{};
qkvParams.qkv_input = qkv_input.data_ptr();
qkvParams.cross_kv_input = nullptr;
qkvParams.cross_kv_input = optPtr<void>(cross_kv);
qkvParams.quantized_qkv_output = nullptr;
qkvParams.q_output = ptrs.qBufPtr;
qkvParams.kv_cache_buffer = kvArrays.kvCacheBuffer;
Expand All @@ -353,8 +359,9 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor
qkvParams.logn_scaling = nullptr;
qkvParams.tokens_info = ptrs.tokensInfoPtr;
qkvParams.seq_lens = static_cast<int*>(context_lengths.data_ptr());
qkvParams.cache_seq_lens = static_cast<int*>(sequence_lengths.data_ptr());
qkvParams.encoder_seq_lens = nullptr;
qkvParams.cache_seq_lens = cross_attention ? static_cast<int*>(context_lengths.data_ptr())
: static_cast<int*>(sequence_lengths.data_ptr());
qkvParams.encoder_seq_lens = cross_attention ? static_cast<int*>(sequence_lengths.data_ptr()) : nullptr;
qkvParams.cu_seq_lens = ptrs.cuQSeqlensPtr;
qkvParams.cu_kv_seq_lens = ptrs.cuKvSeqlensPtr;
qkvParams.sparse_kv_offsets = nullptr;
Expand All @@ -367,11 +374,11 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor
qkvParams.batch_size = static_cast<int>(batch_size);
qkvParams.max_input_seq_len = static_cast<int>(input_seq_length);
qkvParams.max_kv_seq_len = static_cast<int>(max_past_kv_length);
qkvParams.cyclic_kv_cache_len = static_cast<int>(cyclic_attention_window_size);
qkvParams.cyclic_kv_cache_len = static_cast<int>(effectiveCyclicAttentionWindowSize);
qkvParams.token_num = static_cast<int>(num_tokens);
qkvParams.remove_padding = true;
qkvParams.is_last_chunk = attention_chunk_size == 0 || input_seq_length == max_past_kv_length;
qkvParams.cross_attention = false;
qkvParams.cross_attention = cross_attention;
qkvParams.head_num = static_cast<int>(num_heads);
qkvParams.kv_head_num = static_cast<int>(num_kv_heads);
qkvParams.qheads_per_kv_head = static_cast<int>(num_heads / num_kv_heads);
Expand Down Expand Up @@ -438,7 +445,9 @@ trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, tor

// FlashInfer paged context launches trtllm-gen with multi-CTA-KV mode disabled, so it does not
// consume the counter slab reserved at the head of the workspace.
auto const windowLeft = computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size);
auto const windowLeft = cross_attention
? int64_t{-1}
: computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size);
return {qProcessed, kvPool, blockTables, kvScalePool, views.fmhaBmm1Scale, views.fmhaBmm2Scale,
views.trtllmGenWorkspace, views.cuQSeqlens, views.cuKvSeqlens, input_seq_length, max_past_kv_length,
windowLeft};
Expand Down Expand Up @@ -577,17 +586,22 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace,
int64_t const rotary_embedding_max_positions, int64_t const position_embedding_type, double const bmm1_scale,
double const bmm2_scale, bool const fp8_context_fmha, int64_t const predicted_tokens_per_seq,
int64_t const attention_chunk_size, int64_t const multi_processor_count, int64_t const total_num_blocks,
int64_t const kv_factor, bool const need_build_kv_cache_metadata)
int64_t const kv_factor, bool const need_build_kv_cache_metadata, bool const cross_attention)
{
TORCH_CHECK(host_kv_cache_pool_pointers.has_value(), "host_kv_cache_pool_pointers is required.");
TORCH_CHECK(host_kv_cache_pool_mapping.has_value(), "host_kv_cache_pool_mapping is required.");
TORCH_CHECK(kv_cache_block_offsets.has_value(), "kv_cache_block_offsets is required.");
(void) bmm2_scale;

bool const isMultiTokenGen = spec_decoding_generation_lengths.has_value() && predicted_tokens_per_seq > 1;
TORCH_CHECK(
!cross_attention || !isMultiTokenGen, "trtllm-gen cross attention does not support multi-token generation.");
auto const qkvScalarType = qkv_input.scalar_type();
auto const qkvElementSize = static_cast<size_t>(qkv_input.element_size());
auto const quantMode = tensorrt_llm::common::QuantMode(static_cast<uint32_t>(kv_cache_quant_mode));
int64_t const effectiveMaxAttentionWindowSize = cross_attention ? max_past_kv_length : max_attention_window_size;
int64_t const effectiveCyclicAttentionWindowSize
= cross_attention ? max_past_kv_length : cyclic_attention_window_size;
auto const views = [&]
{
auto const layout = TrtllmAttentionWorkspaceManager::buildGenerationLayout(
Expand Down Expand Up @@ -617,7 +631,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace,
decoderInfoParams.fmhaBmm2Scale = nullptr;
decoderInfoParams.batchSize = static_cast<int>(batch_beam);
decoderInfoParams.maxQSeqLength = static_cast<int>(input_seq_length);
decoderInfoParams.maxEncoderQSeqLength = 0;
decoderInfoParams.maxEncoderQSeqLength = cross_attention ? static_cast<int>(max_past_kv_length) : 0;
decoderInfoParams.attentionWindowSize = 0;
decoderInfoParams.sinkTokenLength = 0;
decoderInfoParams.numTokens = static_cast<int>(num_tokens);
Expand Down Expand Up @@ -655,7 +669,8 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace,
{
return buildPagedKvCacheBuffers(kv_cache_block_offsets, host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping, quantMode, layer_idx, batch_beam, tokens_per_block, num_kv_heads, head_size,
cyclic_attention_window_size, max_attention_window_size, 1, seq_offset, false, qkvElementSize);
effectiveCyclicAttentionWindowSize, effectiveMaxAttentionWindowSize, 1, seq_offset, false,
qkvElementSize);
}();

QKVPreprocessingParams<void, KVBlockArray> qkvParams{};
Expand All @@ -676,7 +691,7 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace,
qkvParams.tokens_info = isMultiTokenGen ? views.tokensInfoPtr : nullptr;
qkvParams.seq_lens = isMultiTokenGen ? optPtr<int>(spec_decoding_generation_lengths) : nullptr;
qkvParams.cache_seq_lens = static_cast<int*>(sequence_lengths.data_ptr());
qkvParams.encoder_seq_lens = nullptr;
qkvParams.encoder_seq_lens = cross_attention ? static_cast<int*>(sequence_lengths.data_ptr()) : nullptr;
qkvParams.cu_seq_lens = buildDecoderInfoNeeded ? views.cuSeqlensPtr : nullptr;
qkvParams.cu_kv_seq_lens = buildDecoderInfoNeeded ? views.cuKvSeqlensPtr : nullptr;
qkvParams.sparse_kv_offsets = nullptr;
Expand All @@ -690,11 +705,11 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace,
qkvParams.batch_size = static_cast<int>(batch_beam);
qkvParams.max_input_seq_len = static_cast<int>(input_seq_length);
qkvParams.max_kv_seq_len = static_cast<int>(max_past_kv_length);
qkvParams.cyclic_kv_cache_len = static_cast<int>(cyclic_attention_window_size);
qkvParams.cyclic_kv_cache_len = static_cast<int>(effectiveCyclicAttentionWindowSize);
qkvParams.token_num = static_cast<int>(num_tokens);
qkvParams.remove_padding = true;
qkvParams.is_last_chunk = false;
qkvParams.cross_attention = false;
qkvParams.cross_attention = cross_attention;
qkvParams.head_num = static_cast<int>(num_heads);
qkvParams.kv_head_num = static_cast<int>(num_kv_heads);
qkvParams.qheads_per_kv_head = static_cast<int>(num_heads / num_kv_heads);
Expand Down Expand Up @@ -751,7 +766,9 @@ trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace,

auto qProcessed = views.qBuf.view({num_tokens, num_heads, head_size});

auto const windowLeft = computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size);
auto const windowLeft = cross_attention
? int64_t{-1}
: computeWindowLeft(cyclic_attention_window_size, max_past_kv_length, attention_chunk_size);
return {qProcessed, kvPool, blockTables, kvScalePool, views.bmm1Scale, views.bmm2Scale, views.trtllmGenWorkspace,
cuSeqlens, input_seq_length, max_past_kv_length, windowLeft, isMultiTokenGen};
}
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/attention_backend/fmha/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def forward(
block_ids_per_seq=metadata.block_ids_per_seq,
tokens_per_block=metadata.tokens_per_block,
max_num_requests=metadata.max_num_requests,
beam_width=metadata.effective_beam_width,
beam_width=metadata.beam_width,
use_paged_context_fmha=metadata.use_paged_context_fmha,
helix_position_offsets=metadata.helix_position_offsets,
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
Expand Down
Loading
Loading