Skip to content

Commit 479dd39

Browse files
authored
Remove cudaStreamSynchronize from CUDA LLM ops for CUDA graph capture compatibility (microsoft#27484)
This pull request refactors validation logic for CUDA attention masks and tensor scatter operations to move error checking from host-side (CPU) to device-side (GPU) using CUDA kernel assertions (`CUDA_KERNEL_ASSERT`). This change eliminates synchronous host-device memory transfers and stream synchronizations, improving performance and simplifying code. Corresponding test cases are updated to only expect validation failures on the CPU, as CUDA errors are now asynchronous. Key changes: **Attention mask validation (GQA path):** - Removes host-side validation and memory copies for boolean attention masks in `attention.cc`; mask validity (right-padding, contiguous True/False) is now checked asynchronously via `CUDA_KERNEL_ASSERT` in the CUDA kernel. [[1]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL385-L387) [[2]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL414-L418) [[3]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL427-L448) - Updates the CUDA kernel and its interface to drop the `validation_result` buffer and rely on device assertions for mask validation. Documentation is updated to reflect this asynchronous error checking. [[1]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L10-R17) [[2]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L34) [[3]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L81-R76) [[4]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L104-R92) [[5]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L118) [[6]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L137) [[7]](diffhunk://#diff-8aa9a15a92d7dc138346dce5de055911895d940ba2183b4ba45bd95ac0e5bfc9L37-L45) **TensorScatter write_indices validation:** - Removes host-side validation and synchronization for `write_indices` in `tensorscatter.cc`; index bounds checking is now performed asynchronously inside the CUDA kernel via `CUDA_KERNEL_ASSERT`. [[1]](diffhunk://#diff-d69233ff3987fe3093132a31710b6b64cc0a32140e2a5a415a2f1f0907bd22d2L75-R76) [[2]](diffhunk://#diff-1694a04b8ba9963cc06d651ec6a3be8aa9cb2bcb73c2438dc251ca8cdcb2eb41L31-R37) **Test updates:** - Updates negative test cases for `TensorScatter` to run only on CPU, since CUDA now validates asynchronously and will not synchronously return errors to the host. [[1]](diffhunk://#diff-8c90e642cc0cf4e68b2f3d4e4b3f1e21bf6d07f01663d424bc52c75ad0db2dfeR300) [[2]](diffhunk://#diff-8c90e642cc0cf4e68b2f3d4e4b3f1e21bf6d07f01663d424bc52c75ad0db2dfeL311-R319) [[3]](diffhunk://#diff-8c90e642cc0cf4e68b2f3d4e4b3f1e21bf6d07f01663d424bc52c75ad0db2dfeL327-R339) [[4]](diffhunk://#diff-8c90e642cc0cf4e68b2f3d4e4b3f1e21bf6d07f01663d424bc52c75ad0db2dfeL342-R354)
1 parent 5f94d6c commit 479dd39

6 files changed

Lines changed: 30 additions & 77 deletions

File tree

onnxruntime/core/providers/cuda/llm/attention.cc

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
382382
// the MHA path below, where 2D masks follow ONNX broadcasting: [A, B] → [1, 1, A, B], so
383383
// 2D = (q_seq_len, total_seq_len) with both batch and heads broadcast.
384384
if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
385-
// Allocate validation result buffer on GPU
386-
auto validation_buffer = GetScratchBuffer<int>(parameters.batch_size, context->GetComputeStream());
387-
388385
// Get mask dimensions for broadcasting
389386
// attn_mask can be 2D, 3D, or 4D and broadcasts to (batch_size, num_heads, q_seq_len, total_seq_len)
390387
const auto& mask_shape = attn_mask->Shape();
@@ -411,11 +408,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
411408
"Boolean attn_mask must be 2D, 3D, or 4D. Got ", mask_dims, "D.");
412409
}
413410

