Skip to content

Commit ffdce20

Browse files
committed
refactor: simplify CMakelists changes and temporalily bypass the hardcoded distributed logic in main.cc
1 parent a140494 commit ffdce20

File tree

3 files changed

+22
-50
lines changed

3 files changed

+22
-50
lines changed

CMakeLists.txt

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,19 @@ option(USE_OMP "Use OpenMP as backend for Eigen" ON)
99
option(USE_NCCL "Build project for distributed running on CUDA using NCCL" ON)
1010
option(USE_MCCL "Build project for distributed running on MACA using MCCL" ON)
1111

12-
# ------------------------------------------------------------------------------
13-
# MACA toolchain override (must happen before project())
14-
# ------------------------------------------------------------------------------
15-
# When targeting MetaX MACA, the C/C++ compiler must be mxcc so that .maca
16-
# sources and device code can be compiled by the MACA toolchain.
12+
project(infini_train VERSION 0.5.0 LANGUAGES CXX)
13+
14+
# Switch to mxcc after project() so that third-party libs (glog, gflags) are
15+
# configured with the host compiler and their feature-detection checks pass.
1716
if(USE_MACA)
1817
set(MACA_PATH $ENV{MACA_PATH})
1918
if(NOT MACA_PATH)
20-
message(FATAL_ERROR "USE_MACA=ON but environment variable MACA_PATH is not set. "
21-
"Please export MACA_PATH (e.g. /opt/maca) before configuring.")
19+
message(FATAL_ERROR "USE_MACA=ON but environment variable MACA_PATH is not set.")
2220
endif()
2321
set(CMAKE_C_COMPILER "${MACA_PATH}/mxgpu_llvm/bin/mxcc")
2422
set(CMAKE_CXX_COMPILER "${MACA_PATH}/mxgpu_llvm/bin/mxcc")
2523
endif()
2624

27-
project(infini_train VERSION 0.5.0 LANGUAGES CXX)
28-
2925
set(CMAKE_CXX_STANDARD 20)
3026
set(CMAKE_CXX_STANDARD_REQUIRED ON)
3127
set(CMAKE_CXX_EXTENSIONS OFF)
@@ -45,45 +41,8 @@ include_directories(${gflags_SOURCE_DIR}/include)
4541
set(WITH_GFLAGS OFF CACHE BOOL "Disable glog finding system gflags" FORCE)
4642
set(WITH_GTEST OFF CACHE BOOL "Disable glog finding system gtest" FORCE)
4743
set(BUILD_TESTING OFF CACHE BOOL "Disable glog unit tests" FORCE)
48-
# Build glog as a static lib so its symbols are always visible at link time.
49-
# Under mxcc the default symbol visibility is hidden, which causes the shared
50-
# libglog.so to export no symbols and produces "undefined reference" errors.
5144
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build glog as static library" FORCE)
5245

53-
# Under MACA/mxcc, cmake's feature-detection test compilations do not find
54-
# standard POSIX system headers (mxcc has a non-standard sysroot probe path).
55-
# Pre-set glog's HAVE_* cache variables so that glog skips its fallback type /
56-
# symbol definitions, which would otherwise conflict with the real system
57-
# headers during the actual build.
58-
if(USE_MACA)
59-
set(HAVE_SYS_TYPES_H 1 CACHE INTERNAL "")
60-
set(HAVE_UNISTD_H 1 CACHE INTERNAL "")
61-
set(HAVE_DLFCN_H 1 CACHE INTERNAL "")
62-
set(HAVE_GLOB_H 1 CACHE INTERNAL "")
63-
set(HAVE_PWD_H 1 CACHE INTERNAL "")
64-
set(HAVE_SYS_TIME_H 1 CACHE INTERNAL "")
65-
set(HAVE_SYS_UTSNAME_H 1 CACHE INTERNAL "")
66-
set(HAVE_SYS_WAIT_H 1 CACHE INTERNAL "")
67-
set(HAVE_SYS_SYSCALL_H 1 CACHE INTERNAL "")
68-
set(HAVE_SYSLOG_H 1 CACHE INTERNAL "")
69-
set(HAVE_UCONTEXT_H 1 CACHE INTERNAL "")
70-
# check_type_size() uses two internal variables: the size value and a sentinel
71-
# "HAVE_HAVE_<VAR>" that marks the check as done. Pre-setting only the value
72-
# is insufficient — the sentinel must also be set so the check skips entirely.
73-
set(HAVE_MODE_T 4 CACHE INTERNAL "") # 4 bytes on Linux
74-
set(HAVE_HAVE_MODE_T TRUE CACHE INTERNAL "")
75-
set(HAVE_SSIZE_T 8 CACHE INTERNAL "") # 8 bytes on 64-bit Linux
76-
set(HAVE_HAVE_SSIZE_T TRUE CACHE INTERNAL "")
77-
set(HAVE_PREAD 1 CACHE INTERNAL "")
78-
set(HAVE_PWRITE 1 CACHE INTERNAL "")
79-
set(HAVE_POSIX_FADVISE 1 CACHE INTERNAL "")
80-
set(HAVE_SIGACTION 1 CACHE INTERNAL "")
81-
set(HAVE_SIGALTSTACK 1 CACHE INTERNAL "")
82-
set(HAVE_FCNTL 1 CACHE INTERNAL "")
83-
set(HAVE_DLADDR 1 CACHE INTERNAL "")
84-
set(HAVE___CXA_DEMANGLE 1 CACHE INTERNAL "")
85-
endif()
86-
8746
add_subdirectory(third_party/glog)
8847
include_directories(${glog_SOURCE_DIR}/src)
8948

example/gpt2/main.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ void Train(const nn::parallel::Rank &rank) {
146146
const ProcessGroup *pp_pg = nullptr;
147147

148148
if (rank.IsParallel()) {
149-
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
149+
auto parallel_device_type =
150+
(FLAGS_device == kDeviceMACA) ? Device::DeviceType::kMACA : Device::DeviceType::kCUDA;
151+
device = Device(parallel_device_type, rank.thread_rank());
150152
auto *pg_factory = ProcessGroupFactory::Instance(device.type());
151153

152154
if (ddp_world_size > 1) {

example/llama3/main.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,16 @@ namespace {
9393
const std::unordered_set<std::string> kSupportedModels = {"llama3"};
9494
constexpr char kDeviceCPU[] = "cpu";
9595
constexpr char kDeviceCUDA[] = "cuda";
96+
constexpr char kDeviceMACA[] = "maca";
9697
constexpr char kDtypeFP32[] = "float32";
9798
constexpr char kDtypeBF16[] = "bfloat16";
9899
} // namespace
99100

100101
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
101102
DEFINE_validator(device,
102-
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
103+
[](const char *, const std::string &value) {
104+
return value == kDeviceCPU || value == kDeviceCUDA || value == kDeviceMACA;
105+
});
103106

104107
void Train(const nn::parallel::Rank &rank) {
105108
using namespace nn::parallel;
@@ -129,7 +132,9 @@ void Train(const nn::parallel::Rank &rank) {
129132
const ProcessGroup *pp_pg = nullptr;
130133

131134
if (rank.IsParallel()) {
132-
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
135+
auto parallel_device_type =
136+
(FLAGS_device == kDeviceMACA) ? Device::DeviceType::kMACA : Device::DeviceType::kCUDA;
137+
device = Device(parallel_device_type, rank.thread_rank());
133138
auto *pg_factory = ProcessGroupFactory::Instance(device.type());
134139

135140
if (ddp_world_size > 1) {
@@ -154,7 +159,13 @@ void Train(const nn::parallel::Rank &rank) {
154159
nn::parallel::pp_rank = pp_rank;
155160
}
156161
} else {
157-
device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0);
162+
if (FLAGS_device == kDeviceCPU) {
163+
device = Device();
164+
} else if (FLAGS_device == kDeviceMACA) {
165+
device = Device(Device::DeviceType::kMACA, 0);
166+
} else {
167+
device = Device(Device::DeviceType::kCUDA, 0);
168+
}
158169
}
159170

160171
// calculate gradient accumulation from the desired total batch size and the current run configuration

0 commit comments

Comments
 (0)