Skip to content

[Cpp API Compatibility] Align cuda compat#78808

Merged
SigureMo merged 15 commits intoPaddlePaddle:developfrom
youge325:cAlign-cuda-compat
May 1, 2026
Merged

[Cpp API Compatibility] Align cuda compat#78808
SigureMo merged 15 commits intoPaddlePaddle:developfrom
youge325:cAlign-cuda-compat

Conversation

@youge325
Copy link
Copy Markdown
Contributor

@youge325 youge325 commented Apr 25, 2026

PR Category

Execute Infrastructure

PR Types

Bug fixes

Description

拆分自 #78707

对齐 torch::cuda::synchronizec10::cuda::CUDAGuard 的 CUDA 兼容性语义,并修复相关编译与平台问题:

  1. 重写 torch::cuda::synchronize

    • 使用 c10::cuda::CUDAGuard 替代原有的直接 cudaDeviceSynchronize / hipDeviceSynchronize 调用。
    • 匹配 PyTorch 语义:device_index == -1 表示同步当前设备;显式设备同步完成后不得泄漏被修改的当前设备。
  2. 修复编译与链接问题

    • 修复 CPU-only 编译失败(条件编译 #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP))。
    • 修复 Windows 链接错误:为 torch::cuda 的公开 API 添加 PADDLE_API 符号导出。
  3. 重构 c10::cuda::CUDAGuard / OptionalCUDAGuard

    • 移除对 paddle::platform::CUDADeviceGuard 的依赖,改为直接调用 phi::backends::gpu::SetDeviceId
    • 在析构函数、set_deviceset_indexreset 等路径中显式恢复原始设备,确保与 PyTorch CUDAGuard 的行为一致。
  4. 新增单测

    • ATen_CUDAContext_test.cc 中补充 4 组测试,验证:
      • torch::cuda::synchronize 不泄漏当前设备
      • CUDAGuard 多次切换后正确恢复原始设备
      • OptionalCUDAGuard::reset 正确清理状态

是否引起精度变化

Copilot AI review requested due to automatic review settings April 25, 2026 07:58
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 25, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot Bot added the contributor External developers label Apr 25, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@youge325
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 25, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@4919688). Learn more about missing BASE report.

Additional details and impacted files
@@             Coverage Diff             @@
##             develop    #78808   +/-   ##
===========================================
  Coverage           ?   100.00%           
===========================================
  Files              ?         2           
  Lines              ?        23           
  Branches           ?         0           
===========================================
  Hits               ?        23           
  Misses             ?         0           
  Partials           ?         0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@youge325
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

1 similar comment
@youge325
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

@youge325
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

