Skip to content

Commit c29e914

Browse files
Revert "[Plugin EP] Port graph capture/replay APIs (#27958)"
This reverts commit 7afe4c2.
1 parent 0e67485 commit c29e914

20 files changed

Lines changed: 107 additions & 708 deletions

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,6 @@ class IExecutionProvider {
292292
return Status::OK();
293293
}
294294

295-
/**
296-
Get the node assignment validation policy for graph capture.
297-
When graph capture is enabled, ORT validates that nodes are assigned to EPs
298-
in a way compatible with graph capture. This tells ORT which policy to apply.
299-
*/
300-
virtual OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const {
301-
return OrtGraphCaptureNodeAssignmentPolicy_ALL_NODES_ON_EP;
302-
}
303-
304295
/**
305296
Called when session creation is complete
306297
This provides an opportunity for execution providers to optionally synchronize and

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,10 +1389,6 @@ struct Env : detail::Base<OrtEnv> {
13891389
const std::vector<Value>& dst_tensors,
13901390
OrtSyncStream* stream) const; ///< Wraps OrtApi::CopyTensors
13911391

1392-
/// Wraps OrtApi::CopyTensors
1393-
/// Copies only one src tensor to another dst tensor.
1394-
Status CopyTensor(const OrtValue* src_tensor, OrtValue* dst_tensor, OrtSyncStream* stream) const;
1395-
13961392
/// \brief Wraps OrtApi::SetPerSessionThreadPoolCallbacks
13971393
/// Stores work callbacks on the Env for per-session thread pools.
13981394
/// Only affects sessions created after this call. Does not affect global thread pools.

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,11 +1055,6 @@ inline Status Env::CopyTensors(const std::vector<Value>& src_tensors,
10551055
return Status(status);
10561056
}
10571057

1058-
inline Status Env::CopyTensor(const OrtValue* src_tensor, OrtValue* dst_tensor, OrtSyncStream* stream) const {
1059-
OrtStatus* status = GetApi().CopyTensors(p_, &src_tensor, &dst_tensor, stream, 1);
1060-
return Status(status);
1061-
}
1062-
10631058
inline UnownedAllocator Env::CreateSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type,
10641059
OrtAllocatorType allocator_type,
10651060
const OrtKeyValuePairs* allocator_options) {

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 0 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,23 +2027,6 @@ typedef enum OrtEpDataLayout {
20272027
OrtEpDataLayout_Default = OrtEpDataLayout_NCHW,
20282028
} OrtEpDataLayout;
20292029

2030-
/**
2031-
* \brief Node assignment policies for graph capture validation.
2032-
*
2033-
* When graph capture is enabled, ORT validates that nodes are assigned to EPs in a way that is
2034-
* compatible with graph capture. An EP can specify which validation policy ORT should apply.
2035-
*
2036-
* \since Version 1.26.
2037-
*/
2038-
typedef enum OrtGraphCaptureNodeAssignmentPolicy {
2039-
/** All nodes in the main graph must be assigned to this EP. No CPU fallback is allowed. */
2040-
OrtGraphCaptureNodeAssignmentPolicy_ALL_NODES_ON_EP = 0,
2041-
2042-
/** Compute nodes must be on this EP. CPU nodes are allowed for shape computation as long as
2043-
* no memory copy nodes exist. */
2044-
OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES = 1,
2045-
} OrtGraphCaptureNodeAssignmentPolicy;
2046-
20472030
/**
20482031
* \brief The OrtEp struct provides functions to implement for an execution provider.
20492032
* \since Version 1.22.
@@ -2363,101 +2346,6 @@ struct OrtEp {
23632346
*/
23642347
ORT_API2_STATUS(CreateProfiler, _In_ OrtEp* this_ptr,
23652348
_Outptr_result_maybenull_ OrtEpProfilerImpl** profiler);
2366-
2367-
/** \brief Indicate whether the graph capturing mode (e.g., CUDA graph) is enabled for the provider.
2368-
*
2369-
* Graph capture allows an EP to record a sequence of device (e.g., GPU) operations during an initial run and replay
2370-
* them on subsequent runs, bypassing per-kernel CPU launch overhead.
2371-
*
2372-
* Applications enable graph capture via EP-specific provider options (e.g., `enable_cuda_graph=1`
2373-
* for the CUDA EP). An EP should return true from this function if it has been configured to enable
2374-
* graph capture/replay.
2375-
*
2376-
* **ORT graph capture/replay summary:**
2377-
* During OrtSession initialization, ORT calls OrtEp::IsGraphCaptureEnabled() on each EP in the order specified during
2378-
* provider registration with the session. If an EP returns true, ORT validates that the graph is suitable for
2379-
* graph capture, and if so, caches the EP for graph capture during the next run. The graph validation ensures
2380-
* that there are no control flow nodes and that node-to-EP assignments are compatible with the policy specified
2381-
* by the EP via OrtEp::GetGraphCaptureNodeAssignmentPolicy().
2382-
* Note that an OrtSession only supports graph capture for one EP (i.e., the first EP to claim support).
2383-
*
2384-
* During the first call to OrtApi::Run() for the OrtSession, ORT performs multiple internal runs of the model
2385-
* until the EP indicates that the graph has been captured by returning `true` from `OrtEp::IsGraphCaptured()`.
2386-
* If the EP is unable to capture the graph within 8 runs, the call to OrtApi::Run() returns an error OrtStatus.
2387-
* Each internal run invokes `OrtEp::OnRunStart()`, normal execution, and `OrtEp::OnRunEnd()`. EPs should use
2388-
* these run callbacks to track the number of necessary warm-up runs and begin/end graph capture when ready.
2389-
*
2390-
* After successful graph capture, subsequent calls to OrtApi::Run() skip normal execution and ORT instead calls
2391-
* `OrtEp::ReplayGraph()` directly.
2392-
*
2393-
* Applications can capture and replay multiple graphs (e.g., one per distinct input shape) by setting the
2394-
* `"gpu_graph_id"` run config entry via `OrtApi::AddRunConfigEntry()` to different integer values. ORT passes
2395-
* the value as the `graph_annotation_id` parameter to `OrtEp::IsGraphCaptured()` and `OrtEp::ReplayGraph()`.
2396-
*
2397-
* \param[in] this_ptr The OrtEp instance.
2398-
* \return true if graph capture mode is enabled, false otherwise.
2399-
*
2400-
* \note Implementation of this function is optional. If set to NULL, ORT assumes graph capture is not enabled.
2401-
* \note If this function returns true, `OrtEp::IsGraphCaptured` and `OrtEp::ReplayGraph` must also be implemented.
2402-
* If either is NULL, ORT will log a warning and ignore this EP for graph capture.
2403-
*
2404-
* \since Version 1.26.
2405-
*/
2406-
ORT_API_T(bool, IsGraphCaptureEnabled, _In_ const OrtEp* this_ptr);
2407-
2408-
/** \brief Indicate whether a graph has been captured and instantiated.
2409-
*
2410-
* ORT calls this before each `Session::Run()`. If true, ORT calls `ReplayGraph()` instead of
2411-
* normal execution. After a run where this returns false, ORT automatically retries until it
2412-
* returns true (handling warm-up runs transparently).
2413-
*
2414-
* \param[in] this_ptr The OrtEp instance.
2415-
* \param[in] graph_annotation_id Identifies which captured graph to query.
2416-
* Applications can set this value via `OrtApi::AddRunConfigEntry()` with the key `"gpu_graph_id"`.
2417-
* The default value is 0 when the run config entry is not set.
2418-
* Setting different IDs allows the EP to capture and manage multiple graphs (e.g., one per
2419-
* distinct input shape). A value of -1 means graph capture/replay should be skipped for this run.
2420-
* \return true if the graph has been captured, false otherwise.
2421-
*
2422-
* \note This function must be implemented if `OrtEp::IsGraphCaptureEnabled` is implemented and may return true.
2423-
*
2424-
* \since Version 1.26.
2425-
*/
2426-
ORT_API_T(bool, IsGraphCaptured, _In_ const OrtEp* this_ptr, _In_ int graph_annotation_id);
2427-
2428-
/** \brief Run the instantiated (captured) graph.
2429-
*
2430-
* Called by ORT instead of normal execution when `IsGraphCaptured()` returns true.
2431-
*
2432-
* \param[in] this_ptr The OrtEp instance.
2433-
* \param[in] graph_annotation_id Identifies which captured graph to replay.
2434-
* Applications can set this value via `OrtApi::AddRunConfigEntry()` with the key `"gpu_graph_id"`.
2435-
* The default value is 0 when the run config entry is not set.
2436-
* A value of -1 means graph replay should be skipped for this run.
2437-
*
2438-
* \snippet{doc} snippets.dox OrtStatus Return Value
2439-
*
2440-
* \note This function must be implemented if `OrtEp::IsGraphCaptureEnabled` is implemented and may return true.
2441-
*
2442-
* \since Version 1.26.
2443-
*/
2444-
ORT_API2_STATUS(ReplayGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id);
2445-
2446-
/** \brief Get the node assignment validation policy for graph capture.
2447-
*
2448-
* When graph capture is enabled, ORT validates that nodes are assigned to EPs in a way that is
2449-
* compatible with graph capture. This function tells ORT which validation policy to apply.
2450-
*
2451-
* \param[in] this_ptr The OrtEp instance.
2452-
* \return The node assignment policy for graph capture.
2453-
*
2454-
* \note Implementation of this function is optional. If set to NULL, ORT uses
2455-
* OrtGraphCaptureNodeAssignmentPolicy_ALL_NODES_ON_EP (strictest validation).
2456-
*
2457-
* \since Version 1.26.
2458-
*/
2459-
ORT_API_T(OrtGraphCaptureNodeAssignmentPolicy, GetGraphCaptureNodeAssignmentPolicy,
2460-
_In_ const OrtEp* this_ptr);
24612349
};
24622350

