Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

#include <optional>

#include "paddle/phi/core/platform/cuda_device_guard.h"
#include "paddle/phi/backends/gpu/gpu_info.h"

namespace c10::cuda {

Expand All @@ -46,15 +46,13 @@ struct CUDAGuard {

explicit CUDAGuard(DeviceIndex device_index)
: original_device_(detail::current_cuda_device()),
current_device_(original_device_),
guard_() {
current_device_(original_device_) {
set_index(device_index);
}

explicit CUDAGuard(Device device)
: original_device_(detail::current_cuda_device()),
current_device_(original_device_),
guard_() {
current_device_(original_device_) {
set_device(device);
}

Expand All @@ -63,18 +61,27 @@ struct CUDAGuard {

CUDAGuard(CUDAGuard&& other) = delete;
CUDAGuard& operator=(CUDAGuard&& other) = delete;
~CUDAGuard() = default;
~CUDAGuard() {
// Always restore to original_device_ to handle cases where the device
// was changed outside of this guard, matching PyTorch semantics.
phi::backends::gpu::SetDeviceId(static_cast<int>(original_device_.index()));
}

void set_device(Device device) {
current_device_ = detail::normalize_cuda_device(device);
guard_.SetDevice(current_device_._PD_GetInner());
const Device normalized = detail::normalize_cuda_device(device);
if (normalized.index() != current_device_.index()) {
phi::backends::gpu::SetDeviceId(static_cast<int>(normalized.index()));
current_device_ = normalized;
}
}

void reset_device(Device device) { set_device(device); }

void set_index(DeviceIndex device_index) {
current_device_ = Device(kCUDA, device_index);
guard_.SetDeviceIndex(device_index);
if (current_device_.index() != device_index) {
phi::backends::gpu::SetDeviceId(static_cast<int>(device_index));
current_device_ = Device(kCUDA, device_index);
}
}

Device original_device() const { return original_device_; }
Expand All @@ -84,7 +91,6 @@ struct CUDAGuard {
private:
Device original_device_;
Device current_device_;
paddle::platform::CUDADeviceGuard guard_;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

移除对 paddle::platform::CUDADeviceGuard 的依赖,改为直接调用 phi::backends::gpu::SetDeviceId

话说 paddle 的 guard 是有什么问题吗?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#78707 (review) codex review的时候发现的,单测不够充分,或者说当时单测的设计没有发现问题

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

喔喔 是那个问题

不过这里是否可以通过修复 paddle 内的 guard 实现?看起来单纯是 paddle guard 的 bug?这里应该也不会出现修改后导致破坏之前行为的问题?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image

应该可以修,因为每次需要SetDevice的时候都调用SetDeviceIndex,导致prev_id_被反复覆盖,实际上应该只在构造CUDADeviceGuard的时候做一次设置prev_id_的操作,这个逻辑应该被拆分出来,之前没有错可能是因为只有一次性调用CUDADeviceGuard的场景,不需要反反复复SetDevice或者SetDeviceIndex

Copy link
Copy Markdown
Contributor Author

@youge325 youge325 May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#55498 这好像是有意设计的,析构的时候不回到原设备,等我再想想

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

啊……有些历史原因这里不能直接复用的话就按照当前方式来就好,看起来不是好解的问题

};

struct OptionalCUDAGuard {
Expand All @@ -107,20 +113,24 @@ struct OptionalCUDAGuard {

OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete;
OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete;
~OptionalCUDAGuard() = default;
~OptionalCUDAGuard() { reset(); }

void set_device(Device device) {
const Device normalized = detail::normalize_cuda_device(device);
init_if_needed();
guard_->SetDevice(normalized._PD_GetInner());
if (normalized.index() != current_device_->index()) {
phi::backends::gpu::SetDeviceId(static_cast<int>(normalized.index()));
}
current_device_ = normalized;
}

void reset_device(Device device) { set_device(device); }

void set_index(DeviceIndex device_index) {
init_if_needed();
guard_->SetDeviceIndex(device_index);
if (device_index != current_device_->index()) {
phi::backends::gpu::SetDeviceId(static_cast<int>(device_index));
}
current_device_ = Device(kCUDA, device_index);
}

Expand All @@ -129,23 +139,26 @@ struct OptionalCUDAGuard {
std::optional<Device> current_device() const { return current_device_; }

void reset() {
guard_.reset();
if (original_device_.has_value()) {
// Always restore to original_device_ to handle external device changes.
// This matches PyTorch OptionalDeviceGuard semantics.
phi::backends::gpu::SetDeviceId(
static_cast<int>(original_device_->index()));
}
Comment on lines 116 to +147
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OptionalCUDAGuard::~OptionalCUDAGuard() calls reset(), and reset() always restores original_device_ via SetDeviceId whenever original_device_ is set. This can cause the same unintended first-call cudaSetDevice side effects even when no device switch occurred (e.g., set_device() called with the current device). Consider restoring only if the current device differs from original_device_ and/or if the guard actually changed the device, then clear the optionals.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful read. The behavior matches PyTorch's InlineDeviceGuard: PyTorch's destructor also unconditionally calls impl_.uncheckedSetDevice(original_device_), and the no-op short-circuit happens one layer down inside c10::cuda::SetDevice, which compares cur_device == device via cudaGetDevice before calling cudaSetDevice. Our compat path keeps the same shape: phi::backends::gpu::SetDeviceId performs the same cudaGetDevice-then-compare short-circuit internally, so calling reset() on a never-mutated guard does not actually trigger a cudaSetDevice. Keeping reset() always restore original_device_ is therefore the right call for parity with PyTorch.

original_device_.reset();
current_device_.reset();
}

private:
void init_if_needed() {
if (!guard_.has_value()) {
if (!original_device_.has_value()) {
original_device_ = detail::current_cuda_device();
current_device_ = original_device_;
guard_.emplace();
}
}

std::optional<Device> original_device_;
std::optional<Device> current_device_;
std::optional<paddle::platform::CUDADeviceGuard> guard_;
};

} // namespace c10::cuda
Expand Down
35 changes: 20 additions & 15 deletions paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.cpp
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下这个 PR 不影响 ABI 兼容性是么?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我再验证一下,上次改了cpp文件,再跑一遍 ABI Check

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

再跑一遍 ABI Check

可以考虑先把那个 PR 处理下,就不用每次都单独搞一下了,加一下 review 检查,@SigureMo@BingooYang

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static-check 过了,这个 PR 里相关改动可以 revert 掉了

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <c10/cuda/CUDAFunctions.h>
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include <c10/cuda/CUDAGuard.h>
#endif
#include <c10/util/Exception.h>
#include <torch/cuda.h>

Expand All @@ -25,8 +29,10 @@ c10::DeviceIndex device_count() {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return phi::backends::gpu::GetGPUDeviceCount();
#else
PADDLE_THROW(common::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit device count."));
// Match PyTorch c10::cuda::device_count(): return 0 in CPU-only builds so
// that is_available() and the pre-checks of synchronize() degrade gracefully
// through a single, consistent "No CUDA GPUs are available" error path.
return 0;
#endif
}

Expand All @@ -35,21 +41,20 @@ bool is_available() { return cuda::device_count() > 0; }
void synchronize(int64_t device_index) {
TORCH_CHECK(is_available(), "No CUDA GPUs are available");
auto num_gpus = cuda::device_count();
TORCH_CHECK(device_index < 0 || device_index < num_gpus,
"Device index out of range: ",
device_index);
// TODO(yongqiang) need using DeviceGuard
TORCH_CHECK(
device_index == -1 || (device_index >= 0 && device_index < num_gpus),
"Device index out of range: ",
device_index);
Comment on lines 41 to +47
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In CPU-only builds, torch::cuda::is_available() currently calls torch::cuda::device_count(), but device_count() throws when CUDA/HIP is not compiled in. This makes the #else branch in synchronize() unreachable and causes synchronize() to throw from device_count() instead of the intended "not compiled with CUDA"/"no GPUs" error path. Consider making device_count() return 0 (or gating is_available() / the pre-checks behind #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)), so is_available() is a non-throwing query and synchronize() reports a consistent error on CPU-only builds.

Copilot uses AI. Check for mistakes.
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle::platform::SetDeviceId(device_index);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#endif
#else
PADDLE_THROW(common::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit device synchronize."));
// Match PyTorch semantics:
// 1. `device_index == -1` means "current CUDA device".
// 2. Explicit device synchronization must not leak a changed current device
// to the caller after returning.
const c10::cuda::CUDAGuard device_guard(c10::Device(
c10::DeviceType::CUDA, static_cast<c10::DeviceIndex>(device_index)));
c10::cuda::device_synchronize();
#endif
// CPU-only builds are already rejected above by the is_available() check.
}

} // namespace torch::cuda
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
#pragma once

#include <c10/core/Device.h>

#include <cstdint>
#include "paddle/common/macros.h"

namespace torch::cuda {

c10::DeviceIndex device_count();
PADDLE_API c10::DeviceIndex device_count();

bool is_available();
PADDLE_API bool is_available();

void synchronize(int64_t device_index = -1);
PADDLE_API void synchronize(int64_t device_index = -1);

} // namespace torch::cuda
namespace at::cuda {
Expand Down
154 changes: 154 additions & 0 deletions test/cpp/compat/ATen_CUDAContext_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/core/Allocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <torch/cuda.h>

#include "gtest/gtest.h"

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "paddle/phi/backends/gpu/gpu_info.h"
#endif
Expand Down Expand Up @@ -57,6 +59,27 @@ TEST(CUDAFunctionsTest, DeviceSynchronize) {
#endif
}

// CPU-only: torch::cuda::synchronize must report "No CUDA GPUs are available"
// rather than the older "Cannot visit device count" produced by device_count().
// Matches PyTorch behavior where device_count() returns 0 in CPU-only builds
// and the synchronize() pre-check is the single source of the GPU-missing
// error message.
TEST(CUDAFunctionsTest, SynchronizeReportsNoGpuMessageInCpuOnly) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Only relevant in CPU-only builds
return;
#else
try {
torch::cuda::synchronize();
FAIL() << "expected exception";
} catch (const std::exception& e) {
const std::string msg = e.what();
EXPECT_NE(msg.find("No CUDA GPUs are available"), std::string::npos) << msg;
EXPECT_EQ(msg.find("Cannot visit device count"), std::string::npos) << msg;
}
#endif
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(CUDAFunctionsTest, StreamSynchronize) {
if (!at::cuda::is_available()) {
Expand All @@ -78,6 +101,110 @@ TEST(CUDAFunctionsTest, AtNamespaceAliases) {
auto stream = c10::cuda::getCurrentCUDAStream();
ASSERT_NO_THROW(at::cuda::stream_synchronize(stream));
}

TEST(CUDAFunctionsTest, TorchSynchronizePreservesCurrentDevice) {
if (!torch::cuda::is_available()) {
return;
}
if (torch::cuda::device_count() < 2) {
return;
}

constexpr int current_device = 0;
constexpr int other_device = 1;
c10::cuda::CUDAGuard guard(static_cast<c10::DeviceIndex>(current_device));
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), current_device);

ASSERT_NO_THROW(torch::cuda::synchronize(other_device));
EXPECT_EQ(phi::backends::gpu::GetCurrentDeviceId(), current_device);
}

TEST(CUDAFunctionsTest, SynchronizeRejectsInvalidNegativeDevice) {
if (!torch::cuda::is_available()) {
return;
}
ASSERT_THROW(torch::cuda::synchronize(-2), std::exception);
}

TEST(CUDAFunctionsTest, CUDAGuardRestoresOriginalDeviceAfterMultipleSwitches) {
if (!torch::cuda::is_available()) {
return;
}
if (torch::cuda::device_count() < 2) {
return;
}

constexpr int original_device = 0;
constexpr int intermediate_device = 1;
phi::backends::gpu::SetDeviceId(original_device);
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device);

{
c10::cuda::CUDAGuard guard(
static_cast<c10::DeviceIndex>(intermediate_device));
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), intermediate_device);
guard.set_index(static_cast<c10::DeviceIndex>(original_device));
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device);
guard.set_index(static_cast<c10::DeviceIndex>(intermediate_device));
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), intermediate_device);
}

EXPECT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device);
}

TEST(CUDAFunctionsTest,
CUDAGuardRestoresOriginalDeviceAfterReturnToOriginalThenExit) {
if (!torch::cuda::is_available()) {
return;
}
if (torch::cuda::device_count() < 2) {
return;
}

constexpr int original_device = 0;
constexpr int intermediate_device = 1;
phi::backends::gpu::SetDeviceId(original_device);
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device);

{
c10::cuda::CUDAGuard guard(
static_cast<c10::DeviceIndex>(intermediate_device));
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), intermediate_device);

guard.set_index(static_cast<c10::DeviceIndex>(original_device));
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device);
}

