-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[webgpu] Per-graph buffer manager for WebGPU multi-graph capture #28260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
94a2ca4
bc5cfdb
d24af03
96ce10b
93e3db6
cd24850
9d65cd2
7417a7a
ff5a405
92f9fd7
2acc028
d9288df
09b39c9
f171fc8
09900dd
d8226af
f6db1b7
60f1ed1
ea0e3a6
086811b
c805958
f9e5ffc
b293976
0f61e15
aa63475
f5992e9
d0369a5
e7cfaf0
bb9ebdd
f0f3b07
3d9df67
673ed1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -103,6 +103,16 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& | |
| TensorShape output_qk_shape(output_qk_dims); | ||
| Tensor* output_qk = context.Output(3, output_qk_shape); | ||
|
|
||
| // Match CPU EP semantics: when no present_key/present_value output is requested, | ||
| // ignore past_key/past_value. The CPU EP sets past_sequence_length=0 in this case, | ||
| // effectively treating the input as if there is no KV cache. | ||
| if (present_key == nullptr && present_value == nullptr) { | ||
| past_key = nullptr; | ||
| past_value = nullptr; | ||
| parameters.past_sequence_length_ = 0; | ||
| parameters.total_sequence_length_ = parameters.kv_sequence_length_; | ||
| } | ||
|
Comment on lines
+109
to
+114
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. was this change meant to be a part of this PR? it seems unrelated.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is intentional. It fixes the CI test |
||
|
|
||
| if (output_qk == nullptr && // Flash attention does not output QK scores | ||
| CanApplyFlashAttention(parameters, context)) { | ||
| if (bias != nullptr) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to account for calls that may happen concurrently with
Ort::Session::Run()? in general (though not for the WebGPU EP),Run()itself may be called concurrently.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thread safety for concurrent
ReleaseCapturedGraph+Run()is an EP implementation detail rather than a C API contract concern. For WebGPU EP,Run()itself is not concurrent (single GPU queue), so this is not a realistic scenario. For EPs that do support concurrentRun()(e.g., CUDA), the EP would need to handle synchronization internally in itsReleaseCapturedGraphimplementation — same as other EP callbacks. No other session-mutating APIs (e.g.,SetEpDynamicOptions) document thread-safety constraints at the C API level, so adding one here would be inconsistent.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if an EP reports that
ConcurrentRunSupported()is false, ORT will lock a mutex during the call toOrt::Session::Run()to ensure that only one actual run is happening at a time. this doesn't prevent users from callingRun()concurrently.it may not make sense for an application to make concurrent calls to
Run() + ReleaseCapturedGraph()if the WebGPU EP is the only EP to consider, but an application might support more than just the WebGPU EP.I don't have a good answer to this now, but I think it may be worth some thought. if we provide no synchronization in ORT or EP code, I think the assumption should at least be documented somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
ReleaseCapturedGraphnow acquiressession_mutex_before delegating to the EP. For EPs like WebGPU whereConcurrentRunSupported()returns false, this provides mutual exclusion withRun()sinceRun()also acquiressession_mutex_in that case. For EPs that support concurrent runs,Run()does not acquiresession_mutex_, so additional synchronization would need to be handled at the EP level if needed.