Skip to content

Commit 1f10a97

Browse files
committed
fix(maca): harden multi-thread DDP+TP init on gpt2
- Move MACA/MCCL P2P_DISABLE setenv into MacaGuardImpl ctor and parse --tensor_parallel from /proc/self/cmdline, so both flags land before mcInit(0) (setenv from main() was too late at static init). - Also disable MCCL_P2P_DISABLE when TP>1: MACA_P2P_DISABLE alone still lets MCCL establish its own P2P buffers, which deadlocks multi-PG init on TP+SP / TP+SP+PP+VPP. - gpt2 main: defer ProcessGroup creation until after model->To(device), serialize the upload under a mutex + barrier across DP threads. MCCL init otherwise leaves stale read-only P2P mappings in the VA ranges mcMalloc later returns, racing with concurrent model uploads. - Drop the now-redundant setenv blocks from gpt2/llama3 main().
1 parent 905bee0 commit 1f10a97

3 files changed

Lines changed: 119 additions & 28 deletions

File tree

example/gpt2/main.cc

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
#include <algorithm>
2+
#include <barrier>
13
#include <chrono>
24
#include <cstdlib>
35
#include <format>
6+
#include <iterator>
47
#include <memory>
8+
#include <mutex>
59
#include <optional>
10+
#include <thread>
611
#include <unordered_map>
712
#include <unordered_set>
813

@@ -148,31 +153,33 @@ void Train(const nn::parallel::Rank &rank) {
148153
const ProcessGroup *tp_pg = nullptr;
149154
const ProcessGroup *pp_pg = nullptr;
150155

156+
auto rank_in_group = [&](const std::vector<int> &group_ranks) {
157+
auto it = std::find(group_ranks.begin(), group_ranks.end(), rank.GlobalRank());
158+
CHECK(it != group_ranks.end());
159+
return static_cast<int>(std::distance(group_ranks.begin(), it));
160+
};
161+
151162
if (rank.IsParallel()) {
152163
auto parallel_device_type
153164
= (FLAGS_device == kDeviceMACA) ? Device::DeviceType::kMACA : Device::DeviceType::kCUDA;
154165
device = Device(parallel_device_type, rank.thread_rank());
155-
auto *pg_factory = ProcessGroupFactory::Instance(device.type());
156166

167+
// NOTE(dcj): On MACA, defer ProcessGroup creation until AFTER the model
168+
// has been uploaded to the device. MCCL init registers internal P2P
169+
// buffers that leave stale read-only mappings in the address ranges
170+
// mcMalloc later hands out; allocating the model first keeps it in a
171+
// P2P-clean region of the VA space and avoids the init-time race on
172+
// multi-thread DDP+TP. Mirrors the llama3 fix combo.
157173
if (ddp_world_size > 1) {
158-
ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
159-
GetDataParallelGroupRanks(rank.GlobalRank()));
160-
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
174+
ddp_rank = rank_in_group(GetDataParallelGroupRanks(rank.GlobalRank()));
161175
}
162-
163176
if (tp_world_size > 1) {
164-
tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
165-
GetTensorParallelGroupRanks(rank.GlobalRank()));
166-
tp_rank = tp_pg->GetGroupRank(rank.GlobalRank());
177+
tp_rank = rank_in_group(GetTensorParallelGroupRanks(rank.GlobalRank()));
167178
// NOTE(zbl): Reserved for VocabParallelEmbedding
168179
nn::parallel::tp_rank = tp_rank;
169180
}
170-
171181
if (pp_world_size > 1) {
172-
pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
173-
GetPipelineParallelGroupRanks(rank.GlobalRank()));
174-
pp_rank = pp_pg->GetGroupRank(rank.GlobalRank());
175-
182+
pp_rank = rank_in_group(GetPipelineParallelGroupRanks(rank.GlobalRank()));
176183
nn::parallel::pp_rank = pp_rank;
177184
}
178185
} else {
@@ -206,7 +213,46 @@ void Train(const nn::parallel::Rank &rank) {
206213
model = std::make_shared<nn::TransformerModel>(model_config);
207214
}
208215

209-
model->To(device);
216+
// On MACA, parallel mcMalloc/mcMemcpy across threads still races even with
217+
// an mcMalloc mutex, because the runtime auto-maps allocations P2P-readonly
218+
// between sibling devices. Serialize the entire model upload so each
219+
// thread's allocations land before any peer thread starts touching the
220+
// address space.
221+
if (FLAGS_device == kDeviceMACA && rank.IsParallel() && FLAGS_nthread_per_process > 1) {
222+
static std::mutex model_to_mutex;
223+
std::lock_guard<std::mutex> lock(model_to_mutex);
224+
model->To(device);
225+
auto upload_impl = core::GetDeviceGuardImpl(device.type());
226+
upload_impl->SynchronizeDevice(device);
227+
} else {
228+
model->To(device);
229+
}
230+
231+
// Synchronize model upload across all DP threads before any MCCL init runs.
232+
// The barrier ensures no thread enters mcclCommInitAll while peer threads
233+
// are still mid-mcMemcpyAsync.
234+
if (FLAGS_device == kDeviceMACA && rank.IsParallel() && FLAGS_nthread_per_process > 1) {
235+
auto pre_pg_impl = core::GetDeviceGuardImpl(device.type());
236+
pre_pg_impl->SynchronizeDevice(device);
237+
static std::barrier pre_pg_barrier(FLAGS_nthread_per_process);
238+
pre_pg_barrier.arrive_and_wait();
239+
}
240+
241+
if (rank.IsParallel()) {
242+
auto *pg_factory = ProcessGroupFactory::Instance(device.type());
243+
if (ddp_world_size > 1) {
244+
ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
245+
GetDataParallelGroupRanks(rank.GlobalRank()));
246+
}
247+
if (tp_world_size > 1) {
248+
tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
249+
GetTensorParallelGroupRanks(rank.GlobalRank()));
250+
}
251+
if (pp_world_size > 1) {
252+
pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
253+
GetPipelineParallelGroupRanks(rank.GlobalRank()));
254+
}
255+
}
210256