24632351
/** \brief The function signature that ORT will call to create OrtEpFactory instances.

onnxruntime/core/providers/cuda/cuda_execution_provider.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
124124
bool IsGraphCaptureEnabled() const override;
125125
bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override;
126126
Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) override;
127-
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
128-
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
129-
}
130127
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
131128
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
132129
std::vector<AllocatorPtr> CreatePreferredAllocators() override;

onnxruntime/core/providers/cuda/plugin/cuda_ep.cc

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,6 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo
5858
Compile = nullptr;
5959
ReleaseNodeComputeInfos = nullptr;
6060

61-
// Graph capture/replay
62-
IsGraphCaptureEnabled = IsGraphCaptureEnabledImpl;
63-
IsGraphCaptured = IsGraphCapturedImpl;
64-
ReplayGraph = ReplayGraphImpl;
65-
GetGraphCaptureNodeAssignmentPolicy = GetGraphCaptureNodeAssignmentPolicyImpl;
66-
6761
const OrtApi& ort_api = factory_.GetOrtApi();
6862
Ort::Status log_status(ort_api.Logger_LogMessage(&logger_, ORT_LOGGING_LEVEL_INFO,
6963
"CUDA Plugin EP created",
@@ -310,30 +304,5 @@ OrtStatus* ORT_API_CALL CudaEp::SyncImpl(OrtEp* this_ptr) noexcept {
310304
EXCEPTION_TO_STATUS_END
311305
}
312306

313-
bool ORT_API_CALL CudaEp::IsGraphCaptureEnabledImpl(const OrtEp* /*this_ptr*/) noexcept {
314-
// TODO: forward to EpImpl()->IsGraphCaptureEnabled()
315-
return false;
316-
}
317-
318-
/*static*/
319-
bool ORT_API_CALL CudaEp::IsGraphCapturedImpl(const OrtEp* /*this_ptr*/, int /*graph_annotation_id*/) noexcept {
320-
// TODO: forward to EpImpl()->IsGraphCaptured(graph_annotation_id)
321-
return false;
322-
}
323-
324-
/*static*/
325-
OrtStatus* ORT_API_CALL CudaEp::ReplayGraphImpl(OrtEp* /*this_ptr*/, int /*graph_annotation_id*/) noexcept {
326-
// TODO: forward to EpImpl()->ReplayGraph(graph_annotation_id)
327-
return Ort::GetApi().CreateStatus(ORT_NOT_IMPLEMENTED,
328-
"Graph capture replay is not yet supported in the CUDA plugin EP.");
329-
}
330-
331-
/*static*/
332-
OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL CudaEp::GetGraphCaptureNodeAssignmentPolicyImpl(
333-
const OrtEp* /*this_ptr*/) noexcept {
334-
// TODO: forward to EpImpl()->GetGraphCaptureNodeAssignmentPolicy()
335-
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
336-
}
337-
338307
} // namespace cuda_plugin
339308
} // namespace onnxruntime

