Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,9 @@ class GQAAttentionBase {
args.buffer = reinterpret_cast<float*>(flash_buffer_alloc);
args.buffer_size_per_thread = buffer_size_per_thread;
args.query = Q;
args.q_batch_stride = packed_qkv
? static_cast<size_t>(packed_batch_stride)
: static_cast<size_t>(SafeInt<size_t>(num_heads_) * sequence_length * head_size);
args.k_cache = present_key_data;
args.v_cache = present_value_data;
args.k_scale = k_scale;
Expand Down Expand Up @@ -874,7 +877,11 @@ class GQAAttentionBase {
args.buffer_size_per_thread = buffer_size_per_thread;

// Offset Q and output for this batch
args.query = Q + static_cast<size_t>(b) * num_heads_ * sequence_length * head_size;
const ptrdiff_t q_batch_stride_elems = packed_batch_stride > 0
? packed_batch_stride
: static_cast<ptrdiff_t>(SafeInt<ptrdiff_t>(num_heads_) * sequence_length * head_size);
args.query = Q + static_cast<size_t>(SafeInt<size_t>(b) * static_cast<size_t>(q_batch_stride_elems));
args.q_batch_stride = static_cast<size_t>(q_batch_stride_elems);
args.k_cache = present_key_data +
static_cast<size_t>(b) * kv_num_heads_ * seqlen_present_kv_cache * packed_row_bytes;
args.v_cache = present_value_data +
Expand All @@ -884,10 +891,13 @@ class GQAAttentionBase {
args.output = output->MutableData<float>() +
static_cast<size_t>(b) * sequence_length * hidden_size;

// Slice attention bias for this batch (the kernel sees batch_size=1, so batch_idx=0 inside)
// Slice attention bias for this batch (the kernel sees batch_size=1, so batch_idx=0 inside).
// Bias shape is [batch|1, num_heads|1, S, T]; the batch stride uses the actual head
// extent (1 when the head dim is broadcast).
const float* batch_bias = attention_bias_data;
if (attention_bias_data != nullptr && !attention_bias_broadcast_batch) {
batch_bias += static_cast<size_t>(b) * num_heads_ * sequence_length * attention_bias_seqlen_stride;
const size_t bias_head_extent = attention_bias_broadcast_head ? 1 : static_cast<size_t>(num_heads_);
batch_bias += static_cast<size_t>(SafeInt<size_t>(b) * bias_head_extent * sequence_length * attention_bias_seqlen_stride);
}
args.attention_bias = batch_bias;
args.attention_bias_seqlen_stride = attention_bias_seqlen_stride;
Expand Down
62 changes: 32 additions & 30 deletions onnxruntime/core/mlas/inc/mlas_qkv_quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,46 +249,48 @@ MlasSVGemm(
* It avoids materializing the full [S, T] attention probability matrix.
*/
struct MlasFlashAttentionQuantizedKVArgs {
int batch_size;
int num_heads; // Q heads
int kv_num_heads; // KV heads (for GQA sharing)
int sequence_length; // Q sequence length (new tokens)
int total_seqlen; // Total KV sequence length (past + new)
int head_size;
int past_seqlen; // For computing causal positions
int local_window_size; // -1 = disabled
int seqlen_present_kv; // Buffer dimension for present KV (may be > total_seqlen)
int q_block_size; // Br (query block size)
int kv_block_size; // Bc (KV block size)
float scale; // 1/sqrt(head_size) or user-specified
int batch_size = 0;
int num_heads = 0; // Q heads
int kv_num_heads = 0; // KV heads (for GQA sharing)
int sequence_length = 0; // Q sequence length (new tokens)
int total_seqlen = 0; // Total KV sequence length (past + new)
int head_size = 0;
int past_seqlen = 0; // For computing causal positions
int local_window_size = -1; // -1 = disabled
int seqlen_present_kv = 0; // Buffer dimension for present KV (may be > total_seqlen)
int q_block_size = 0; // Br (query block size)
int kv_block_size = 0; // Bc (KV block size)
float scale = 0.0f; // 1/sqrt(head_size) or user-specified

MLAS_KV_QUANT_TYPE quant_type;
bool per_channel_k; // Whether K uses per-channel scales
bool per_channel_v; // Whether V uses per-channel scales
MLAS_KV_QUANT_TYPE quant_type = MLAS_KV_QUANT_TYPE::S8_PerTensor;
bool per_channel_k = false; // Whether K uses per-channel scales
bool per_channel_v = false; // Whether V uses per-channel scales

int thread_count;
float* buffer;
size_t buffer_size_per_thread;
int thread_count = 1;
float* buffer = nullptr;
size_t buffer_size_per_thread = 0;

const float* query; // [B, N, S, H] FP32
const uint8_t* k_cache; // [B, kv_N, seqlen_present, packed_row_bytes] quantized
const uint8_t* v_cache; // [B, kv_N, seqlen_present, packed_row_bytes] quantized
const float* k_scale; // Scalar or per-channel scales for K
const float* v_scale; // Scalar or per-channel scales for V
float* output; // [B, S, N, H] FP32
const float* query = nullptr; // [B, N, S, H] FP32
size_t q_batch_stride = 0; // element stride between consecutive batches in `query`
// (num_heads*S*H for unpacked, (num_heads+2*kv_num_heads)*S*H for packed QKV)
const uint8_t* k_cache = nullptr; // [B, kv_N, seqlen_present, packed_row_bytes] quantized
const uint8_t* v_cache = nullptr; // [B, kv_N, seqlen_present, packed_row_bytes] quantized
const float* k_scale = nullptr; // Scalar or per-channel scales for K
const float* v_scale = nullptr; // Scalar or per-channel scales for V
float* output = nullptr; // [B, S, N, H] FP32

// Attention bias (additive, applied after QK GEMM before masking/softmax).
// Shape: [B|1, N|1, S, T] where dimensions of size 1 are broadcast.
const float* attention_bias; // nullptr if no bias
int attention_bias_seqlen_stride; // stride along the T (total_seqlen) dimension = shape[3]
bool attention_bias_broadcast_batch; // true if shape[0] == 1
bool attention_bias_broadcast_head; // true if shape[1] == 1
const float* attention_bias = nullptr; // nullptr if no bias
int attention_bias_seqlen_stride = 0; // stride along the T (total_seqlen) dimension = shape[3]
bool attention_bias_broadcast_batch = true; // true if shape[0] == 1
bool attention_bias_broadcast_head = true; // true if shape[1] == 1

// Flash decoding fields (used when sequence_length == 1 and KV is split across threads).
// Partials buffer stores per-(batch, head, kv_chunk) intermediate results:
// [m_partial, l_partial, output_partial[head_size]] for each chunk.
float* flash_decoding_partials; // nullptr to disable flash decoding
int kv_chunk_count; // number of KV chunks = ceil(total_seqlen / kv_block_size)
float* flash_decoding_partials = nullptr; // nullptr to disable flash decoding
int kv_chunk_count = 0; // number of KV chunks = ceil(total_seqlen / kv_block_size)
};

/**
Expand Down
27 changes: 19 additions & 8 deletions onnxruntime/core/mlas/lib/flashattn_qkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ MlasFlashAttentionQuantizedKVThreaded(
? args->v_scale + kv_head_idx * static_cast<size_t>(head_size)
: args->v_scale;

// Q pointer: layout [batch, num_heads, seq, head_size] or packed
// Q pointer: layout [batch, num_heads, seq, head_size]. The batch stride is
// supplied separately (args->q_batch_stride) so the kernel works with both the
// standard BNSH layout and packed-QKV input where Q/K/V are interleaved per batch.
const float* q_ptr = args->query +
(static_cast<size_t>(batch_idx) * static_cast<size_t>(num_heads) +
static_cast<size_t>(head_idx)) * static_cast<size_t>(sequence_length) * static_cast<size_t>(head_size) +
static_cast<size_t>(batch_idx) * args->q_batch_stride +
static_cast<size_t>(head_idx) * static_cast<size_t>(sequence_length) * static_cast<size_t>(head_size) +
static_cast<size_t>(q_idx) * static_cast<size_t>(head_size);

// Iterate over KV blocks
Expand Down Expand Up @@ -162,10 +164,14 @@ MlasFlashAttentionQuantizedKVThreaded(
static_cast<ptrdiff_t>(args->attention_bias_seqlen_stride);
const ptrdiff_t bias_matrix_size =
static_cast<ptrdiff_t>(sequence_length) * bias_seqlen_stride;
// The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch
// stride uses the actual head extent (1 when the head dim is broadcast).
const ptrdiff_t bias_head_extent =
args->attention_bias_broadcast_head ? 1 : static_cast<ptrdiff_t>(num_heads);
ptrdiff_t bias_offset = 0;
if (!args->attention_bias_broadcast_batch) {
bias_offset += static_cast<ptrdiff_t>(batch_idx) *
static_cast<ptrdiff_t>(num_heads) * bias_matrix_size;
bias_head_extent * bias_matrix_size;
}
if (!args->attention_bias_broadcast_head) {
bias_offset += static_cast<ptrdiff_t>(head_idx) * bias_matrix_size;
Expand Down Expand Up @@ -378,10 +384,11 @@ MlasFlashDecodingQuantizedKVThreaded(
? args->v_scale + kv_head_idx * static_cast<size_t>(head_size)
: args->v_scale;

// Q pointer: layout [batch, num_heads, 1, head_size] (sequence_length=1)
// Q pointer: layout [batch, num_heads, 1, head_size] (sequence_length=1).
// The batch stride is supplied separately to support packed-QKV input.
const float* q_ptr = args->query +
(static_cast<size_t>(batch_idx) * static_cast<size_t>(num_heads) +
static_cast<size_t>(head_idx)) * static_cast<size_t>(head_size);
static_cast<size_t>(batch_idx) * args->q_batch_stride +
static_cast<size_t>(head_idx) * static_cast<size_t>(head_size);

// Step 1: QK^T GEMM for this KV chunk
const uint8_t* k_block = k_cache_head + static_cast<size_t>(ir) * packed_row_bytes;
Expand All @@ -405,10 +412,14 @@ MlasFlashDecodingQuantizedKVThreaded(
const ptrdiff_t bias_seqlen_stride =
static_cast<ptrdiff_t>(args->attention_bias_seqlen_stride);
const ptrdiff_t bias_matrix_size = bias_seqlen_stride; // S=1
// The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch stride
// uses the actual head extent (1 when the head dim is broadcast).
const ptrdiff_t bias_head_extent =
args->attention_bias_broadcast_head ? 1 : static_cast<ptrdiff_t>(num_heads);
ptrdiff_t bias_offset = 0;
if (!args->attention_bias_broadcast_batch) {
bias_offset += static_cast<ptrdiff_t>(batch_idx) *
static_cast<ptrdiff_t>(num_heads) * bias_matrix_size;
bias_head_extent * bias_matrix_size;
}
if (!args->attention_bias_broadcast_head) {
bias_offset += static_cast<ptrdiff_t>(head_idx) * bias_matrix_size;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/mlas/bench/bench_qkv_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ static void BM_GQA_Flash(benchmark::State& state) {
args.buffer = buffer.data();
args.buffer_size_per_thread = buffer_size_per_thread;
args.query = query.data();
args.q_batch_stride = static_cast<size_t>(num_heads) * seq_len * head_size;
args.k_cache = k_cache.data();
args.v_cache = v_cache.data();
args.k_scale = k_scale.data();
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/mlas/unittest/test_qkv_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ class MlasFlashAttentionQuantizedKVTest : public MlasTestBase {
args.buffer = flash_buffer;
args.buffer_size_per_thread = buffer_size_per_thread;
args.query = Q;
args.q_batch_stride = seq_len * head_size;
args.k_cache = k_quant;
args.v_cache = v_quant;
args.k_scale = k_scale;
Expand Down Expand Up @@ -592,6 +593,7 @@ class MlasFlashAttentionQuantizedKVTest : public MlasTestBase {
args.buffer = flash_buffer;
args.buffer_size_per_thread = buffer_size_per_thread;
args.query = Q;
args.q_batch_stride = head_size;
args.k_cache = k_quant;
args.v_cache = v_quant;
args.k_scale = k_scale_buf;
Expand Down
Loading
Loading