Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "ep_profiling.h"
#include <algorithm>
#include <array>
#include <chrono>
#include <optional>
Expand Down Expand Up @@ -59,7 +60,8 @@ void EpEventManager::PushOrtEvent(uint64_t profiler_id) {
tls_profiling_state_.profiler_id = profiler_id;
}

void EpEventManager::PopOrtEvent(uint64_t profiler_id, const std::string& ort_event_name) {
void EpEventManager::PopOrtEvent(uint64_t profiler_id, const std::string& ort_event_name,
int64_t ort_event_start_us, int64_t ort_event_duration_us) {
std::lock_guard<std::mutex> lock(mutex_);

auto iter = profiler_state_.find(profiler_id);
Expand All @@ -77,6 +79,8 @@ void EpEventManager::PopOrtEvent(uint64_t profiler_id, const std::string& ort_ev

if (ep_event.thread_id == current_thread_id && ep_event.ort_event_name.empty()) {
ep_event.ort_event_name = ort_event_name;
ep_event.ort_event_start_us = ort_event_start_us;
ep_event.ort_event_duration_us = ort_event_duration_us;
}
}

Expand Down Expand Up @@ -169,9 +173,11 @@ OrtStatus* ORT_API_CALL ExampleKernelEpProfiler::StopEventImpl(OrtEpProfilerImpl

Ort::ConstProfilingEvent ort_event(c_ort_event);
const char* ort_event_name = ort_event.GetName();
const int64_t ort_event_start_us = ort_event.GetTimestampUs();
const int64_t ort_event_duration_us = ort_event.GetDurationUs();

// Annotate all EP events that were collected during this ORT event with metadata from the ORT event.
ep_event_manager.PopOrtEvent(self->profiler_id, ort_event_name);
ep_event_manager.PopOrtEvent(self->profiler_id, ort_event_name, ort_event_start_us, ort_event_duration_us);
return nullptr;
EXCEPTION_TO_RETURNED_STATUS_END
}
Expand Down Expand Up @@ -209,6 +215,22 @@ OrtStatus* ORT_API_CALL ExampleKernelEpProfiler::EndProfilingImpl(
raw_ep_event.end_time - raw_ep_event.start_time)
.count();

// The ORT-to-EP clock reconstruction can differ by a few microseconds on some platforms.
// Bound the EP event to the correlated ORT parent interval so the emitted child event remains
// properly nested without weakening the test's containment checks.
if (raw_ep_event.ort_event_start_us >= 0 && raw_ep_event.ort_event_duration_us >= 0) {
const int64_t parent_start_us = raw_ep_event.ort_event_start_us;
const int64_t parent_end_us = parent_start_us + raw_ep_event.ort_event_duration_us;

int64_t rel_end_us = rel_ts_us + std::max<int64_t>(dur_us, 0);
rel_ts_us = std::clamp(rel_ts_us, parent_start_us, parent_end_us);
rel_end_us = std::clamp(rel_end_us, parent_start_us, parent_end_us);
if (rel_end_us < rel_ts_us) {
rel_end_us = rel_ts_us;
}
dur_us = rel_end_us - rel_ts_us;
}

// Set parent_name as an event arg. The parent_name is just the name of the correlated ORT event.
std::unordered_map<std::string, std::string> args = {{"parent_name", raw_ep_event.ort_event_name.c_str()}};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ class EpEventManager {
std::chrono::high_resolution_clock::time_point start_time;
std::chrono::high_resolution_clock::time_point end_time;
std::string ort_event_name; // Set from the correlated ORT event
std::thread::id thread_id; // Thread that created this event
int64_t ort_event_start_us = -1;
int64_t ort_event_duration_us = -1;
std::thread::id thread_id; // Thread that created this event
};

static EpEventManager& GetInstance();
Expand All @@ -82,7 +84,8 @@ class EpEventManager {
void UnregisterProfiler(uint64_t profiler_id);

void PushOrtEvent(uint64_t profiler_id);
void PopOrtEvent(uint64_t profiler_id, const std::string& ort_event_name);
void PopOrtEvent(uint64_t profiler_id, const std::string& ort_event_name, int64_t ort_event_start_us,
int64_t ort_event_duration_us);

void AddEpEvent(uint64_t profiler_id, Event event);

Expand Down
Loading