Fix AllreduceV with CUDA stream.#12171
Conversation
We use async NCCL to implement timeout. However, when NCCL is in async mode, it uses a thread pool, and cannot work with per-thread CUDA stream, which resolves to the wrong per-thread stream in its pool. The existing allreduce implementation works because we use a custom CUDA stream in the NCCL coll wrapper. This PR moves that stream into NCCLComm, to expose it to the allreduce V implementation.
There was a problem hiding this comment.
Pull request overview
This PR addresses incorrect CUDA stream usage when NCCL is configured in async/non-blocking mode (thread pool), by moving ownership of a dedicated CUDA stream into NCCLComm and updating AllreduceV to correctly synchronize between the caller’s per-thread stream and the NCCL stream.
Changes:
- Move the NCCL communication stream from
NCCLCollintoNCCLCommand expose it viaNCCLComm::Stream(). - Update NCCL collective launches to use the communicator-owned stream (removing the prior coll-owned stream plumbing).
- Fix
AllreduceVby introducing event-based stream bracketing between the user stream and the NCCL stream.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| src/collective/comm.cuh | Make NCCLComm own a curt::Stream and expose a StreamRef view. |
| src/collective/comm.cu | Update NCCLComm construction/destruction to use the owned stream and sync it on teardown. |
| src/collective/coll.cuh | Remove the coll-owned CUDA stream member. |
| src/collective/coll.cu | Route async NCCL launches through NCCLComm::Stream() and update call sites accordingly. |
| src/collective/allreduce_v.cuh | Add event bracketing to safely interop between user stream and NCCL stream; update AllreduceV signature. |
| src/collective/allreduce.h | Pass an explicit user stream into gpu_impl::AllreduceV (ctx-stream when available). |
Comments suppressed due to low confidence (1)
src/collective/comm.cu:57
stream_is constructed beforecurt::SetDevice(ctx->Ordinal())is called. Sincecurt::Streamcreates acudaStream_ton the current device, this can create the NCCL communicator stream on the wrong device when the caller thread isn’t already onctx->Ordinal(), leading to invalid-handle errors or silent mis-synchronization. Consider setting the CUDA device before constructingstream_(e.g., delay stream creation until afterSetDevice, or add a device-guard member that runs beforestream_is constructed).
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl,
StringView nccl_path)
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
root.TaskID()},
stream_{} {
this->world_ = root.World();
this->rank_ = root.Rank();
this->domain_ = root.Domain();
if (!root.IsDistributed()) {
return;
}
curt::SetDevice(ctx->Ordinal());
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 22 out of 22 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 22 out of 22 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
We use async NCCL to implement timeout. However, when NCCL is in async mode, it uses a thread pool, and cannot work with per-thread CUDA stream, which resolves to the wrong per-thread stream in its internal pool.
The existing allreduce implementation works because we use a custom CUDA stream in the NCCL coll wrapper.
This PR moves that stream into NCCLComm, to expose it to the allreduce V implementation. In addition, to correctly synchronize with call stream, we pass the Context object through the stack.
ref #12122