@youge325
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.cpp Outdated
Comment thread paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h Outdated
Comment thread paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 116 to +147
@@ -129,23 +139,26 @@ struct OptionalCUDAGuard {
std::optional<Device> current_device() const { return current_device_; }

void reset() {
guard_.reset();
if (original_device_.has_value()) {
// Always restore to original_device_ to handle external device changes.
// This matches PyTorch OptionalDeviceGuard semantics.
phi::backends::gpu::SetDeviceId(
static_cast<int>(original_device_->index()));
}
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OptionalCUDAGuard::~OptionalCUDAGuard() calls reset(), and reset() always restores original_device_ via SetDeviceId whenever original_device_ is set. This can cause the same unintended first-call cudaSetDevice side effects even when no device switch occurred (e.g., set_device() called with the current device). Consider restoring only if the current device differs from original_device_ and/or if the guard actually changed the device, then clear the optionals.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful read. The behavior matches PyTorch's InlineDeviceGuard: PyTorch's destructor also unconditionally calls impl_.uncheckedSetDevice(original_device_), and the no-op short-circuit happens one layer down inside c10::cuda::SetDevice, which compares cur_device == device via cudaGetDevice before calling cudaSetDevice. Our compat path keeps the same shape: phi::backends::gpu::SetDeviceId performs the same cudaGetDevice-then-compare short-circuit internally, so calling reset() on a never-mutated guard does not actually trigger a cudaSetDevice. Keeping reset() always restore original_device_ is therefore the right call for parity with PyTorch.

Comment on lines 39 to +45
void synchronize(int64_t device_index) {
TORCH_CHECK(is_available(), "No CUDA GPUs are available");
auto num_gpus = cuda::device_count();
TORCH_CHECK(device_index < 0 || device_index < num_gpus,
"Device index out of range: ",
device_index);
// TODO(yongqiang) need using DeviceGuard
TORCH_CHECK(
device_index == -1 || (device_index >= 0 && device_index < num_gpus),
"Device index out of range: ",
device_index);
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In CPU-only builds, torch::cuda::is_available() currently calls torch::cuda::device_count(), but device_count() throws when CUDA/HIP is not compiled in. This makes the #else branch in synchronize() unreachable and causes synchronize() to throw from device_count() instead of the intended "not compiled with CUDA"/"no GPUs" error path. Consider making device_count() return 0 (or gating is_available() / the pre-checks behind #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)), so is_available() is a non-throwing query and synchronize() reports a consistent error on CPU-only builds.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下这个 PR 不影响 ABI 兼容性是么?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我再验证一下,上次改了cpp文件,再跑一遍 ABI Check

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

再跑一遍 ABI Check

可以考虑先把那个 PR 处理下,就不用每次都单独搞一下了,加一下 review 检查,@SigureMo@BingooYang

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static-check 过了,这个 PR 里相关改动可以 revert 掉了

youge325 added 2 commits May 1, 2026 21:18
In CPU-only builds, c10::cuda::device_count() / torch::cuda::device_count()
previously threw "Cannot visit device count" via PADDLE_THROW. This made
is_available() unsafe to call and caused synchronize() to surface the
wrong error message.

Match PyTorch semantics: return 0 in CPU-only builds so that
is_available() returns false and synchronize() falls through the existing
TORCH_CHECK(is_available(), "No CUDA GPUs are available") guard. The
unreachable #else PADDLE_THROW branch in synchronize() is removed.

Adds three CPU-only regression tests:
- DeviceCountReturnsZeroInCpuOnly
- IsAvailableFalseAndNoThrowInCpuOnly
- SynchronizeReportsNoGpuMessageInCpuOnly

Addresses Copilot review comment 3168115261.
Copy link
Copy Markdown
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

private:
Device original_device_;
Device current_device_;
paddle::platform::CUDADeviceGuard guard_;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

移除对 paddle::platform::CUDADeviceGuard 的依赖,改为直接调用 phi::backends::gpu::SetDeviceId

话说 paddle 的 guard 是有什么问题吗?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#78707 (review) codex review的时候发现的,单测不够充分,或者说当时单测的设计没有发现问题

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

喔喔 是那个问题

不过这里是否可以通过修复 paddle 内的 guard 实现?看起来单纯是 paddle guard 的 bug?这里应该也不会出现修改后导致破坏之前行为的问题?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image

应该可以修,因为每次需要SetDevice的时候都调用SetDeviceIndex,导致prev_id_被反复覆盖,实际上应该只在构造CUDADeviceGuard的时候做一次设置prev_id_的操作,这个逻辑应该被拆分出来,之前没有错可能是因为只有一次性调用CUDADeviceGuard的场景,不需要反反复复SetDevice或者SetDeviceIndex

Copy link
Copy Markdown
Contributor Author

@youge325 youge325 May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#55498 这好像是有意设计的,析构的时候不回到原设备,等我再想想

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

啊……有些历史原因这里不能直接复用的话就按照当前方式来就好,看起来不是好解的问题

@SigureMo SigureMo merged commit de1a390 into PaddlePaddle:develop May 1, 2026
83 of 85 checks passed
@youge325 youge325 deleted the cAlign-cuda-compat branch May 2, 2026 03:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants