|
| 1 | +#include <algorithm> |
| 2 | +#include <barrier> |
1 | 3 | #include <chrono> |
2 | 4 | #include <cstdlib> |
3 | 5 | #include <format> |
| 6 | +#include <iterator> |
4 | 7 | #include <memory> |
| 8 | +#include <mutex> |
5 | 9 | #include <optional> |
| 10 | +#include <thread> |
6 | 11 | #include <unordered_map> |
7 | 12 | #include <unordered_set> |
8 | 13 |
|
@@ -148,31 +153,33 @@ void Train(const nn::parallel::Rank &rank) { |
148 | 153 | const ProcessGroup *tp_pg = nullptr; |
149 | 154 | const ProcessGroup *pp_pg = nullptr; |
150 | 155 |
|
| 156 | + auto rank_in_group = [&](const std::vector<int> &group_ranks) { |
| 157 | + auto it = std::find(group_ranks.begin(), group_ranks.end(), rank.GlobalRank()); |
| 158 | + CHECK(it != group_ranks.end()); |
| 159 | + return static_cast<int>(std::distance(group_ranks.begin(), it)); |
| 160 | + }; |
| 161 | + |
151 | 162 | if (rank.IsParallel()) { |
152 | 163 | auto parallel_device_type |
153 | 164 | = (FLAGS_device == kDeviceMACA) ? Device::DeviceType::kMACA : Device::DeviceType::kCUDA; |
154 | 165 | device = Device(parallel_device_type, rank.thread_rank()); |
155 | | - auto *pg_factory = ProcessGroupFactory::Instance(device.type()); |
156 | 166 |
|
| 167 | + // NOTE(dcj): On MACA, defer ProcessGroup creation until AFTER the model |
| 168 | + // has been uploaded to the device. MCCL init registers internal P2P |
| 169 | + // buffers that leave stale read-only mappings in the address ranges |
| 170 | + // mcMalloc later hands out; allocating the model first keeps it in a |
| 171 | + // P2P-clean region of the VA space and avoids the init-time race on |
| 172 | + // multi-thread DDP+TP. Mirrors the llama3 fix combo. |
157 | 173 | if (ddp_world_size > 1) { |
158 | | - ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()), |
159 | | - GetDataParallelGroupRanks(rank.GlobalRank())); |
160 | | - ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank()); |
| 174 | + ddp_rank = rank_in_group(GetDataParallelGroupRanks(rank.GlobalRank())); |
161 | 175 | } |
162 | | - |
163 | 176 | if (tp_world_size > 1) { |
164 | | - tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()), |
165 | | - GetTensorParallelGroupRanks(rank.GlobalRank())); |
166 | | - tp_rank = tp_pg->GetGroupRank(rank.GlobalRank()); |
| 177 | + tp_rank = rank_in_group(GetTensorParallelGroupRanks(rank.GlobalRank())); |
167 | 178 | // NOTE(zbl): Reserved for VocabParallelEmbedding |
168 | 179 | nn::parallel::tp_rank = tp_rank; |
169 | 180 | } |
170 | | - |
171 | 181 | if (pp_world_size > 1) { |
172 | | - pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()), |
173 | | - GetPipelineParallelGroupRanks(rank.GlobalRank())); |
174 | | - pp_rank = pp_pg->GetGroupRank(rank.GlobalRank()); |
175 | | - |
| 182 | + pp_rank = rank_in_group(GetPipelineParallelGroupRanks(rank.GlobalRank())); |
176 | 183 | nn::parallel::pp_rank = pp_rank; |
177 | 184 | } |
178 | 185 | } else { |
@@ -206,7 +213,46 @@ void Train(const nn::parallel::Rank &rank) { |
206 | 213 | model = std::make_shared<nn::TransformerModel>(model_config); |
207 | 214 | } |
208 | 215 |
|
209 | | - model->To(device); |
| 216 | + // On MACA, parallel mcMalloc/mcMemcpy across threads still races even with |
| 217 | + // an mcMalloc mutex, because the runtime auto-maps allocations P2P-readonly |
| 218 | + // between sibling devices. Serialize the entire model upload so each |
| 219 | + // thread's allocations land before any peer thread starts touching the |
| 220 | + // address space. |
| 221 | + if (FLAGS_device == kDeviceMACA && rank.IsParallel() && FLAGS_nthread_per_process > 1) { |
| 222 | + static std::mutex model_to_mutex; |
| 223 | + std::lock_guard<std::mutex> lock(model_to_mutex); |
| 224 | + model->To(device); |
| 225 | + auto upload_impl = core::GetDeviceGuardImpl(device.type()); |
| 226 | + upload_impl->SynchronizeDevice(device); |
| 227 | + } else { |
| 228 | + model->To(device); |
| 229 | + } |
| 230 | + |
| 231 | + // Synchronize model upload across all DP threads before any MCCL init runs. |
| 232 | + // The barrier ensures no thread enters mcclCommInitAll while peer threads |
| 233 | + // are still mid-mcMemcpyAsync. |
| 234 | + if (FLAGS_device == kDeviceMACA && rank.IsParallel() && FLAGS_nthread_per_process > 1) { |
| 235 | + auto pre_pg_impl = core::GetDeviceGuardImpl(device.type()); |
| 236 | + pre_pg_impl->SynchronizeDevice(device); |
| 237 | + static std::barrier pre_pg_barrier(FLAGS_nthread_per_process); |
| 238 | + pre_pg_barrier.arrive_and_wait(); |
| 239 | + } |
| 240 | + |
| 241 | + if (rank.IsParallel()) { |
| 242 | + auto *pg_factory = ProcessGroupFactory::Instance(device.type()); |
| 243 | + if (ddp_world_size > 1) { |
| 244 | + ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()), |
| 245 | + GetDataParallelGroupRanks(rank.GlobalRank())); |
| 246 | + } |
| 247 | + if (tp_world_size > 1) { |
| 248 | + tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()), |
| 249 | + GetTensorParallelGroupRanks(rank.GlobalRank())); |
| 250 | + } |
| 251 | + if (pp_world_size > 1) { |
| 252 | + pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()), |
| 253 | + GetPipelineParallelGroupRanks(rank.GlobalRank())); |
| 254 | + } |
| 255 | + } |
210 | 256 |
|
211 | 257 | utils::PrecisionChecker::BuildNameMap(model.get()); |
212 | 258 |
|
@@ -470,13 +516,6 @@ int main(int argc, char *argv[]) { |
470 | 516 | gflags::ParseCommandLineFlags(&argc, &argv, true); |
471 | 517 | google::InitGoogleLogging(argv[0]); |
472 | 518 |
|
473 | | - // On MACA, when TP > 1 disable P2P to prevent MCCL communication-ordering |
474 | | - // deadlocks and P2P teardown crashes. Must be set before any mcclCommInitAll |
475 | | - // call (i.e. before threads that create ProcessGroups are spawned). |
476 | | - if (FLAGS_device == kDeviceMACA && FLAGS_tensor_parallel > 1) { |
477 | | - setenv("MACA_P2P_DISABLE", "1", 1); |
478 | | - } |
479 | | - |
480 | 519 | auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check); |
481 | 520 | nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, |
482 | 521 | FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); |
|
0 commit comments