onnxruntime/core/providers/cuda/plugin/cuda_ep.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,6 @@ class CudaEp : public onnxruntime::ep::adapter::Ep {
6161

6262
static OrtStatus* ORT_API_CALL SyncImpl(OrtEp* this_ptr) noexcept;
6363

64-
static bool ORT_API_CALL IsGraphCaptureEnabledImpl(const OrtEp* this_ptr) noexcept;
65-
66-
static bool ORT_API_CALL IsGraphCapturedImpl(const OrtEp* this_ptr,
67-
int graph_annotation_id) noexcept;
68-
69-
static OrtStatus* ORT_API_CALL ReplayGraphImpl(OrtEp* this_ptr,
70-
int graph_annotation_id) noexcept;
71-
72-
static OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL GetGraphCaptureNodeAssignmentPolicyImpl(
73-
const OrtEp* this_ptr) noexcept;
74-
7564
CudaEpFactory& factory_;
7665
std::string name_;
7766
Config config_;

onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,6 @@ namespace Dml
356356
return m_impl->ReplayGraph(graph_annotation_id);
357357
}
358358

359-
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override
360-
{
361-
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
362-
}
363-
364359
private:
365360
ComPtr<ExecutionProviderImpl> m_impl;
366361
};

onnxruntime/core/providers/js/js_execution_provider.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ class JsExecutionProvider : public IExecutionProvider {
7272
bool IsGraphCaptureEnabled() const override;
7373
bool IsGraphCaptured(int graph_annotation_id) const override;
7474
Status ReplayGraph(int graph_annotation_id) override;
75-
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
76-
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
77-
}
7875

7976
private:
8077
bool IsGraphCaptureAllowed() const;

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,9 +1280,7 @@ void NvExecutionProvider::HandleCudaGraphStart(cudaStream_t stream, bool require
12801280
}
12811281

12821282
bool NvExecutionProvider::IsGraphCaptureEnabled() const {
1283-
// Return false so that ORT's framework does not cache this EP for ORT-managed graph capture/replay.
1284-
// NvTensorRTRTX manages CUDA graph capture/replay internally.
1285-
return false;
1283+
return cuda_graph_enable_;
12861284
}
12871285

12881286
bool NvExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {

0 commit comments

Comments
 (0)