Skip to content

Commit 58a87dc

Browse files
authored
Add CUDA Graph support for the CUDA plugin EP (#28002)
## Description This PR brings CUDA graph capture/replay to the CUDA plugin execution provider so plugin-based CUDA deployments can get the same reduced CPU launch overhead that the in-tree CUDA EP already supports. It also adds the ORT framework and plugin-C-API plumbing needed to let graph-enabled plugin EPs participate safely in warmup, capture, and replay, while preserving compatibility with older plugins through version-gated fallbacks. ## Summary of Changes ### CUDA plugin EP runtime and allocator integration | File | Change | |------|--------| | `onnxruntime/core/providers/cuda/plugin/cuda_ep.cc` | Implements plugin-side graph capture lifecycle callbacks, per-thread graph context management, graph replay, and stream selection for graph-enabled runs. | | `onnxruntime/core/providers/cuda/plugin/cuda_ep.h` | Adds CUDA graph configuration/state to the plugin EP, including per-thread graph context ownership. | | `onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc` | Adds `CudaGraphSet`/`CudaGraphManager` to own captured graphs and coordinate warmup, capture, and replay by annotation ID. | | `onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h` | Declares the new graph manager types and graph-related constants. | | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc` | Adds external-stream wrapping so graph-enabled runs can reuse the thread’s graph stream without taking ownership of it. | | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h` | Declares the external-stream initialization path and stream ownership tracking. | | `onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc` | Parses `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture` provider/session options for the plugin EP. | | `onnxruntime/core/providers/cuda/plugin/cuda_mempool_allocator_plugin.cc` | Updates allocator behavior needed for CUDA native mempool compatibility during graph capture/replay. | | `onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h` | Adjusts plugin kernel/device helpers used by the graph-enabled execution path. | | `onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h` | Adds supporting helpers used by the plugin CUDA graph flow. | ### ORT framework and plugin API support for graph replay | File | Change | |------|--------| | `include/onnxruntime/core/session/onnxruntime_ep_c_api.h` | Documents and extends the plugin EP contract for graph-enabled runs, including replay behavior relative to `OnRunStart`/`OnRunEnd`. | | `include/onnxruntime/core/framework/execution_provider.h` | Adds graph-capture node-assignment policy support to the execution provider interface. | | `onnxruntime/core/session/inference_session.cc` | Generalizes the session replay path and warmup/capture retry loop so ORT can drive graph replay for graph-capable EPs. | | `onnxruntime/core/session/inference_session.h` | Updates replay-related messaging and supporting declarations for the new run flow. | | `onnxruntime/core/framework/session_state.cc` | Makes device-stream collection reuse thread-affine so warmup/capture/replay reuse stays on the owning thread. | | `onnxruntime/core/framework/session_state.h` | Adds supporting state for the thread-affine stream collection pool. | | `onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc` | Bridges the new graph callbacks, hardens validation of plugin graph support, and exposes effective plugin provider options gathered from session config. | | `onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h` | Stores provider options and declares the new accessor/graph bridge behavior. | | `onnxruntime/core/providers/webgpu/webgpu_execution_provider.h` | Aligns graph-capture policy support with the new execution-provider interface. | | `onnxruntime/core/providers/js/js_execution_provider.h` | Aligns graph-capture policy support with the new execution-provider interface. | ### Tests and validation coverage | File | Change | |------|--------| | `onnxruntime/test/python/transformers/test_cuda_plugin_ep.py` | Adds end-to-end CUDA graph tests for warmup/capture/replay, replay after input updates, CUDA mempool mode, multiple graph annotation IDs, multi-GPU/device-id coverage, and a simple Add model. | ### Documentation | File | Change | |------|--------| | `docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md` | Adds a dedicated design/implementation document covering architecture, lifecycle, allocator interaction, concurrency, and verification guidance. | | `docs/cuda_plugin_ep/cuda_plugin_ep_design.md` | Updates the broader plugin EP design doc to reflect that CUDA graph support is implemented and documents the framework-level changes. | | `docs/cuda_plugin_ep/QUICK_START.md` | Updates quick-start/testing guidance and removes the outdated “no CUDA Graph support” limitation. | ## Testing - Build ONNX Runtime with `onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON`, install the generated wheel, and deploy the CUDA plugin shared library as described in `docs/cuda_plugin_ep/QUICK_START.md`. - Run `python onnxruntime/test/python/transformers/test_cuda_plugin_ep.py`. - Pay particular attention to the new CUDA graph scenarios in that suite: warmup/capture/replay, replay after in-place input updates, CUDA mempool mode, multiple `gpu_graph_id` captures, and the second-device path when multiple GPUs are available. - Verify backward compatibility by confirming older plugins still load safely through the version-gated graph callback bridge, and that graph-disabled runs continue through the normal execution path. ## Motivation and Context The CUDA plugin EP exists to decouple CUDA EP delivery from core ONNX Runtime releases, but that model only works well if important runtime optimizations are also available through the plugin path. CUDA graph replay is one of the highest-value CUDA execution optimizations because it eliminates repeated kernel-launch overhead after capture, especially for steady-state inference workloads. Supporting that in the plugin EP required more than adding plugin-local capture code. ORT also needed a framework-level replay flow that works for plugin EPs, a plugin C API contract for graph support and node-assignment policy, and thread-affine stream reuse so captured graph resources and stream wrappers are not reused across unrelated threads. This PR packages those pieces together and documents the resulting behavior for future plugin EP work. It also depends on earlier plugin allocator work so warmup can stabilize allocations before capture begins. ## Checklist - [x] Tests added/updated - [x] Documentation updated (if applicable) - [x] No breaking changes (or documented in description)
1 parent ce91376 commit 58a87dc

24 files changed

Lines changed: 1365 additions & 223 deletions

docs/cuda_plugin_ep/QUICK_START.md

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ build.bat --cmake_generator "Visual Studio 17 2022" --config Release --build_whe
1212
--cudnn_home "D:\path\to\cudnn-installation-root" ^
1313
--use_vcpkg --use_binskim_compliant_compile_flags ^
1414
--cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=native" ^
15-
--cmake_extra_defines "onnxruntime_BUILD_UNIT_TESTS=ON" ^
1615
--cmake_extra_defines "onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON"
1716
```
1817

@@ -106,7 +105,7 @@ The focused validation script for the CUDA Plugin EP is `onnxruntime/test/python
106105

107106
### Test prerequisites
108107

109-
- Build ONNX Runtime with `onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON` and `onnxruntime_BUILD_UNIT_TESTS=ON`.
108+
- Build ONNX Runtime with `onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON`.
110109
- Install the built ONNX Runtime wheel.
111110
- Install Python test dependencies. `test_cuda_plugin_ep.py` uses PyTorch for CPU-side reference computations, so CPU-only PyTorch is sufficient.
112111

@@ -151,16 +150,10 @@ python test_cuda_plugin_ep.py
151150

152151
The script validates plugin registration, device enumeration, provider options, operator coverage, and that key nodes are actually assigned to `CudaPluginExecutionProvider`.
153152

154-
## Known Limitations
155-
* The plugin does not currently support CUDA Graphs.
156-
* The plugin direct-allocates memory using `cudaMalloc` resulting in a potential performance penalty compared to the integrated Memory Arena.
157153

158154
## Verification
159155
You can generate a parity report comparing the kernels available in the plugin EP versus the statically linked CUDA EP.
160156
```bash
161-
# Check static source registration parity:
162-
python tools/ci_build/cuda_plugin_parity_report.py
163-
164157
# Check runtime registry parity:
165158
python tools/ci_build/cuda_plugin_parity_report.py --runtime --plugin-ep-lib build/Linux/RelWithDebInfo/libonnxruntime_providers_cuda_plugin.so
166159
```
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# CUDA Graph Support for CUDA Plugin EP
2+
3+
## Design Overview
4+
5+
### Background
6+
7+
The CUDA Plugin EP is a standalone shared library (`libonnxruntime_providers_cuda_plugin.so`) that implements the OrtEp C API, allowing CUDA EP updates independent of ORT releases. CUDA graph capture/replay is a critical performance optimization that records a sequence of GPU operations into a graph, then replays it with minimal CPU overhead on subsequent runs.
8+
9+
The OrtEp C API (v1.26+) provides four graph-capture callbacks:
10+
11+
| Callback | Signature | Purpose |
12+
|----------|-----------|---------|
13+
| `IsGraphCaptureEnabled` | `bool(const OrtEp*)` | Report whether graph capture is enabled |
14+
| `IsGraphCaptured` | `bool(const OrtEp*, int graph_annotation_id)` | Check if a graph has been captured for a given annotation ID |
15+
| `ReplayGraph` | `OrtStatus*(OrtEp*, int graph_annotation_id)` | Launch a previously captured graph |
16+
| `GetGraphCaptureNodeAssignmentPolicy` | `OrtGraphCaptureNodeAssignmentPolicy(const OrtEp*)` | Specify validation strictness for node assignment |
17+
18+
These are supplemented by the existing `OnRunStart` / `OnRunEnd` lifecycle callbacks that drive the capture workflow.
19+
20+
### Architecture
21+
22+
```
23+
Session::Run()
24+
25+
├─ Run 1..N (warmup): OnRunStart → kernel dispatch → OnRunEnd (increment counter)
26+
27+
├─ Run N+1 (capture): OnRunStart → cudaStreamBeginCapture → kernel dispatch
28+
│ → OnRunEnd → cudaStreamEndCapture → cudaGraphInstantiate → Replay
29+
30+
└─ Run N+2+ (replay): IsGraphCaptured() → true → ReplayGraph() → cudaGraphLaunch
31+
(OnRunStart/OnRunEnd are NOT called during replay)
32+
```
33+
34+
**Key design choices:**
35+
36+
- Each thread gets its own dedicated graph `cudaStream_t`, `CudaGraphManager`, and capture bookkeeping for the EP instance. `CudaSyncStream::InitHandlesWithExternalStream()` wraps the thread's graph stream so graph capture sees the same stream as kernels. The manager stores captured `cudaGraphExec_t` executables keyed by annotation ID, allowing multiple graphs (e.g., different input shapes) for that thread.
37+
- Warm-up runs (default: 2) allow memory allocations to stabilize before capture begins.
38+
- Graph annotation IDs are parsed from `OrtRunOptions` key `"gpu_graph_id"`. ID `-1` skips capture; `0` is the default.
39+
40+
### New Components
41+
42+
- **`CudaGraphSet`** — Hash map storage for `cudaGraphExec_t`, keyed by annotation ID. Owns the CUDA graph exec resources.
43+
- **`CudaGraphManager`** — Orchestrates capture lifecycle: `CaptureBegin()`, `CaptureEnd()`, `Replay()`, warm-up tracking via `IncrementRegularRunCount()` / `IsGraphCaptureAllowed()`.
44+
- **`CudaEp::PerThreadContext`** — Per-thread owner for the graph stream, `CudaGraphManager`, and the pre-capture free-memory watermark. The context is owned by a thread-local cache keyed by `CudaEp*`, so it is destroyed automatically when that thread exits. `CudaEp` keeps weak references to live thread-local cache maps only so it can erase its entry during EP teardown, and it prunes expired cache-map references while creating new contexts.
45+
- **`CudaSyncStream::InitHandlesWithExternalStream()`** — Wraps an external (non-owned) `cudaStream_t` for registration/lifecycle tracking. Migrated kernels bind cuBLAS/cuDNN/cuBLASLt through thread-local fallback handles at dispatch time when the wrapper does not own library handles.
46+
47+
### Config Options
48+
49+
| Option Key | Type | Default | Description |
50+
|-----------|------|---------|-------------|
51+
| `ep.cudapluginexecutionprovider.enable_cuda_graph` | bool | false | Enable CUDA graph capture/replay |
52+
| `ep.cudapluginexecutionprovider.min_num_runs_before_cuda_graph_capture` | int | 2 | Warmup runs before capture |
53+
54+
Legacy aliases `ep.cuda.enable_cuda_graph` and `enable_cuda_graph` are also supported. For the warm-up count, `ep.cuda.min_num_runs_before_cuda_graph_capture` is also accepted.
55+
56+
---
57+
58+
## Implementation Summary
59+
60+
### Files Changed
61+
62+
| File | Change |
63+
|------|--------|
64+
| `onnxruntime/core/providers/cuda/plugin/cuda_ep.cc` | Implemented graph capture callbacks (`OnRunStartImpl`, `OnRunEndImpl`, `IsGraphCaptureEnabledImpl`, `IsGraphCapturedImpl`, `ReplayGraphImpl`, `IsConcurrentRunSupportedImpl`), updated `CreateSyncStreamForDeviceImpl` to use the current thread's graph stream when graph capture is enabled, added per-thread graph state, preserved `sync_stream` synchronization, and added a `cudaMemGetInfo` defensive allocation check |
65+
| `onnxruntime/core/providers/cuda/plugin/cuda_ep.h` | Added `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture` config fields, graph callback declarations, and a per-thread graph context cache |
66+
| `onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc` | **NEW** — Complete `CudaGraphSet` and `CudaGraphManager` implementation |
67+
| `onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h` | **NEW** — Header for graph manager types and constants |
68+
| `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc` | Added `InitHandlesWithExternalStream()`, updated destructor for `owns_stream_` |
69+
| `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h` | Added `InitHandlesWithExternalStream()` declaration, `owns_stream_` member |
70+
| `onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc` | Added config parsing for `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture` |
71+
| `include/onnxruntime/core/session/onnxruntime_ep_c_api.h` | Added `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph`, `GetGraphCaptureNodeAssignmentPolicy` callbacks and `OrtGraphCaptureNodeAssignmentPolicy` enum to `OrtEp` |
72+
| `include/onnxruntime/core/framework/execution_provider.h` | Added `GetGraphCaptureNodeAssignmentPolicy()` virtual to `IExecutionProvider` |
73+
| `onnxruntime/core/session/inference_session.cc` | Replaced hard-coded EP name list with policy-driven graph capture validation loop; added bounded recursion via `RunImpl()` with `kMaxGraphCaptureWarmupRuns`; graph-enabled runs now reacquire stream collections through ORT core's thread-affine pool across internal warm-up/capture recursion |
74+
| `onnxruntime/core/framework/session_state.cc` | Sharded the `DeviceStreamCollection` cache by caller thread using per-thread lifetime tokens, so stream wrappers are only reused on the creating thread |
75+
| `onnxruntime/core/framework/session_state.h` | Added thread-affine stream pool bucket state for `DeviceStreamCollection` reuse |
76+
| `onnxruntime/core/session/inference_session.h` | Added `RunImpl()` private method and `kMaxGraphCaptureWarmupRuns` constant |
77+
| `onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc` | Added version-gated `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph`, `GetGraphCaptureNodeAssignmentPolicy` bridge implementations |
78+
| `onnxruntime/core/providers/webgpu/ep/ep.cc` | Added graph capture callback delegation to underlying `IExecutionProvider` |
79+
80+
### Key Design Decisions
81+
82+
- **`GetGraphCaptureNodeAssignmentPolicy`**: Returns `ALLOW_CPU_FOR_SHAPES` — consistent with the non-plugin CUDA EP behavior and allows shape-inference nodes on CPU.
83+
- **Thread safety**: Mutable graph state and graph streams are stored per thread. ORT core's `DeviceStreamCollection` cache is also thread-affine, so graph-enabled runs can recycle stream wrappers without exposing them to a different thread.
84+
- **Scope**: Capture/replay pipeline plus allocator compatibility. Arena integration is complete — see the [Arena Allocator Integration](#arena-allocator-integration) section.
85+
- **Callback assignment**: `IsGraphCaptureEnabled` and `GetGraphCaptureNodeAssignmentPolicy` are always set. `OnRunStart`, `OnRunEnd` are conditional on `enable_cuda_graph`. `IsGraphCaptured` and `ReplayGraph` are always set (return false/error when disabled).
86+
- **Stream management**: `CreateSyncStreamForDevice` remains unconditional — it branches internally to use the current thread's graph stream (via `InitHandlesWithExternalStream`) when graph capture is enabled, or creates an owned stream when disabled.
87+
- **Run-end synchronization**: `OnRunEndImpl` honors the `sync_stream` flag without double-synchronizing replayed graphs, preserving the normal EP completion contract.
88+
- **Stream collection reuse**: ORT core now recycles `DeviceStreamCollection` objects into a thread-affine session pool keyed by a per-thread lifetime token. Warm-up, capture, replay, and later user-visible `Run()` calls on the same thread can reuse the same stream wrappers, while dead-thread buckets are pruned before they can be reused by another thread.
89+
- **Per-thread context lifecycle**: Thread-local caches hold the strong `PerThreadContext` references, so CUDA streams and captured graph executables are released when the owning thread exits. The EP tracks weak references to those cache maps to remove stale entries during EP destruction without keeping the contexts alive.
90+
91+
### Arena Allocator Integration
92+
93+
CUDA graph capture requires that all memory allocations happen during warmup, not during capture. The plugin arena allocator (PR #27931) is now landed and integrated with the graph capture path.
94+
95+
**Allocation-during-capture detection:**
96+
97+
- `OnRunStartImpl` records free GPU memory in the per-thread context via `cudaMemGetInfo` before `CaptureBegin`.
98+
- `OnRunEndImpl` compares post-capture free memory in the same per-thread context. If it decreased, a warning is logged advising the user to increase `min_num_runs_before_cuda_graph_capture`.
99+
- This `cudaMemGetInfo` check is retained as a last-line diagnostic after arena integration, because custom arena options, insufficient warm-up, or regressions can still surface allocation-during-capture issues.
100+
101+
**Arena integration details (now implemented):**
102+
103+
- Default CUDA device allocations come from the plugin-hosted arena (`CudaArenaAllocator`). During warmup runs, the arena grows to accommodate all needed chunks; during capture and replay, the same chunks are reused without `cudaMalloc` calls.
104+
- When `arena.use_cuda_mempool=1` is configured, CUDA device allocations come from `CudaMempoolOrtAllocator`, which wraps `cudaMallocFromPoolAsync`/`cudaFreeAsync`. These async allocation/free operations are CUDA-graph-safe since CUDA 11.4+ and become part of the captured graph topology.
105+
- Pinned allocations are also arena-backed, but remain non-stream-aware.
106+
- The graph stream created by `CudaEp::PerThreadContext` flows through `CudaSyncStream::InitHandlesWithExternalStream()` so stream-aware arena allocation uses the same `cudaStream_t` during warm-up, capture, and replay.
107+
- `CudaSyncStream::OnSessionRunEndImpl()` resets arena chunk-to-stream assignments via `factory_.ResetDeviceArenaChunksUsingStream()` at the end of each run, even for graph-enabled runs. `OnSessionRunEnd` executes before the stream collection is recycled into the current thread's pool bucket.
108+
- The plugin allocator's `OrtMemoryInfo::alloc_type` stays as `OrtDeviceAllocator`; the arena remains opaque to ORT core.
109+
110+
### Concurrent Run Support
111+
112+
Concurrent `Session::Run()` is supported with CUDA graph enabled:
113+
114+
- `CudaEp::PerThreadContext` owns the graph stream, graph manager, warm-up run counts, and memory watermark for the current thread.
115+
- The current thread's cache owns the `PerThreadContext`; new threads get independent contexts, and exited threads release their contexts automatically.
116+
- `CreateSyncStreamForDeviceImpl()` wraps the current thread's graph stream, so warm-up, capture, and replay all use the same stream for that thread.
117+
- `CudaGraphManager::CaptureBegin()` uses `cudaStreamCaptureModeThreadLocal`, allowing overlapping capture scopes on different threads.
118+
- ORT core recycles graph-enabled `DeviceStreamCollection` objects into a thread-affine session pool, so internal warm-up/capture recursion and later top-level `Run()` calls on the same thread reuse the same stream wrappers without cross-thread leakage.
119+
- `IsGraphCaptured()` and `ReplayGraph()` resolve the current thread's graph context. If a new thread runs a graph-enabled session for the first time, that thread performs its own warm-up and capture before replaying.
120+
121+
## Verification
122+
123+
1. Build and deploy the plugin using the instructions in [QUICK_START.md](QUICK_START.md#build-instructions) and [QUICK_START.md](QUICK_START.md#running-tests).
124+
2. Run `onnxruntime/test/python/transformers/test_cuda_plugin_ep.py` as described in [QUICK_START.md](QUICK_START.md#running-tests).
125+
3. The CUDA graph tests in that script validate:
126+
- `test_cuda_graph_capture_and_replay` — warmup + capture + replay with default arena
127+
- `test_cuda_graph_replay_with_updated_input` — in-place input update after graph capture
128+
- `test_cuda_graph_with_mempool` — graph capture with `arena.use_cuda_mempool=1`
129+
- `test_cuda_graph_annotation_id` — multiple graphs via `gpu_graph_id` run config
130+
- `test_cuda_graph_add_model` — graph capture with Add op (arena-backed)
131+
132+
## Future Work
133+
134+
1. **Profiling integration**: CUDA graph replay currently bypasses the CUDA plugin EP profiler path because the CUDA plugin EP does not yet implement `OrtEp::CreateProfiler`. Wiring graph replay into that path is future work.

0 commit comments

Comments
 (0)