Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions samples/zello_world/zello_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,13 @@ int main( int argc, char *argv[] )
zeEventHostSynchronize(event, UINT64_MAX );
std::cout << "Congratulations, the device completed execution!\n";

zelCheckIsLoaderInTearDown();
zeContextDestroy(context);
zelCheckIsLoaderInTearDown();
zeCommandListDestroy(command_list);
zelCheckIsLoaderInTearDown();
zeEventDestroy(event);
zelCheckIsLoaderInTearDown();
zeEventPoolDestroy(event_pool);

if (tracing_enabled) {
Expand Down
164 changes: 121 additions & 43 deletions source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ namespace ze_lib
ze_lib::context = nullptr;
}
}
#define ZEL_STABILITY_THREAD_STATE_RUNNING 1
#define ZEL_STABILITY_THREAD_STATE_POLLING 0
#define ZEL_STABILITY_THREAD_STATE_SHUTDOWN -1
bool delayContextDestruction = false;
std::mutex *stabilityMutex = nullptr;
std::promise<int> *stabilityPromiseResult = nullptr;
std::future<int> *resultFutureResult = nullptr;
std::atomic<int> *stabilityCheckThreadStatus = nullptr;
std::thread *stabilityThread = nullptr;
#endif
bool destruction = false;

Expand All @@ -43,6 +51,35 @@ namespace ze_lib
if (loader) {
FREE_DRIVER_LIBRARY( loader );
}
if (ze_lib::stabilityCheckThreadStatus)
ze_lib::stabilityCheckThreadStatus->store(ZEL_STABILITY_THREAD_STATE_SHUTDOWN);
try {
if (stabilityThread && stabilityThread->joinable()) {
stabilityThread->join();
}
} catch (...) {
// Ignore any exceptions from thread join
}
if (stabilityThread) {
delete stabilityThread;
stabilityThread = nullptr;
}
if (stabilityMutex) {
delete stabilityMutex;
stabilityMutex = nullptr;
}
if (stabilityPromiseResult) {
delete stabilityPromiseResult;
stabilityPromiseResult = nullptr;
}
if (resultFutureResult) {
delete resultFutureResult;
resultFutureResult = nullptr;
}
if (stabilityCheckThreadStatus) {
delete stabilityCheckThreadStatus;
stabilityCheckThreadStatus = nullptr;
}
#endif
ze_lib::destruction = true;
};
Expand Down Expand Up @@ -149,6 +186,10 @@ namespace ze_lib
std::string version_message = "Loader API Version to be requested is v" + std::to_string(ZE_MAJOR_VERSION(version)) + "." + std::to_string(ZE_MINOR_VERSION(version));
debug_trace_message(version_message, "");
loaderDriverGet = reinterpret_cast<ze_pfnDriverGet_t>(GET_FUNCTION_PTR(loader, "zeDriverGet"));
stabilityMutex = new std::mutex();
stabilityPromiseResult = new std::promise<int>();
resultFutureResult = new std::future<int>(stabilityPromiseResult->get_future());
stabilityCheckThreadStatus = new std::atomic<int>(ZEL_STABILITY_THREAD_STATE_POLLING);
#else
result = zeLoaderInit();
if( ZE_RESULT_SUCCESS == result ) {
Expand Down Expand Up @@ -410,61 +451,55 @@ zelSetDelayLoaderContextTeardown()
#define ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL 1
#define ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED 2
#define ZEL_STABILITY_CHECK_RESULT_EXCEPTION 3
// The stability check thread timeout in milliseconds
#define ZEL_STABILITY_CHECK_THREAD_TIMEOUT 100
// The stability check thread polling interval in nanoseconds
#define ZEL_STABILITY_CHECK_THREAD_POLLING_INTERVAL 100

/**
* @brief Performs a stability check for the Level Zero loader.
*
* This function checks the stability of the Level Zero loader by verifying
* the presence of the loader module, the validity of the `zeDriverGet` function
* pointer, and the ability to retrieve driver information. The result of the
* stability check is communicated through the provided promise.
* This function verifies the stability of the Level Zero loader by checking:
* - The presence of the loader module.
* - The validity of the `zeDriverGet` function pointer.
* - The ability to retrieve driver information.
*
* @param stabilityPromise A promise object used to communicate the result of
* the stability check. The promise is set with one of
* the following values:
* - ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL: The
* `zeDriverGet` function pointer is invalid.
* - ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED: The
* loader failed to retrieve driver information.
* - ZEL_STABILITY_CHECK_RESULT_EXCEPTION: An
* exception occurred during the stability check.
* - ZEL_STABILITY_CHECK_RESULT_SUCCESS: The stability
* check was successful.
* The result of the stability check is returned as an integer, with the following possible values:
* - `ZEL_STABILITY_CHECK_RESULT_SUCCESS`: The stability check was successful.
* - `ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL`: The `zeDriverGet` function pointer is invalid.
* - `ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED`: The loader failed to retrieve driver information.
* - `ZEL_STABILITY_CHECK_RESULT_EXCEPTION`: An exception occurred during the stability check.
*
* @note If debug tracing is enabled, debug messages are logged for each failure
* scenario.
* @note If the Loader is completely torn down, this thread is expected to be killed
* due to invalid memory access and the stability check will determine a failure.
* If debug tracing is enabled, debug messages are logged for each failure scenario.
*
* @exception This function catches all exceptions internally and does not throw.
* @return An integer indicating the result of the stability check.
*
* @note If the loader is completely torn down, this function may fail due to invalid memory access.
* @note This function catches all exceptions internally and does not throw.
*/
void stabilityCheck(std::promise<int> stabilityPromise) {
int stabilityCheck() {
try {
if (!ze_lib::context->loaderDriverGet) {
if (ze_lib::context->debugTraceEnabled) {
std::string message = "LoaderDriverGet is a bad pointer. Exiting stability checker thread.";
std::string message = "LoaderDriverGet is a bad pointer. Exiting stability checker.";
ze_lib::context->debug_trace_message(message, "");
}
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL);
return;
return ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL;
}

uint32_t driverCount = 0;
ze_result_t result = ZE_RESULT_ERROR_UNINITIALIZED;
result = ze_lib::context->loaderDriverGet(&driverCount, nullptr);
if (result != ZE_RESULT_SUCCESS || driverCount == 0) {
if (ze_lib::context->debugTraceEnabled) {
std::string message = "Loader stability check failed. Exiting stability checker thread.";
std::string message = "Loader stability check failed. Exiting stability checker.";
ze_lib::context->debug_trace_message(message, "");
}
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED);
return;
return ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED;
}
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_SUCCESS);
return;
return ZEL_STABILITY_CHECK_RESULT_SUCCESS;
} catch (...) {
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_EXCEPTION);
return;
return ZEL_STABILITY_CHECK_RESULT_EXCEPTION;
}
}
#endif
Expand All @@ -490,18 +525,60 @@ zelCheckIsLoaderInTearDown() {
return true;
}
#ifdef DYNAMIC_LOAD_LOADER
std::promise<int> stabilityPromise;
std::future<int> resultFuture = stabilityPromise.get_future();
int result = -1;
static bool unstable = false;
int threadResult = -1;
if (unstable) {
return true;
}
try {
// Launch the stability checker thread
std::thread stabilityThread(stabilityCheck, std::move(stabilityPromise));
result = resultFuture.get(); // Blocks until the result is available
if (ze_lib::context->debugTraceEnabled) {
std::string message = "Stability checker thread completed with result: " + std::to_string(result);
ze_lib::context->debug_trace_message(message, "");
// Launch the stability checker thread on the first call
static std::once_flag stabilityThreadFlag;
std::lock_guard<std::mutex> lock(*ze_lib::stabilityMutex);
*ze_lib::stabilityPromiseResult = std::promise<int>();
*ze_lib::resultFutureResult = ze_lib::stabilityPromiseResult->get_future();
ze_lib::stabilityCheckThreadStatus->store(ZEL_STABILITY_THREAD_STATE_RUNNING);
std::call_once(stabilityThreadFlag, []() {
ze_lib::stabilityThread = new std::thread([]() {
while (true) {
while(ze_lib::stabilityCheckThreadStatus && ze_lib::stabilityCheckThreadStatus->load() == ZEL_STABILITY_THREAD_STATE_POLLING) {
std::this_thread::sleep_for(std::chrono::nanoseconds(ZEL_STABILITY_CHECK_THREAD_POLLING_INTERVAL));
}
if (ze_lib::destruction || ze_lib::context == nullptr) {
break;
}
if (!ze_lib::stabilityCheckThreadStatus) {
break;
}
if (ze_lib::stabilityCheckThreadStatus->load() == ZEL_STABILITY_THREAD_STATE_SHUTDOWN) {
break;
}
ze_lib::stabilityCheckThreadStatus->store(ZEL_STABILITY_THREAD_STATE_POLLING);
int result = stabilityCheck();
if (result != ZEL_STABILITY_CHECK_RESULT_SUCCESS) {
if (ze_lib::context->debugTraceEnabled) {
std::string message = "Loader stability check thread failed with result: " + std::to_string(result);
ze_lib::context->debug_trace_message(message, "");
}
if (ze_lib::stabilityPromiseResult) {
ze_lib::stabilityPromiseResult->set_value(result);
}
break; // Exit the thread if stability check fails
}
if (ze_lib::stabilityPromiseResult) {
ze_lib::stabilityPromiseResult->set_value(result);
}
}
});
});
if (ze_lib::resultFutureResult->wait_for(std::chrono::milliseconds(ZEL_STABILITY_CHECK_THREAD_TIMEOUT)) == std::future_status::timeout) {
if (ze_lib::context->debugTraceEnabled) {
std::string message = "Stability Thread timeout, assuming thread has crashed";
ze_lib::context->debug_trace_message(message, "");
}
threadResult = ZEL_STABILITY_CHECK_RESULT_EXCEPTION;
} else {
threadResult = ze_lib::resultFutureResult->get();
}
stabilityThread.join();
} catch (const std::exception& e) {
if (ze_lib::context->debugTraceEnabled) {
std::string message = "Exception caught in parent thread: " + std::string(e.what());
Expand All @@ -513,11 +590,12 @@ zelCheckIsLoaderInTearDown() {
ze_lib::context->debug_trace_message(message, "");
}
}
if (result != ZEL_STABILITY_CHECK_RESULT_SUCCESS) {
if (threadResult != ZEL_STABILITY_CHECK_RESULT_SUCCESS) {
if (ze_lib::context->debugTraceEnabled) {
std::string message = "Loader stability check failed with result: " + std::to_string(result);
std::string message = "Loader stability check failed with result: " + std::to_string(threadResult);
ze_lib::context->debug_trace_message(message, "");
}
unstable = true;
return true;
}
#endif
Expand Down
21 changes: 21 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,27 @@ endif()
add_test(NAME tests_loader_teardown_check COMMAND tests --gtest_filter=*GivenLoaderNotInDestructionStateWhenCallingzelCheckIsLoaderInTearDownThenFalseIsReturned)
set_property(TEST tests_loader_teardown_check PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1")

add_test(NAME test_zello_world_legacy COMMAND zello_world --enable_legacy_init --enable_null_driver --force_loader_intercepts --enable_validation_layer --enable_tracing_layer --enable_tracing_layer_runtime)
set_property(TEST test_zello_world_legacy PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1")

add_test(NAME test_zello_world_legacy_intercept COMMAND zello_world --enable_legacy_init --enable_null_driver --force_loader_intercepts)
set_property(TEST test_zello_world_legacy_intercept PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1")

add_test(NAME test_zello_world_legacy_validation_layer COMMAND zello_world --enable_legacy_init --enable_null_driver --enable_validation_layer)
set_property(TEST test_zello_world_legacy_validation_layer PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1")

add_test(NAME test_zello_world_legacy_tracing COMMAND zello_world --enable_legacy_init --enable_null_driver --enable_tracing_layer)
set_property(TEST test_zello_world_legacy_tracing PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1")

add_test(NAME test_zello_world_legacy_dynamic_tracing COMMAND zello_world --enable_legacy_init --enable_null_driver --enable_tracing_layer_runtime)
set_property(TEST test_zello_world_legacy_dynamic_tracing PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1")

add_test(NAME test_zello_world_legacy_all_tracing COMMAND zello_world --enable_legacy_init --enable_null_driver --force_loader_intercepts --enable_validation_layer --enable_tracing_layer)
set_property(TEST test_zello_world_legacy_all_tracing PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1")

add_test(NAME test_zello_world_legacy_all_tracing_dynamic COMMAND zello_world --enable_legacy_init --enable_null_driver --force_loader_intercepts --enable_validation_layer --enable_tracing_layer_runtime)
set_property(TEST test_zello_world_legacy_all_tracing_dynamic PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1")

# These tests are currently not supported on Windows. The reason is that the std::cerr is not being redirected to a pipe in Windows to be then checked against the expected output.
if(NOT MSVC)
add_test(NAME tests_event_deadlock COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingzeCommandListAppendMemoryCopyWithCircularDependencyOnEventsThenValidationLayerPrintsWarningOfDeadlock*)
Expand Down
3 changes: 3 additions & 0 deletions test/loader_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ TEST(

EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0));
EXPECT_FALSE(zelCheckIsLoaderInTearDown());
EXPECT_FALSE(zelCheckIsLoaderInTearDown());
EXPECT_FALSE(zelCheckIsLoaderInTearDown());
EXPECT_FALSE(zelCheckIsLoaderInTearDown());
}

class CaptureOutput {
Expand Down
Loading