diff --git a/paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h b/paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h index 130e2fe6b12727..e9087aab1d87a8 100644 --- a/paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h +++ b/paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h @@ -23,7 +23,7 @@ #include -#include "paddle/phi/core/platform/cuda_device_guard.h" +#include "paddle/phi/backends/gpu/gpu_info.h" namespace c10::cuda { @@ -46,15 +46,13 @@ struct CUDAGuard { explicit CUDAGuard(DeviceIndex device_index) : original_device_(detail::current_cuda_device()), - current_device_(original_device_), - guard_() { + current_device_(original_device_) { set_index(device_index); } explicit CUDAGuard(Device device) : original_device_(detail::current_cuda_device()), - current_device_(original_device_), - guard_() { + current_device_(original_device_) { set_device(device); } @@ -63,18 +61,27 @@ struct CUDAGuard { CUDAGuard(CUDAGuard&& other) = delete; CUDAGuard& operator=(CUDAGuard&& other) = delete; - ~CUDAGuard() = default; + ~CUDAGuard() { + // Always restore to original_device_ to handle cases where the device + // was changed outside of this guard, matching PyTorch semantics. + phi::backends::gpu::SetDeviceId(static_cast(original_device_.index())); + } void set_device(Device device) { - current_device_ = detail::normalize_cuda_device(device); - guard_.SetDevice(current_device_._PD_GetInner()); + const Device normalized = detail::normalize_cuda_device(device); + if (normalized.index() != current_device_.index()) { + phi::backends::gpu::SetDeviceId(static_cast(normalized.index())); + current_device_ = normalized; + } } void reset_device(Device device) { set_device(device); } void set_index(DeviceIndex device_index) { - current_device_ = Device(kCUDA, device_index); - guard_.SetDeviceIndex(device_index); + if (current_device_.index() != device_index) { + phi::backends::gpu::SetDeviceId(static_cast(device_index)); + current_device_ = Device(kCUDA, device_index); + } } Device original_device() const { return original_device_; } @@ -84,7 +91,6 @@ struct CUDAGuard { private: Device original_device_; Device current_device_; - paddle::platform::CUDADeviceGuard guard_; }; struct OptionalCUDAGuard { @@ -107,12 +113,14 @@ struct OptionalCUDAGuard { OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete; OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete; - ~OptionalCUDAGuard() = default; + ~OptionalCUDAGuard() { reset(); } void set_device(Device device) { const Device normalized = detail::normalize_cuda_device(device); init_if_needed(); - guard_->SetDevice(normalized._PD_GetInner()); + if (normalized.index() != current_device_->index()) { + phi::backends::gpu::SetDeviceId(static_cast(normalized.index())); + } current_device_ = normalized; } @@ -120,7 +128,9 @@ struct OptionalCUDAGuard { void set_index(DeviceIndex device_index) { init_if_needed(); - guard_->SetDeviceIndex(device_index); + if (device_index != current_device_->index()) { + phi::backends::gpu::SetDeviceId(static_cast(device_index)); + } current_device_ = Device(kCUDA, device_index); } @@ -129,23 +139,26 @@ struct OptionalCUDAGuard { std::optional current_device() const { return current_device_; } void reset() { - guard_.reset(); + if (original_device_.has_value()) { + // Always restore to original_device_ to handle external device changes. + // This matches PyTorch OptionalDeviceGuard semantics. + phi::backends::gpu::SetDeviceId( + static_cast(original_device_->index())); + } original_device_.reset(); current_device_.reset(); } private: void init_if_needed() { - if (!guard_.has_value()) { + if (!original_device_.has_value()) { original_device_ = detail::current_cuda_device(); current_device_ = original_device_; - guard_.emplace(); } } std::optional original_device_; std::optional current_device_; - std::optional guard_; }; } // namespace c10::cuda diff --git a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.cpp b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.cpp index e13f017e35c88a..95efa28a63ad7a 100644 --- a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.cpp +++ b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.cpp @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include +#endif #include #include @@ -25,8 +29,10 @@ c10::DeviceIndex device_count() { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return phi::backends::gpu::GetGPUDeviceCount(); #else - PADDLE_THROW(common::errors::Unavailable( - "Paddle is not compiled with CUDA. Cannot visit device count.")); + // Match PyTorch c10::cuda::device_count(): return 0 in CPU-only builds so + // that is_available() and the pre-checks of synchronize() degrade gracefully + // through a single, consistent "No CUDA GPUs are available" error path. + return 0; #endif } @@ -35,21 +41,20 @@ bool is_available() { return cuda::device_count() > 0; } void synchronize(int64_t device_index) { TORCH_CHECK(is_available(), "No CUDA GPUs are available"); auto num_gpus = cuda::device_count(); - TORCH_CHECK(device_index < 0 || device_index < num_gpus, - "Device index out of range: ", - device_index); -// TODO(yongqiang) need using DeviceGuard + TORCH_CHECK( + device_index == -1 || (device_index >= 0 && device_index < num_gpus), + "Device index out of range: ", + device_index); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - paddle::platform::SetDeviceId(device_index); -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); -#endif -#else - PADDLE_THROW(common::errors::Unavailable( - "Paddle is not compiled with CUDA. Cannot visit device synchronize.")); + // Match PyTorch semantics: + // 1. `device_index == -1` means "current CUDA device". + // 2. Explicit device synchronization must not leak a changed current device + // to the caller after returning. + const c10::cuda::CUDAGuard device_guard(c10::Device( + c10::DeviceType::CUDA, static_cast(device_index))); + c10::cuda::device_synchronize(); #endif + // CPU-only builds are already rejected above by the is_available() check. } } // namespace torch::cuda diff --git a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h index 4eb38ceecc681f..5d45d82a21dc77 100644 --- a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h +++ b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h @@ -15,16 +15,16 @@ #pragma once #include - #include +#include "paddle/common/macros.h" namespace torch::cuda { -c10::DeviceIndex device_count(); +PADDLE_API c10::DeviceIndex device_count(); -bool is_available(); +PADDLE_API bool is_available(); -void synchronize(int64_t device_index = -1); +PADDLE_API void synchronize(int64_t device_index = -1); } // namespace torch::cuda namespace at::cuda { diff --git a/test/cpp/compat/ATen_CUDAContext_test.cc b/test/cpp/compat/ATen_CUDAContext_test.cc index f4c8a58dd7084b..642698cac2c923 100644 --- a/test/cpp/compat/ATen_CUDAContext_test.cc +++ b/test/cpp/compat/ATen_CUDAContext_test.cc @@ -15,10 +15,12 @@ #include #include #include +#include #include "gtest/gtest.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include #include #include "paddle/phi/backends/gpu/gpu_info.h" #endif @@ -57,6 +59,27 @@ TEST(CUDAFunctionsTest, DeviceSynchronize) { #endif } +// CPU-only: torch::cuda::synchronize must report "No CUDA GPUs are available" +// rather than the older "Cannot visit device count" produced by device_count(). +// Matches PyTorch behavior where device_count() returns 0 in CPU-only builds +// and the synchronize() pre-check is the single source of the GPU-missing +// error message. +TEST(CUDAFunctionsTest, SynchronizeReportsNoGpuMessageInCpuOnly) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + // Only relevant in CPU-only builds + return; +#else + try { + torch::cuda::synchronize(); + FAIL() << "expected exception"; + } catch (const std::exception& e) { + const std::string msg = e.what(); + EXPECT_NE(msg.find("No CUDA GPUs are available"), std::string::npos) << msg; + EXPECT_EQ(msg.find("Cannot visit device count"), std::string::npos) << msg; + } +#endif +} + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(CUDAFunctionsTest, StreamSynchronize) { if (!at::cuda::is_available()) { @@ -78,6 +101,110 @@ TEST(CUDAFunctionsTest, AtNamespaceAliases) { auto stream = c10::cuda::getCurrentCUDAStream(); ASSERT_NO_THROW(at::cuda::stream_synchronize(stream)); } + +TEST(CUDAFunctionsTest, TorchSynchronizePreservesCurrentDevice) { + if (!torch::cuda::is_available()) { + return; + } + if (torch::cuda::device_count() < 2) { + return; + } + + constexpr int current_device = 0; + constexpr int other_device = 1; + c10::cuda::CUDAGuard guard(static_cast(current_device)); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), current_device); + + ASSERT_NO_THROW(torch::cuda::synchronize(other_device)); + EXPECT_EQ(phi::backends::gpu::GetCurrentDeviceId(), current_device); +} + +TEST(CUDAFunctionsTest, SynchronizeRejectsInvalidNegativeDevice) { + if (!torch::cuda::is_available()) { + return; + } + ASSERT_THROW(torch::cuda::synchronize(-2), std::exception); +} + +TEST(CUDAFunctionsTest, CUDAGuardRestoresOriginalDeviceAfterMultipleSwitches) { + if (!torch::cuda::is_available()) { + return; + } + if (torch::cuda::device_count() < 2) { + return; + } + + constexpr int original_device = 0; + constexpr int intermediate_device = 1; + phi::backends::gpu::SetDeviceId(original_device); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device); + + { + c10::cuda::CUDAGuard guard( + static_cast(intermediate_device)); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), intermediate_device); + guard.set_index(static_cast(original_device)); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device); + guard.set_index(static_cast(intermediate_device)); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), intermediate_device); + } + + EXPECT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device); +} + +TEST(CUDAFunctionsTest, + CUDAGuardRestoresOriginalDeviceAfterReturnToOriginalThenExit) { + if (!torch::cuda::is_available()) { + return; + } + if (torch::cuda::device_count() < 2) { + return; + } + + constexpr int original_device = 0; + constexpr int intermediate_device = 1; + phi::backends::gpu::SetDeviceId(original_device); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device); + + { + c10::cuda::CUDAGuard guard( + static_cast(intermediate_device)); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), intermediate_device); + + guard.set_index(static_cast(original_device)); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device); + } + + EXPECT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device); +} + +TEST(CUDAFunctionsTest, + OptionalCUDAGuardResetRestoresOriginalDeviceAfterReturnToOriginal) { + if (!torch::cuda::is_available()) { + return; + } + if (torch::cuda::device_count() < 2) { + return; + } + + constexpr int original_device = 0; + constexpr int intermediate_device = 1; + phi::backends::gpu::SetDeviceId(original_device); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device); + + c10::cuda::OptionalCUDAGuard guard; + guard.set_index(static_cast(intermediate_device)); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), intermediate_device); + + guard.set_index(static_cast(original_device)); + ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device); + + guard.reset(); + + EXPECT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device); + EXPECT_FALSE(guard.original_device().has_value()); + EXPECT_FALSE(guard.current_device().has_value()); +} #endif // --------------------------------------------------------------------------- @@ -111,6 +238,33 @@ TEST(CUDAContextLightTest, GetNumGPUs) { #endif } +// CPU-only: device_count() must return 0 instead of throwing, matching the +// PyTorch contract that device_count() is a non-throwing query. +TEST(CUDAContextLightTest, DeviceCountReturnsZeroInCpuOnly) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + // Only relevant in CPU-only builds + return; +#else + ASSERT_NO_THROW({ + EXPECT_EQ(c10::cuda::device_count(), 0); + EXPECT_EQ(torch::cuda::device_count(), 0); + }); +#endif +} + +// CPU-only: is_available() must be false and not throw, matching PyTorch. +TEST(CUDAContextLightTest, IsAvailableFalseAndNoThrowInCpuOnly) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + // Only relevant in CPU-only builds + return; +#else + ASSERT_NO_THROW({ + EXPECT_FALSE(at::cuda::is_available()); + EXPECT_FALSE(torch::cuda::is_available()); + }); +#endif +} + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // The following tests require CUDA runtime and can only run in CUDA builds