EXPECT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device);
}

TEST(CUDAFunctionsTest,
OptionalCUDAGuardResetRestoresOriginalDeviceAfterReturnToOriginal) {
if (!torch::cuda::is_available()) {
return;
}
if (torch::cuda::device_count() < 2) {
return;
}

constexpr int original_device = 0;
constexpr int intermediate_device = 1;
phi::backends::gpu::SetDeviceId(original_device);
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device);

c10::cuda::OptionalCUDAGuard guard;
guard.set_index(static_cast<c10::DeviceIndex>(intermediate_device));
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), intermediate_device);

guard.set_index(static_cast<c10::DeviceIndex>(original_device));
ASSERT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device);

guard.reset();

EXPECT_EQ(phi::backends::gpu::GetCurrentDeviceId(), original_device);
EXPECT_FALSE(guard.original_device().has_value());
EXPECT_FALSE(guard.current_device().has_value());
}
#endif

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -111,6 +238,33 @@ TEST(CUDAContextLightTest, GetNumGPUs) {
#endif
}

// CPU-only: device_count() must return 0 instead of throwing, matching the
// PyTorch contract that device_count() is a non-throwing query.
TEST(CUDAContextLightTest, DeviceCountReturnsZeroInCpuOnly) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Only relevant in CPU-only builds
return;
#else
ASSERT_NO_THROW({
EXPECT_EQ(c10::cuda::device_count(), 0);
EXPECT_EQ(torch::cuda::device_count(), 0);
});
#endif
}

// CPU-only: is_available() must be false and not throw, matching PyTorch.
TEST(CUDAContextLightTest, IsAvailableFalseAndNoThrowInCpuOnly) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Only relevant in CPU-only builds
return;
#else
ASSERT_NO_THROW({
EXPECT_FALSE(at::cuda::is_available());
EXPECT_FALSE(torch::cuda::is_available());
});
#endif
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)

// The following tests require CUDA runtime and can only run in CUDA builds
Expand Down
Loading