414-
// Launch CUDA kernel to convert mask to seqlens_k and validate
411+
// Launch CUDA kernel to convert mask to seqlens_k.
412+
// Mask validity (right-padding, contiguous) is checked asynchronously via CUDA_KERNEL_ASSERT.
415413
ORT_RETURN_IF_ERROR(LaunchConvertMaskToSeqlensK(
416414
attn_mask->Data<bool>(),
417415
seqlens_k_buffer.get(),
418-
validation_buffer.get(),
419416
parameters.batch_size,
420417
parameters.total_sequence_length,
421418
mask_dims,
@@ -424,28 +421,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
424421
mask_dim2,
425422
cuda_stream,
426423
device_prop.maxThreadsPerBlock));
427-
428-
// Copy validation results to CPU and check for errors
429-
std::vector<int> validation_host(parameters.batch_size);
430-
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(validation_host.data(), validation_buffer.get(),
431-
sizeof(int) * parameters.batch_size,
432-
cudaMemcpyDeviceToHost, cuda_stream));
433-
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
434-
435-
for (int b = 0; b < parameters.batch_size; ++b) {
436-
if (validation_host[b] == 1) {
437-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
438-
"Boolean attn_mask for batch ", b,
439-
" does not start with True. "
440-
"GQA path only supports right-padding masks where valid tokens come first.");
441-
} else if (validation_host[b] == 2) {
442-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
443-
"Boolean attn_mask for batch ", b,
444-
" is not contiguous. "
445-
"GQA path only supports right-padding masks with contiguous True values "
446-
"followed by contiguous False values (no interleaving).");
447-
}
448-
}
449424
} else if (attn_mask != nullptr) {
450425
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
451426
"Non-boolean attn_mask is not supported yet in GQA path of Attention op (CUDA).");

onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,14 @@
77
namespace onnxruntime {
88
namespace cuda {
99

10-
// Validation error codes (stored in validation_result buffer)
11-
constexpr int kValidationOK = 0;
12-
constexpr int kValidationErrorNotStartWithTrue = 1;
13-
constexpr int kValidationErrorNotContiguous = 2;
14-
1510
// CUDA kernel to convert boolean attention mask to sequence lengths.
16-
// Also validates that the mask follows right-padding convention.
11+
// Also validates that the mask follows right-padding convention via CUDA_KERNEL_ASSERT.
1712
//
1813
// The kernel processes one batch per thread.
1914
// For each batch, it finds the first False in the mask row, which indicates
2015
// where padding starts. The sequence length is the index of first False.
2116
//
22-
// Validation:
17+
// Validation (via CUDA_KERNEL_ASSERT, reported asynchronously):
2318
// - The mask must start with True (first element must be True)
2419
// - After the first False, all remaining elements must be False (contiguous padding)
2520
//
@@ -31,7 +26,6 @@ constexpr int kValidationErrorNotContiguous = 2;
3126
__global__ void ConvertMaskToSeqlensKernel(
3227
const bool* __restrict__ attn_mask,
3328
int* __restrict__ seqlens_k,
34-
int* __restrict__ validation_result,
3529
const int batch_size,
3630
const int total_seq_len,
3731
const int mask_dims,
@@ -78,15 +72,8 @@ __global__ void ConvertMaskToSeqlensKernel(
7872
mask_row = attn_mask + effective_batch * batch_stride + h_idx * head_stride + q_idx * q_stride;
7973
}
8074

81-
// Initialize validation result for this batch
82-
validation_result[batch_idx] = kValidationOK;
83-
84-
// Check that mask starts with True
85-
if (!mask_row[0]) {
86-
validation_result[batch_idx] = kValidationErrorNotStartWithTrue;
87-
seqlens_k[batch_idx] = -1; // Invalid
88-
return;
89-
}
75+
// Validate that mask starts with True (right-padding convention)
76+
CUDA_KERNEL_ASSERT(mask_row[0]); // mask must start with True
9077

9178
// Find the first False (where padding starts)
9279
// All elements before this should be True, all after should be False
@@ -101,10 +88,8 @@ __global__ void ConvertMaskToSeqlensKernel(
10188
seq_len = i;
10289
found_first_false = true;
10390
} else if (found_first_false && current) {
104-
// Found True after False - this is invalid (not contiguous)
105-
validation_result[batch_idx] = kValidationErrorNotContiguous;
106-
seqlens_k[batch_idx] = -1; // Invalid
107-
return;
91+
// Found True after False - mask is not contiguous (invalid)
92+
CUDA_KERNEL_ASSERT(false); // mask must be contiguous (no True after False)
10893
}
10994
}
11095

@@ -115,7 +100,6 @@ __global__ void ConvertMaskToSeqlensKernel(
115100
Status LaunchConvertMaskToSeqlensK(
116101
const bool* attn_mask_bool,
117102
int* seqlens_k,
118-
int* validation_result,
119103
int batch_size,
120104
int total_seq_len,
121105
int mask_dims,
@@ -134,7 +118,6 @@ Status LaunchConvertMaskToSeqlensK(
134118
ConvertMaskToSeqlensKernel<<<blocks, threads, 0, stream>>>(
135119
attn_mask_bool,
136120
seqlens_k,
137-
validation_result,
138121
batch_size,
139122
total_seq_len,
140123
mask_dims,

onnxruntime/core/providers/cuda/llm/attention_mask_impl.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,13 @@ namespace cuda {
3434
//
3535
// Returns:
3636
// Status::OK() on success
37-
// Error status if mask is invalid (not right-padding, doesn't start with True, etc.)
3837
//
39-
// Note: This function validates the mask on GPU and will return an error if:
40-
// - The mask doesn't start with True for any batch
41-
// - The True/False values are not contiguous (e.g., True, False, True pattern)
38+
// Note: Mask validity (right-padding convention, starts with True, contiguous True/False)
39+
// is checked asynchronously via CUDA_KERNEL_ASSERT inside the kernel. Invalid masks will
40+
// trigger a device-side assertion failure.
4241
Status LaunchConvertMaskToSeqlensK(
4342
const bool* attn_mask_bool,
4443
int* seqlens_k,
45-
int* validation_result, // GPU buffer for validation, size = batch_size
4644
int batch_size,
4745
int total_seq_len,
4846
int mask_dims,

onnxruntime/core/providers/cuda/llm/tensorscatter.cc

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,8 @@ Status TensorScatter::ComputeInternal(OpKernelContext* context) const {
7272
write_indices_tensor->Shape()[0] == batch_size,
7373
"TensorScatter: write_indices must have shape [batch_size]");
7474
write_indices = write_indices_tensor->Data<int64_t>();
75-
76-
// Copy write_indices to host for validation (batch_size elements, negligible overhead).
77-
std::vector<int64_t> host_write_indices(static_cast<size_t>(batch_size));
78-
CUDA_RETURN_IF_ERROR(
79-
cudaMemcpyAsync(host_write_indices.data(), write_indices,
80-
static_cast<size_t>(batch_size) * sizeof(int64_t),
81-
cudaMemcpyDeviceToHost, Stream(context)));
82-
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream(context)));
83-
84-
for (int64_t b = 0; b < batch_size; ++b) {
85-
int64_t wi = host_write_indices[static_cast<size_t>(b)];
86-
ORT_ENFORCE(wi >= 0, "TensorScatter: write_indices[", b, "] = ", wi, " is negative");
87-
if (!circular_) {
88-
ORT_ENFORCE(wi + sequence_length <= max_sequence_length,
89-
"TensorScatter linear mode: write_indices[", b, "] + sequence_length (",
90-
wi, " + ", sequence_length, ") exceeds max_sequence_length (", max_sequence_length, ")");
91-
}
92-
}
75+
// write_indices values (non-negative, in-bounds) are validated asynchronously
76+
// inside the CUDA kernel via CUDA_KERNEL_ASSERT to avoid cudaStreamSynchronize.
9377
}
9478

9579
// Allocate output with the same shape as past_cache.

onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ __global__ void _TensorScatterKernel(
2828

2929
int64_t batch_idx = prefix_idx / prefix_stride_for_batch;
3030
int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0;
31-
// write_indices are validated on the host before kernel launch.
31+
CUDA_KERNEL_ASSERT(wi >= 0);
3232
int64_t cache_pos;
3333
if (circular) {
3434
cache_pos = (wi + seq_idx) % max_seq_len;
3535
} else {
3636
cache_pos = wi + seq_idx;
37+
CUDA_KERNEL_ASSERT(cache_pos < max_seq_len);
3738
}
3839

3940
int64_t out_offset = prefix_idx * (max_seq_len * suffix_count) + cache_pos * suffix_count + suffix_idx;

onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ TEST(TensorScatterTest, InPlace_IOBinding) {
297297
}
298298

299299
// Negative write_indices should fail validation.
300+
// Run CPU-only: CUDA validates asynchronously via CUDA_KERNEL_ASSERT.
300301
TEST(TensorScatterTest, Linear_NegativeWriteIndex) {
301302
OpTester test("TensorScatter", 24);
302303
test.AddAttribute<std::string>("mode", "linear");
@@ -308,10 +309,14 @@ TEST(TensorScatterTest, Linear_NegativeWriteIndex) {
308309
test.AddOutput<float>("present_cache", {1, 4, 3},
309310
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
310311

311-
test.Run(OpTester::ExpectResult::kExpectFailure, "is negative");
312+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
313+
execution_providers.push_back(DefaultCpuExecutionProvider());
314+
test.Run(OpTester::ExpectResult::kExpectFailure, "is negative",
315+
{}, nullptr, &execution_providers);
312316
}
313317

314318
// Linear mode: write_indices + sequence_length > max_sequence_length should fail.
319+
// Run CPU-only: CUDA validates asynchronously via CUDA_KERNEL_ASSERT.
315320
TEST(TensorScatterTest, Linear_OutOfBoundsWriteIndex) {
316321
OpTester test("TensorScatter", 24);
317322
test.AddAttribute<std::string>("mode", "linear");
@@ -324,10 +329,14 @@ TEST(TensorScatterTest, Linear_OutOfBoundsWriteIndex) {
324329
test.AddOutput<float>("present_cache", {1, 4, 3},
325330
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
326331

327-
test.Run(OpTester::ExpectResult::kExpectFailure, "exceeds max_sequence_length");
332+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
333+
execution_providers.push_back(DefaultCpuExecutionProvider());
334+
test.Run(OpTester::ExpectResult::kExpectFailure, "exceeds max_sequence_length",
335+
{}, nullptr, &execution_providers);
328336
}
329337

330338
// Circular mode: negative write_indices should still fail.
339+
// Run CPU-only: CUDA validates asynchronously via CUDA_KERNEL_ASSERT.
331340
TEST(TensorScatterTest, Circular_NegativeWriteIndex) {
332341
OpTester test("TensorScatter", 24);
333342
test.AddAttribute<std::string>("mode", "circular");
@@ -339,7 +348,10 @@ TEST(TensorScatterTest, Circular_NegativeWriteIndex) {
339348
test.AddOutput<float>("present_cache", {1, 4, 2},
340349
{0, 0, 0, 0, 0, 0, 0, 0});
341350

342-
test.Run(OpTester::ExpectResult::kExpectFailure, "is negative");
351+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
352+
execution_providers.push_back(DefaultCpuExecutionProvider());
353+
test.Run(OpTester::ExpectResult::kExpectFailure, "is negative",
354+
{}, nullptr, &execution_providers);
343355
}
344356

345357
} // namespace test

0 commit comments

Comments
 (0)