diff --git a/samples/zello_world/zello_world.cpp b/samples/zello_world/zello_world.cpp index ba11bb09..ff5f9329 100644 --- a/samples/zello_world/zello_world.cpp +++ b/samples/zello_world/zello_world.cpp @@ -196,9 +196,13 @@ int main( int argc, char *argv[] ) zeEventHostSynchronize(event, UINT64_MAX ); std::cout << "Congratulations, the device completed execution!\n"; + zelCheckIsLoaderInTearDown(); zeContextDestroy(context); + zelCheckIsLoaderInTearDown(); zeCommandListDestroy(command_list); + zelCheckIsLoaderInTearDown(); zeEventDestroy(event); + zelCheckIsLoaderInTearDown(); zeEventPoolDestroy(event_pool); if (tracing_enabled) { diff --git a/source/lib/ze_lib.cpp b/source/lib/ze_lib.cpp index 50be7828..a3842eff 100644 --- a/source/lib/ze_lib.cpp +++ b/source/lib/ze_lib.cpp @@ -26,7 +26,15 @@ namespace ze_lib ze_lib::context = nullptr; } } + #define ZEL_STABILITY_THREAD_STATE_RUNNING 1 + #define ZEL_STABILITY_THREAD_STATE_POLLING 0 + #define ZEL_STABILITY_THREAD_STATE_SHUTDOWN -1 bool delayContextDestruction = false; + std::mutex *stabilityMutex = nullptr; + std::promise *stabilityPromiseResult = nullptr; + std::future *resultFutureResult = nullptr; + std::atomic *stabilityCheckThreadStatus = nullptr; + std::thread *stabilityThread = nullptr; #endif bool destruction = false; @@ -43,6 +51,35 @@ namespace ze_lib if (loader) { FREE_DRIVER_LIBRARY( loader ); } + if (ze_lib::stabilityCheckThreadStatus) + ze_lib::stabilityCheckThreadStatus->store(ZEL_STABILITY_THREAD_STATE_SHUTDOWN); + try { + if (stabilityThread && stabilityThread->joinable()) { + stabilityThread->join(); + } + } catch (...) { + // Ignore any exceptions from thread join + } + if (stabilityThread) { + delete stabilityThread; + stabilityThread = nullptr; + } + if (stabilityMutex) { + delete stabilityMutex; + stabilityMutex = nullptr; + } + if (stabilityPromiseResult) { + delete stabilityPromiseResult; + stabilityPromiseResult = nullptr; + } + if (resultFutureResult) { + delete resultFutureResult; + resultFutureResult = nullptr; + } + if (stabilityCheckThreadStatus) { + delete stabilityCheckThreadStatus; + stabilityCheckThreadStatus = nullptr; + } #endif ze_lib::destruction = true; }; @@ -149,6 +186,10 @@ namespace ze_lib std::string version_message = "Loader API Version to be requested is v" + std::to_string(ZE_MAJOR_VERSION(version)) + "." + std::to_string(ZE_MINOR_VERSION(version)); debug_trace_message(version_message, ""); loaderDriverGet = reinterpret_cast(GET_FUNCTION_PTR(loader, "zeDriverGet")); + stabilityMutex = new std::mutex(); + stabilityPromiseResult = new std::promise(); + resultFutureResult = new std::future(stabilityPromiseResult->get_future()); + stabilityCheckThreadStatus = new std::atomic(ZEL_STABILITY_THREAD_STATE_POLLING); #else result = zeLoaderInit(); if( ZE_RESULT_SUCCESS == result ) { @@ -410,43 +451,40 @@ zelSetDelayLoaderContextTeardown() #define ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL 1 #define ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED 2 #define ZEL_STABILITY_CHECK_RESULT_EXCEPTION 3 +// The stability check thread timeout in milliseconds +#define ZEL_STABILITY_CHECK_THREAD_TIMEOUT 100 +// The stability check thread polling interval in nanoseconds +#define ZEL_STABILITY_CHECK_THREAD_POLLING_INTERVAL 100 /** * @brief Performs a stability check for the Level Zero loader. * - * This function checks the stability of the Level Zero loader by verifying - * the presence of the loader module, the validity of the `zeDriverGet` function - * pointer, and the ability to retrieve driver information. The result of the - * stability check is communicated through the provided promise. + * This function verifies the stability of the Level Zero loader by checking: + * - The presence of the loader module. + * - The validity of the `zeDriverGet` function pointer. + * - The ability to retrieve driver information. * - * @param stabilityPromise A promise object used to communicate the result of - * the stability check. The promise is set with one of - * the following values: - * - ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL: The - * `zeDriverGet` function pointer is invalid. - * - ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED: The - * loader failed to retrieve driver information. - * - ZEL_STABILITY_CHECK_RESULT_EXCEPTION: An - * exception occurred during the stability check. - * - ZEL_STABILITY_CHECK_RESULT_SUCCESS: The stability - * check was successful. + * The result of the stability check is returned as an integer, with the following possible values: + * - `ZEL_STABILITY_CHECK_RESULT_SUCCESS`: The stability check was successful. + * - `ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL`: The `zeDriverGet` function pointer is invalid. + * - `ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED`: The loader failed to retrieve driver information. + * - `ZEL_STABILITY_CHECK_RESULT_EXCEPTION`: An exception occurred during the stability check. * - * @note If debug tracing is enabled, debug messages are logged for each failure - * scenario. - * @note If the Loader is completely torn down, this thread is expected to be killed - * due to invalid memory access and the stability check will determine a failure. + * If debug tracing is enabled, debug messages are logged for each failure scenario. * - * @exception This function catches all exceptions internally and does not throw. + * @return An integer indicating the result of the stability check. + * + * @note If the loader is completely torn down, this function may fail due to invalid memory access. + * @note This function catches all exceptions internally and does not throw. */ -void stabilityCheck(std::promise stabilityPromise) { +int stabilityCheck() { try { if (!ze_lib::context->loaderDriverGet) { if (ze_lib::context->debugTraceEnabled) { - std::string message = "LoaderDriverGet is a bad pointer. Exiting stability checker thread."; + std::string message = "LoaderDriverGet is a bad pointer. Exiting stability checker."; ze_lib::context->debug_trace_message(message, ""); } - stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL); - return; + return ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL; } uint32_t driverCount = 0; @@ -454,17 +492,14 @@ void stabilityCheck(std::promise stabilityPromise) { result = ze_lib::context->loaderDriverGet(&driverCount, nullptr); if (result != ZE_RESULT_SUCCESS || driverCount == 0) { if (ze_lib::context->debugTraceEnabled) { - std::string message = "Loader stability check failed. Exiting stability checker thread."; + std::string message = "Loader stability check failed. Exiting stability checker."; ze_lib::context->debug_trace_message(message, ""); } - stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED); - return; + return ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED; } - stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_SUCCESS); - return; + return ZEL_STABILITY_CHECK_RESULT_SUCCESS; } catch (...) { - stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_EXCEPTION); - return; + return ZEL_STABILITY_CHECK_RESULT_EXCEPTION; } } #endif @@ -490,18 +525,60 @@ zelCheckIsLoaderInTearDown() { return true; } #ifdef DYNAMIC_LOAD_LOADER - std::promise stabilityPromise; - std::future resultFuture = stabilityPromise.get_future(); - int result = -1; + static bool unstable = false; + int threadResult = -1; + if (unstable) { + return true; + } try { - // Launch the stability checker thread - std::thread stabilityThread(stabilityCheck, std::move(stabilityPromise)); - result = resultFuture.get(); // Blocks until the result is available - if (ze_lib::context->debugTraceEnabled) { - std::string message = "Stability checker thread completed with result: " + std::to_string(result); - ze_lib::context->debug_trace_message(message, ""); + // Launch the stability checker thread on the first call + static std::once_flag stabilityThreadFlag; + std::lock_guard lock(*ze_lib::stabilityMutex); + *ze_lib::stabilityPromiseResult = std::promise(); + *ze_lib::resultFutureResult = ze_lib::stabilityPromiseResult->get_future(); + ze_lib::stabilityCheckThreadStatus->store(ZEL_STABILITY_THREAD_STATE_RUNNING); + std::call_once(stabilityThreadFlag, []() { + ze_lib::stabilityThread = new std::thread([]() { + while (true) { + while(ze_lib::stabilityCheckThreadStatus && ze_lib::stabilityCheckThreadStatus->load() == ZEL_STABILITY_THREAD_STATE_POLLING) { + std::this_thread::sleep_for(std::chrono::nanoseconds(ZEL_STABILITY_CHECK_THREAD_POLLING_INTERVAL)); + } + if (ze_lib::destruction || ze_lib::context == nullptr) { + break; + } + if (!ze_lib::stabilityCheckThreadStatus) { + break; + } + if (ze_lib::stabilityCheckThreadStatus->load() == ZEL_STABILITY_THREAD_STATE_SHUTDOWN) { + break; + } + ze_lib::stabilityCheckThreadStatus->store(ZEL_STABILITY_THREAD_STATE_POLLING); + int result = stabilityCheck(); + if (result != ZEL_STABILITY_CHECK_RESULT_SUCCESS) { + if (ze_lib::context->debugTraceEnabled) { + std::string message = "Loader stability check thread failed with result: " + std::to_string(result); + ze_lib::context->debug_trace_message(message, ""); + } + if (ze_lib::stabilityPromiseResult) { + ze_lib::stabilityPromiseResult->set_value(result); + } + break; // Exit the thread if stability check fails + } + if (ze_lib::stabilityPromiseResult) { + ze_lib::stabilityPromiseResult->set_value(result); + } + } + }); + }); + if (ze_lib::resultFutureResult->wait_for(std::chrono::milliseconds(ZEL_STABILITY_CHECK_THREAD_TIMEOUT)) == std::future_status::timeout) { + if (ze_lib::context->debugTraceEnabled) { + std::string message = "Stability Thread timeout, assuming thread has crashed"; + ze_lib::context->debug_trace_message(message, ""); + } + threadResult = ZEL_STABILITY_CHECK_RESULT_EXCEPTION; + } else { + threadResult = ze_lib::resultFutureResult->get(); } - stabilityThread.join(); } catch (const std::exception& e) { if (ze_lib::context->debugTraceEnabled) { std::string message = "Exception caught in parent thread: " + std::string(e.what()); @@ -513,11 +590,12 @@ zelCheckIsLoaderInTearDown() { ze_lib::context->debug_trace_message(message, ""); } } - if (result != ZEL_STABILITY_CHECK_RESULT_SUCCESS) { + if (threadResult != ZEL_STABILITY_CHECK_RESULT_SUCCESS) { if (ze_lib::context->debugTraceEnabled) { - std::string message = "Loader stability check failed with result: " + std::to_string(result); + std::string message = "Loader stability check failed with result: " + std::to_string(threadResult); ze_lib::context->debug_trace_message(message, ""); } + unstable = true; return true; } #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9596906b..e124e11b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -97,6 +97,27 @@ endif() add_test(NAME tests_loader_teardown_check COMMAND tests --gtest_filter=*GivenLoaderNotInDestructionStateWhenCallingzelCheckIsLoaderInTearDownThenFalseIsReturned) set_property(TEST tests_loader_teardown_check PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") +add_test(NAME test_zello_world_legacy COMMAND zello_world --enable_legacy_init --enable_null_driver --force_loader_intercepts --enable_validation_layer --enable_tracing_layer --enable_tracing_layer_runtime) +set_property(TEST test_zello_world_legacy PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1") + +add_test(NAME test_zello_world_legacy_intercept COMMAND zello_world --enable_legacy_init --enable_null_driver --force_loader_intercepts) +set_property(TEST test_zello_world_legacy_intercept PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1") + +add_test(NAME test_zello_world_legacy_validation_layer COMMAND zello_world --enable_legacy_init --enable_null_driver --enable_validation_layer) +set_property(TEST test_zello_world_legacy_validation_layer PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1") + +add_test(NAME test_zello_world_legacy_tracing COMMAND zello_world --enable_legacy_init --enable_null_driver --enable_tracing_layer) +set_property(TEST test_zello_world_legacy_tracing PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1") + +add_test(NAME test_zello_world_legacy_dynamic_tracing COMMAND zello_world --enable_legacy_init --enable_null_driver --enable_tracing_layer_runtime) +set_property(TEST test_zello_world_legacy_dynamic_tracing PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1") + +add_test(NAME test_zello_world_legacy_all_tracing COMMAND zello_world --enable_legacy_init --enable_null_driver --force_loader_intercepts --enable_validation_layer --enable_tracing_layer) +set_property(TEST test_zello_world_legacy_all_tracing PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1") + +add_test(NAME test_zello_world_legacy_all_tracing_dynamic COMMAND zello_world --enable_legacy_init --enable_null_driver --force_loader_intercepts --enable_validation_layer --enable_tracing_layer_runtime) +set_property(TEST test_zello_world_legacy_all_tracing_dynamic PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1") + # These tests are currently not supported on Windows. The reason is that the std::cerr is not being redirected to a pipe in Windows to be then checked against the expected output. if(NOT MSVC) add_test(NAME tests_event_deadlock COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingzeCommandListAppendMemoryCopyWithCircularDependencyOnEventsThenValidationLayerPrintsWarningOfDeadlock*) diff --git a/test/loader_api.cpp b/test/loader_api.cpp index 0d35fc2b..5ed18f0f 100644 --- a/test/loader_api.cpp +++ b/test/loader_api.cpp @@ -374,6 +374,9 @@ TEST( EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0)); EXPECT_FALSE(zelCheckIsLoaderInTearDown()); + EXPECT_FALSE(zelCheckIsLoaderInTearDown()); + EXPECT_FALSE(zelCheckIsLoaderInTearDown()); + EXPECT_FALSE(zelCheckIsLoaderInTearDown()); } class CaptureOutput {