Commit 479dd39
authored
Remove cudaStreamSynchronize from CUDA LLM ops for CUDA graph capture compatibility (microsoft#27484)
This pull request refactors validation logic for CUDA attention masks
and tensor scatter operations to move error checking from host-side
(CPU) to device-side (GPU) using CUDA kernel assertions
(`CUDA_KERNEL_ASSERT`). This change eliminates synchronous host-device
memory transfers and stream synchronizations, improving performance and
simplifying code. Corresponding test cases are updated to only expect
validation failures on the CPU, as CUDA errors are now asynchronous.
Key changes:
**Attention mask validation (GQA path):**
- Removes host-side validation and memory copies for boolean attention
masks in `attention.cc`; mask validity (right-padding, contiguous
True/False) is now checked asynchronously via `CUDA_KERNEL_ASSERT` in
the CUDA kernel.
[[1]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL385-L387)
[[2]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL414-L418)
[[3]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL427-L448)
- Updates the CUDA kernel and its interface to drop the
`validation_result` buffer and rely on device assertions for mask
validation. Documentation is updated to reflect this asynchronous error
checking.
[[1]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L10-R17)
[[2]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L34)
[[3]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L81-R76)
[[4]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L104-R92)
[[5]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L118)
[[6]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L137)
[[7]](diffhunk://#diff-8aa9a15a92d7dc138346dce5de055911895d940ba2183b4ba45bd95ac0e5bfc9L37-L45)
**TensorScatter write_indices validation:**
- Removes host-side validation and synchronization for `write_indices`
in `tensorscatter.cc`; index bounds checking is now performed
asynchronously inside the CUDA kernel via `CUDA_KERNEL_ASSERT`.
[[1]](diffhunk://#diff-d69233ff3987fe3093132a31710b6b64cc0a32140e2a5a415a2f1f0907bd22d2L75-R76)
[[2]](diffhunk://#diff-1694a04b8ba9963cc06d651ec6a3be8aa9cb2bcb73c2438dc251ca8cdcb2eb41L31-R37)
**Test updates:**
- Updates negative test cases for `TensorScatter` to run only on CPU,
since CUDA now validates asynchronously and will not synchronously
return errors to the host.
[[1]](diffhunk://#diff-8c90e642cc0cf4e68b2f3d4e4b3f1e21bf6d07f01663d424bc52c75ad0db2dfeR300)
[[2]](diffhunk://#diff-8c90e642cc0cf4e68b2f3d4e4b3f1e21bf6d07f01663d424bc52c75ad0db2dfeL311-R319)
[[3]](diffhunk://#diff-8c90e642cc0cf4e68b2f3d4e4b3f1e21bf6d07f01663d424bc52c75ad0db2dfeL327-R339)
[[4]](diffhunk://#diff-8c90e642cc0cf4e68b2f3d4e4b3f1e21bf6d07f01663d424bc52c75ad0db2dfeL342-R354)1 parent 5f94d6c commit 479dd39
6 files changed
Lines changed: 30 additions & 77 deletions
File tree
- onnxruntime
- core/providers/cuda/llm
- test/providers/cpu/llm
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
382 | 382 | | |
383 | 383 | | |
384 | 384 | | |
385 | | - | |
386 | | - | |
387 | | - | |
388 | 385 | | |
389 | 386 | | |
390 | 387 | | |
| |||
411 | 408 | | |
412 | 409 | | |
413 | 410 | | |
414 | | - | |
| 411 | + | |
| 412 | + | |
415 | 413 | | |
416 | 414 | | |
417 | 415 | | |
418 | | - | |
419 | 416 | | |
420 | 417 | | |
421 | 418 | | |
| |||
424 | 421 | | |
425 | 422 | | |
426 | 423 | | |
427 | | - | |
428 | | - | |
429 | | - | |
430 | | - | |
431 | | - | |
432 | | - | |
433 | | - | |
434 | | - | |
435 | | - | |
436 | | - | |
437 | | - | |
438 | | - | |
439 | | - | |
440 | | - | |
441 | | - | |
442 | | - | |
443 | | - | |
444 | | - | |
445 | | - | |
446 | | - | |
447 | | - | |
448 | | - | |
449 | 424 | | |
450 | 425 | | |
451 | 426 | | |
| |||
Lines changed: 6 additions & 23 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
10 | | - | |
11 | | - | |
12 | | - | |
13 | | - | |
14 | | - | |
15 | 10 | | |
16 | | - | |
| 11 | + | |
17 | 12 | | |
18 | 13 | | |
19 | 14 | | |
20 | 15 | | |
21 | 16 | | |
22 | | - | |
| 17 | + | |
23 | 18 | | |
24 | 19 | | |
25 | 20 | | |
| |||
31 | 26 | | |
32 | 27 | | |
33 | 28 | | |
34 | | - | |
35 | 29 | | |
36 | 30 | | |
37 | 31 | | |
| |||
78 | 72 | | |
79 | 73 | | |
80 | 74 | | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
| 75 | + | |
| 76 | + | |
90 | 77 | | |
91 | 78 | | |
92 | 79 | | |
| |||
101 | 88 | | |
102 | 89 | | |
103 | 90 | | |
104 | | - | |
105 | | - | |
106 | | - | |
107 | | - | |
| 91 | + | |
| 92 | + | |
108 | 93 | | |
109 | 94 | | |
110 | 95 | | |
| |||
115 | 100 | | |
116 | 101 | | |
117 | 102 | | |
118 | | - | |
119 | 103 | | |
120 | 104 | | |
121 | 105 | | |
| |||
134 | 118 | | |
135 | 119 | | |
136 | 120 | | |
137 | | - | |
138 | 121 | | |
139 | 122 | | |
140 | 123 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
37 | | - | |
38 | 37 | | |
39 | | - | |
40 | | - | |
41 | | - | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
42 | 41 | | |
43 | 42 | | |
44 | 43 | | |
45 | | - | |
46 | 44 | | |
47 | 45 | | |
48 | 46 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
72 | 72 | | |
73 | 73 | | |
74 | 74 | | |
75 | | - | |
76 | | - | |
77 | | - | |
78 | | - | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
| 75 | + | |
| 76 | + | |
93 | 77 | | |
94 | 78 | | |
95 | 79 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
31 | | - | |
| 31 | + | |
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
| 37 | + | |
37 | 38 | | |
38 | 39 | | |
39 | 40 | | |
| |||
Lines changed: 15 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
297 | 297 | | |
298 | 298 | | |
299 | 299 | | |
| 300 | + | |
300 | 301 | | |
301 | 302 | | |
302 | 303 | | |
| |||
308 | 309 | | |
309 | 310 | | |
310 | 311 | | |
311 | | - | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
312 | 316 | | |
313 | 317 | | |
314 | 318 | | |
| 319 | + | |
315 | 320 | | |
316 | 321 | | |
317 | 322 | | |
| |||
324 | 329 | | |
325 | 330 | | |
326 | 331 | | |
327 | | - | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
328 | 336 | | |
329 | 337 | | |
330 | 338 | | |
| 339 | + | |
331 | 340 | | |
332 | 341 | | |
333 | 342 | | |
| |||
339 | 348 | | |
340 | 349 | | |
341 | 350 | | |
342 | | - | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
343 | 355 | | |
344 | 356 | | |
345 | 357 | | |
| |||
0 commit comments