211257
utils::PrecisionChecker::BuildNameMap(model.get());
212258

@@ -470,13 +516,6 @@ int main(int argc, char *argv[]) {
470516
gflags::ParseCommandLineFlags(&argc, &argv, true);
471517
google::InitGoogleLogging(argv[0]);
472518

473-
// On MACA, when TP > 1 disable P2P to prevent MCCL communication-ordering
474-
// deadlocks and P2P teardown crashes. Must be set before any mcclCommInitAll
475-
// call (i.e. before threads that create ProcessGroups are spawned).
476-
if (FLAGS_device == kDeviceMACA && FLAGS_tensor_parallel > 1) {
477-
setenv("MACA_P2P_DISABLE", "1", 1);
478-
}
479-
480519
auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check);
481520
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
482521
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);

example/llama3/main.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -492,13 +492,6 @@ int main(int argc, char *argv[]) {
492492
gflags::ParseCommandLineFlags(&argc, &argv, true);
493493
google::InitGoogleLogging(argv[0]);
494494

495-
// On MACA, when TP > 1 disable P2P to prevent MCCL communication-ordering
496-
// deadlocks and P2P teardown crashes. Must be set before any mcclCommInitAll
497-
// call (i.e. before threads that create ProcessGroups are spawned).
498-
if (FLAGS_device == kDeviceMACA && FLAGS_tensor_parallel > 1) {
499-
setenv("MACA_P2P_DISABLE", "1", 1);
500-
}
501-
502495
auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check);
503496
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
504497
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);

infini_train/src/core/runtime/maca/maca_guard_impl.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
#include <array>
44
#include <cstdlib>
5+
#include <fstream>
56
#include <memory>
67
#include <mutex>
8+
#include <sstream>
9+
#include <string>
10+
#include <vector>
711

812
#include "infini_train/include/common/maca/common_maca.h"
913
#include "infini_train/include/core/runtime/runtime_common.h"
@@ -15,6 +19,47 @@ namespace infini_train::core::maca {
1519
namespace {
1620
constexpr int kMaxGpus = 8;
1721

22+
// Read /proc/self/cmdline and return --tensor_parallel value, or 1 if absent /
23+
// unparseable. Must be callable from static init (before main runs), so we
24+
// cannot use gflags here.
25+
int ReadTensorParallelFromCmdline() {
26+
std::ifstream in("/proc/self/cmdline", std::ios::binary);
27+
if (!in) {
28+
return 1;
29+
}
30+
std::vector<std::string> args;
31+
std::string cur;
32+
char c;
33+
while (in.get(c)) {
34+
if (c == '\0') {
35+
if (!cur.empty()) {
36+
args.push_back(std::move(cur));
37+
cur.clear();
38+
}
39+
} else {
40+
cur.push_back(c);
41+
}
42+
}
43+
if (!cur.empty()) {
44+
args.push_back(std::move(cur));
45+
}
46+
for (size_t i = 0; i < args.size(); ++i) {
47+
const auto &a = args[i];
48+
std::string value;
49+
if (a.rfind("--tensor_parallel=", 0) == 0) {
50+
value = a.substr(std::string("--tensor_parallel=").size());
51+
} else if (a == "--tensor_parallel" && i + 1 < args.size()) {
52+
value = args[i + 1];
53+
} else {
54+
continue;
55+
}
56+
try {
57+
return std::stoi(value);
58+
} catch (...) { return 1; }
59+
}
60+
return 1;
61+
}
62+
1863
static std::array<std::unique_ptr<MacaStream>, kMaxGpus> maca_streams;
1964
static std::array<std::unique_ptr<MacaBlasHandle>, kMaxGpus> maca_blas_handles;
2065

@@ -84,6 +129,20 @@ MacaGuardImpl::MacaGuardImpl() {
84129
// themselves before launch.
85130
setenv("MACA_LAUNCH_BLOCKING", "1", 0);
86131

132+
// When TP > 1 on MACA, disable both the MACA runtime P2P mapping and the
133+
// MCCL-level P2P path to prevent multi-PG init deadlocks (threads
134+
// concurrently creating both DP and TP comms hang in mcclCommInitAll).
135+
// MACA_P2P_DISABLE alone is not sufficient for TP+SP / TP+SP+PP+VPP
136+
// configurations — MCCL still establishes its own P2P buffers during init,
137+
// so we must disable that too. Both must be set before mcInit(0); setenv
138+
// from main() is too late because this ctor runs at static init. Peek at
139+
// /proc/self/cmdline to keep single-card / DP-only / PP-only runs on the
140+
// P2P fast path.
141+
if (ReadTensorParallelFromCmdline() > 1) {
142+
setenv("MACA_P2P_DISABLE", "1", 0);
143+
setenv("MCCL_P2P_DISABLE", "1", 0);
144+
}
145+
87146
// The MACA runtime requires an explicit mcInit(0) before any other call.
88147
// CUDA has no equivalent; mirroring the DeviceManager ctor from 87390cd.
89148
MACA_CHECK(mcInit(0));

0 commit comments

Comments
 (0)