Skip to content

Commit dea5cbd

Browse files
i-kosarevclaude
andcommitted
net-ib: fix fault injection API issues from review
- ncclIbCastFaultSetQpDelay/SetQpError: use comm->base.nqps for bounds check instead of NCCL_IB_MAX_QPS - ncclIbCastFaultClear: atomically reset fatalErrorCount to 0 - move net_ib_fault_inject.h into transport/net_ib_cast/; drop local #define NCCL_IB_MAX_QPS 128, include net_ib_cast_inspect.h and add static_assert so a size mismatch becomes a compile error - fault hook in IbCastMultiSend: use IbCastStatsFatalError (renamed from ncclIbStatsFatalError in asanniko's split) - FaultInjCastQpErrorClearRecovers: ASSERT_EQ on SetQpError return value (was silently ignored); drain recvReq before DeregisterMemory - FaultInjCastSingleQpErrorIsFatal: EXPECT_EQ on ncclIbCastSetTokens and ncclIbCastFaultClear return values (Copilot review) Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
1 parent 80a1ad3 commit dea5cbd

5 files changed

Lines changed: 52 additions & 21 deletions

File tree

projects/rccl/src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ set(SRC_FILES
217217
include/plugin/net/net_v11.h
218218
include/plugin/profiler/net_ib_v1.h
219219
include/plugin/profiler/net_ib.h
220-
include/net_ib_fault_inject.h
220+
transport/net_ib_cast/net_ib_fault_inject.h
221221
include/plugin/profiler/net_socket_v1.h
222222
include/plugin/profiler/net_socket.h
223223
include/plugin/profiler/profiler_v1.h

projects/rccl/src/include/net_ib_fault_inject.h renamed to projects/rccl/src/transport/net_ib_cast/net_ib_fault_inject.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,11 @@ extern "C" {
2222
/*
2323
* Test-only per-QP fault injection API for the net-ib CAST transport.
2424
*
25-
* Implemented in src/transport/net_ib_cast.cc (CAST multi-QP path).
26-
*
27-
* NCCL_IB_MAX_QPS is defined in net_ib_cast_inspect.h; guard against
28-
* double-definition when both headers are included together.
25+
* Implemented in src/transport/net_ib_cast/p2p.cc (CAST multi-QP path).
2926
*/
30-
#ifndef NCCL_IB_MAX_QPS
31-
#define NCCL_IB_MAX_QPS 128
27+
#include "net_ib_cast_inspect.h"
28+
#ifdef __cplusplus
29+
static_assert(NCCL_IB_MAX_QPS == 128, "fault injection arrays sized for 128 QPs; update if NCCL_IB_MAX_QPS changes");
3230
#endif
3331

3432
/* ── CAST path (net_ib_cast.cc) ───────────────────────────────────────── */
@@ -39,7 +37,7 @@ extern "C" {
3937
ncclResult_t ncclIbCastFaultSetQpDelay(void* sendComm, int qpIdx, uint32_t delayUs);
4038

4139
/* Arm error injection on a specific QP index.
42-
* When armed, the hook calls ncclIbStatsFatalError and then returns
40+
* When armed, the hook calls IbCastStatsFatalError and then returns
4341
* ncclSystemError instead of calling wrap_ibv_post_send.
4442
* Set inject=false to disarm. */
4543
ncclResult_t ncclIbCastFaultSetQpError(void* sendComm, int qpIdx, bool inject);

projects/rccl/src/transport/net_ib_cast/p2p.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ ncclResult_t IbCastMultiSend(struct ncclIbSendComm* comm, int slot, int nqps, in
298298
const uint32_t faultDelay = comm->base.faultQpDelayUs[qpIndex];
299299
if (faultDelay) usleep(faultDelay);
300300
if (comm->base.faultQpError[qpIndex]) {
301-
ncclIbStatsFatalError(&comm->base.stats);
301+
IbCastStatsFatalError(&comm->base.stats);
302302
return ncclSystemError;
303303
}
304304
}
@@ -1094,15 +1094,15 @@ ncclResult_t IbCastTest(void* request, int* done, int* sizes) {
10941094
ncclResult_t ncclIbCastFaultSetQpDelay(void* sendComm, int qpIdx, uint32_t delayUs) {
10951095
if (!sendComm) return ncclInvalidArgument;
10961096
struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm;
1097-
if (qpIdx < 0 || qpIdx >= NCCL_IB_MAX_QPS) return ncclInvalidArgument;
1097+
if (qpIdx < 0 || qpIdx >= comm->base.nqps) return ncclInvalidArgument;
10981098
comm->base.faultQpDelayUs[qpIdx] = delayUs;
10991099
return ncclSuccess;
11001100
}
11011101

11021102
ncclResult_t ncclIbCastFaultSetQpError(void* sendComm, int qpIdx, bool inject) {
11031103
if (!sendComm) return ncclInvalidArgument;
11041104
struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm;
1105-
if (qpIdx < 0 || qpIdx >= NCCL_IB_MAX_QPS) return ncclInvalidArgument;
1105+
if (qpIdx < 0 || qpIdx >= comm->base.nqps) return ncclInvalidArgument;
11061106
comm->base.faultQpError[qpIdx] = inject;
11071107
return ncclSuccess;
11081108
}
@@ -1112,6 +1112,7 @@ ncclResult_t ncclIbCastFaultClear(void* sendComm) {
11121112
struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm;
11131113
memset(comm->base.faultQpDelayUs, 0, sizeof(comm->base.faultQpDelayUs));
11141114
memset(comm->base.faultQpError, 0, sizeof(comm->base.faultQpError));
1115+
__atomic_store_n(&comm->base.stats.fatalErrorCount, 0, __ATOMIC_RELEASE);
11151116
return ncclSuccess;
11161117
}
11171118

projects/rccl/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ if(BUILD_TESTS)
9494
${PROJECT_BINARY_DIR}/hipify/gensrc # for rccl_bfloat16.h
9595
${PROJECT_BINARY_DIR}/hipify/src # for graph/topo.h
9696
${PROJECT_BINARY_DIR}/hipify/src/include/plugin # for recorder tests, nccl_tuner.h
97+
${PROJECT_BINARY_DIR}/hipify/src/transport/net_ib_cast # for net_ib_fault_inject.h
9798
${ROCM_PATH}/include
9899
${ROCM_PATH}
99100
)

projects/rccl/test/transport/NetIbMPI/FaultInjectTests.cpp

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ TEST_F(NetIbMPITest, FaultInjCastSingleQpErrorIsFatal) {
462462
// Concentrate all tokens on QP 0: tokens[0]=1, tokens[1..n-1]=0.
463463
std::vector<int> tokens(actualNqps, 0);
464464
tokens[0] = 1;
465-
ncclIbCastSetTokens(sendComm, tokens.data(), actualNqps);
465+
EXPECT_EQ(ncclIbCastSetTokens(sendComm, tokens.data(), actualNqps), ncclSuccess);
466466
r1.setErrRet = static_cast<int>(ncclIbCastFaultSetQpError(sendComm, targetQp, /*inject=*/true));
467467
}
468468
MPI_Barrier(MPI_COMM_WORLD);
@@ -602,22 +602,27 @@ TEST_F(NetIbMPITest, FaultInjCastQpErrorClearRecovers) {
602602

603603
if (rank == 1) {
604604
for (int q = 0; q < actualNqps; ++q)
605-
ncclIbCastFaultSetQpError(sendComm1, q, /*inject=*/true);
605+
ASSERT_EQ(ncclIbCastFaultSetQpError(sendComm1, q, /*inject=*/true), ncclSuccess);
606606
}
607607
MPI_Barrier(MPI_COMM_WORLD);
608608

609609
// Trigger the fault (rank 0 posts recv, rank 1 posts send that will fail).
610+
// rank 1 forwards its fault result to rank 0 so we can assert the fault
611+
// was actually observed before testing recovery in phase 2.
612+
static constexpr int kPhase1MpiTag = 9882;
613+
FaultInjectResult p1 = {};
614+
void* recvReq1 = nullptr;
615+
bool recvDone1 = false;
610616
if (rank == 0) {
611617
void* bufs[1] = {buf1};
612618
size_t sizes[1] = {kMsgSize};
613619
int tags[1] = {501};
614620
void* handles[1] = {mhandle1};
615-
void* recvReq = nullptr;
616-
ASSERT_EQ(PostRecv(recvComm1, 1, bufs, sizes, tags, handles, &recvReq), ncclSuccess);
621+
ASSERT_EQ(PostRecv(recvComm1, 1, bufs, sizes, tags, handles, &recvReq1), ncclSuccess);
617622
for (int poll = 0; poll < 100; poll++) {
618623
int done = 0, sz = 0;
619-
if (TestRequest(recvReq, &done, &sz) != ncclSuccess) break;
620-
if (done) break;
624+
if (TestRequest(recvReq1, &done, &sz) != ncclSuccess) break;
625+
if (done) { recvDone1 = true; break; }
621626
usleep(kPollIntervalUs);
622627
}
623628
} else {
@@ -628,21 +633,47 @@ TEST_F(NetIbMPITest, FaultInjCastQpErrorClearRecovers) {
628633
if (sendRet != ncclSuccess || sendReq != nullptr) break;
629634
usleep(kPollIntervalUs);
630635
}
636+
int fatalCount = 0;
631637
if (sendRet == ncclSuccess && sendReq != nullptr) {
632638
for (int poll = 0; poll < 200; poll++) {
633639
int done = 0, sz = 0;
634-
int fc = 0;
635640
TestRequest(sendReq, &done, &sz);
636-
ncclIbCastFaultGetFatalCount(sendComm1, &fc);
637-
if (done || fc > 0) break;
641+
ncclIbCastFaultGetFatalCount(sendComm1, &fatalCount);
642+
if (done || fatalCount > 0) break;
638643
usleep(kPollIntervalUs);
639644
}
645+
} else {
646+
ncclIbCastFaultGetFatalCount(sendComm1, &fatalCount);
640647
}
641-
ncclIbCastFaultClear(sendComm1);
648+
p1.sendRet = static_cast<int>(sendRet);
649+
p1.fatalCount = fatalCount;
650+
EXPECT_EQ(ncclIbCastFaultClear(sendComm1), ncclSuccess);
651+
MPI_Send(&p1, sizeof(p1), MPI_BYTE, 0, kPhase1MpiTag, MPI_COMM_WORLD);
652+
}
653+
654+
if (rank == 0) {
655+
MPI_Recv(&p1, sizeof(p1), MPI_BYTE, 1, kPhase1MpiTag, MPI_COMM_WORLD,
656+
MPI_STATUS_IGNORE);
657+
bool isendFailed = (p1.sendRet != static_cast<int>(ncclSuccess));
658+
EXPECT_TRUE(isendFailed || p1.fatalCount > 0)
659+
<< "Phase 1: fault injection did not trigger — isend returned "
660+
<< p1.sendRet << ", fatalCount=" << p1.fatalCount
661+
<< "; recovery test is meaningless without a confirmed fault";
642662
}
643663

644664
MPI_Barrier(MPI_COMM_WORLD);
645665

666+
// Drain any outstanding recv request before closing the comm.
667+
// The recv may not complete if the sender faulted — that is expected here.
668+
if (rank == 0 && !recvDone1 && recvReq1 != nullptr) {
669+
for (int poll = 0; poll < 500; poll++) {
670+
int done = 0, sz = 0;
671+
if (TestRequest(recvReq1, &done, &sz) != ncclSuccess) break;
672+
if (done) break;
673+
usleep(kPollIntervalUs);
674+
}
675+
}
676+
646677
ASSERT_EQ(DeregisterMemory(comm1, mhandle1), ncclSuccess);
647678
if (rank == 0) {
648679
ASSERT_EQ(CloseRecvComm(recvComm1), ncclSuccess);

0 commit comments

Comments
 (0)