Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,11 @@ class IExecutionProvider {

/**
Run the instantiated graph.
@param sync If true, synchronize the device/stream after replay to ensure completion before returning.
If false, the caller is responsible for synchronization.
EPs that always replay synchronously may ignore this parameter.
*/
virtual common::Status ReplayGraph(int /*graph_annotation_id*/) {
virtual common::Status ReplayGraph(int /*graph_annotation_id*/, bool /*sync*/ = true) {
return Status::OK();
}

Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured(CudaGraphAnnotatio
return cuda_graph_.IsGraphCaptured(graph_annotation_id);
}

Status CUDAExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) {
return cuda_graph_.Replay(graph_annotation_id);
Status CUDAExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t graph_annotation_id, bool sync) {
return cuda_graph_.Replay(graph_annotation_id, sync);
}

void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture(
Expand Down Expand Up @@ -524,7 +524,7 @@ Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunO
GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id);
// CUDA work issued to a capturing stream doesn’t actually run on the GPU,
// so run the captured graph here to actually execute the work.
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id));
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream));
} else {
GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(cuda_graph_annotation_id);
}
Expand Down Expand Up @@ -559,8 +559,8 @@ bool CUDAExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
return GetPerThreadContext().IsGraphCaptured(graph_annotation_id);
}

Status CUDAExecutionProvider::ReplayGraph(int graph_annotation_id) {
return GetPerThreadContext().ReplayGraph(graph_annotation_id);
Status CUDAExecutionProvider::ReplayGraph(int graph_annotation_id, bool sync) {
return GetPerThreadContext().ReplayGraph(graph_annotation_id, sync);
}

namespace cuda {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class CUDAExecutionProvider : public IExecutionProvider {

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override;
Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) override;
Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id, bool sync = true) override;
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
}
Expand Down Expand Up @@ -205,7 +205,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id);
bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const;
Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id);
Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync = true);
void IncrementRegularRunCountBeforeGraphCapture(CudaGraphAnnotation_t cuda_graph_annotation_id);

private:
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/cuda/plugin/cuda_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,10 @@ OrtStatus* ORT_API_CALL CudaEp::ReplayGraphImpl(OrtEp* this_ptr, int graph_annot
ORT_EP_FAIL, "ReplayGraph called but CUDA graph manager is not initialized");
}
PL_CUDA_CALL_THROW(cudaSetDevice(ep->config_.device_id));
return ep->GetPerThreadContext().cuda_graph.Replay(graph_annotation_id);
// Launch graph without sync. The caller (PluginExecutionProvider::ReplayGraph)
// handles synchronization based on disable_synchronize_execution_providers.
// This function is only called from that bridge code path.
return ep->GetPerThreadContext().cuda_graph.Replay(graph_annotation_id, /*sync=*/false);

EXCEPTION_TO_STATUS_END
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ namespace Dml
return m_impl->GraphCaptured(graph_annotation_id);
}

Status ReplayGraph(int graph_annotation_id) override
// The sync parameter is ignored: DML EP always replays synchronously.
Status ReplayGraph(int graph_annotation_id, bool /*sync*/ = true) override
Comment thread
tianleiwu marked this conversation as resolved.
{
return m_impl->ReplayGraph(graph_annotation_id);
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,8 @@ bool JsExecutionProvider::IsGraphCaptured(int) const {
return is_graph_captured_;
}

Status JsExecutionProvider::ReplayGraph(int) {
Status JsExecutionProvider::ReplayGraph(int, bool /*sync*/) {
// The sync parameter is ignored: JS EP always replays synchronously.
ORT_ENFORCE(IsGraphCaptured(0));
EM_ASM({ Module.jsepReplay(); });
return Status::OK();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/js/js_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class JsExecutionProvider : public IExecutionProvider {

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;
Status ReplayGraph(int graph_annotation_id, bool sync = true) override;
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,8 @@ bool NvExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
return false;
}

Status NvExecutionProvider::ReplayGraph(int graph_annotation_id) {
Status NvExecutionProvider::ReplayGraph(int graph_annotation_id, bool /*sync*/) {
// The sync parameter is ignored: NV TRT RTX EP manages its own CUDA graph lifecycle and always replays synchronously.
// This is hardcoded to always return OK because we are not allowing the ORT framework to have the CUDA graph control.
(void)graph_annotation_id;
Comment thread
tianleiwu marked this conversation as resolved.
return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ class NvExecutionProvider : public IExecutionProvider {
// CUDA Graph support
bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;
Status ReplayGraph(int graph_annotation_id, bool sync = true) override;
void HandleCudaGraphStart(cudaStream_t stream, bool require_io_binding, CudaGraphAnnotation_t cuda_graph_annotation_id, bool& graph_replay_on_this_run, bool& should_start_capture);

static common::Status RefitEngine(std::string onnx_model_filename,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1915,7 +1915,8 @@ bool TensorrtExecutionProvider::IsGraphCaptured(int) const {
return is_graph_captured_;
}

Status TensorrtExecutionProvider::ReplayGraph(int) {
Status TensorrtExecutionProvider::ReplayGraph(int, bool /*sync*/) {
// The sync parameter is ignored: TRT EP always replays synchronously under a lock_guard in compute_func().
ORT_ENFORCE(IsGraphCaptured(0));
// Please note that CUDAGraph::Replay() is not thread safe.
// ORT TRT calls ReplayGraph() in compute_func() where synchronization is enforced due to lock_guard(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;
Status ReplayGraph(int graph_annotation_id, bool sync = true) override;

static common::Status RefitEngine(std::string onnx_model_filename,
std::string& onnx_model_folder_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,8 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
return graph_annotation_id != -1 && captured_graph_ids_.contains(graph_annotation_id);
}

Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id, bool /*sync*/) {
// The sync parameter is ignored: WebGPU EP always replays synchronously.
ORT_ENFORCE(IsGraphCaptured(graph_annotation_id));
// TODO: enable profiling in run level
if (session_profiler_ && session_profiler_->Enabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class WebGpuExecutionProvider : public IExecutionProvider {

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;
Status ReplayGraph(int graph_annotation_id, bool sync = true) override;
Status ReleaseCapturedGraph(int graph_annotation_id) override;
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3268,7 +3268,10 @@ Status InferenceSession::RunImpl(const RunOptions& run_options,
<< " with graph annotation id: " << graph_annotation_id;
// log evaluation start to trace logging provider
env.GetTelemetryProvider().LogEvaluationStart(session_id_);
ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id));
bool sync_graph_replay = run_options.config_options.GetConfigOrDefault(
Comment thread
tianleiwu marked this conversation as resolved.
kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id,
sync_graph_replay));
} else {
InlinedVector<IExecutionProvider*> exec_providers_to_stop;
exec_providers_to_stop.reserve(execution_providers_.NumProviders());
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -1109,9 +1109,9 @@ class InferenceSession {
return cached_execution_provider_for_graph_replay_ != nullptr && graph_annotation_id != kGraphAnnotationSkip;
}

Status ReplayGraph(int graph_annotation_id) {
Status ReplayGraph(int graph_annotation_id, bool sync = true) {
if (cached_execution_provider_for_graph_replay_) {
return cached_execution_provider_for_graph_replay_->ReplayGraph(graph_annotation_id);
return cached_execution_provider_for_graph_replay_->ReplayGraph(graph_annotation_id, sync);
}
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cached EP instance for graph replay is not set yet before calling ReplayGraph()");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1047,11 +1047,15 @@ bool PluginExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
return ort_ep_->IsGraphCaptured(ort_ep_.get(), graph_annotation_id);
}

Status PluginExecutionProvider::ReplayGraph(int graph_annotation_id) {
Status PluginExecutionProvider::ReplayGraph(int graph_annotation_id, bool sync) {
if (ort_ep_->ort_version_supported < 26 || ort_ep_->ReplayGraph == nullptr) {
return Base::ReplayGraph(graph_annotation_id);
return Base::ReplayGraph(graph_annotation_id, sync);
}
return ToStatusAndRelease(ort_ep_->ReplayGraph(ort_ep_.get(), graph_annotation_id));
ORT_RETURN_IF_ERROR(ToStatusAndRelease(ort_ep_->ReplayGraph(ort_ep_.get(), graph_annotation_id)));
if (sync) {
ORT_RETURN_IF_ERROR(Sync());
}
return Status::OK();
}

Status PluginExecutionProvider::ReleaseCapturedGraph(int graph_annotation_id) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class PluginExecutionProvider : public IExecutionProvider {

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(int graph_annotation_id) const override;
common::Status ReplayGraph(int graph_annotation_id) override;
common::Status ReplayGraph(int graph_annotation_id, bool sync = true) override;
common::Status ReleaseCapturedGraph(int graph_annotation_id) override;
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override;

Expand Down
Loading
Loading