Skip to content

Commit a140494

Browse files
committed
fix: resolve MACA build and runtime issues to enable GPT-2 training
CMakeLists.txt: - Pre-set HAVE_MODE_T/HAVE_SSIZE_T and their sentinel variables (HAVE_HAVE_MODE_T/HAVE_HAVE_SSIZE_T) before add_subdirectory(glog), since mxcc cmake feature-detection probes cannot find standard POSIX headers; without the sentinels check_type_size re-runs and overwrites the pre-set values, causing glog to emit conflicting fallback typedefs - Add BUILD_TESTING=OFF to skip glog unit tests (-fPIE unsupported by mxcc) - Add BUILD_SHARED_LIBS=OFF to build glog as a static library; mxcc defaults to hidden symbol visibility, making libglog.so export nothing datatype.h: - Add is_bfloat16<T> and is_fp16<T> type traits with USE_CUDA/USE_MACA specializations, needed by common_cpu.h Cast and init.cc ARANGE_CASE common/cpu/common_cpu.h: - Route fp16/bf16 destinations through float in Cast<T>(), avoiding ambiguous integer→__half/__maca_bfloat16 conversion on MACA kernels/maca/{stack,concat,slice,transform,elementwise,split,gather}.maca: - Add reinterpret_cast<void **> to all mcMallocAsync(&ptr, ...) calls; MACA's mcMallocAsync requires void** but typed pointers were passed - Fix mcDevAttrMultiProcessorCount → mcDeviceAttributeMultiProcessorCount in elementwise.maca (correct MACA enum name) optimizer.cc: - Change Fill<T>(0) → Fill<T>(0.f) for Adam m/v initialization; __half(0) is ambiguous on MACA (only float/double ctors available) nn/init.cc: - Replace std::iota + static_cast<TYPE>(start) in ARANGE_CASE with an explicit loop via static_cast<float> to avoid ambiguous integer→fp16/ bf16 conversion for kBFLOAT16/kFLOAT16 cases example/gpt2/main.cc: - Add kDeviceMACA constant, update --device validator to accept "maca", and add Device::DeviceType::kMACA branch in device selection
1 parent eacbaba commit a140494

File tree

13 files changed

+92
-23
lines changed

13 files changed

+92
-23
lines changed

CMakeLists.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,46 @@ include_directories(${gflags_SOURCE_DIR}/include)
4444
# glog
4545
set(WITH_GFLAGS OFF CACHE BOOL "Disable glog finding system gflags" FORCE)
4646
set(WITH_GTEST OFF CACHE BOOL "Disable glog finding system gtest" FORCE)
47+
set(BUILD_TESTING OFF CACHE BOOL "Disable glog unit tests" FORCE)
48+
# Build glog as a static lib so its symbols are always visible at link time.
49+
# Under mxcc the default symbol visibility is hidden, which causes the shared
50+
# libglog.so to export no symbols and produces "undefined reference" errors.
51+
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build glog as static library" FORCE)
52+
53+
# Under MACA/mxcc, cmake's feature-detection test compilations do not find
54+
# standard POSIX system headers (mxcc has a non-standard sysroot probe path).
55+
# Pre-set glog's HAVE_* cache variables so that glog skips its fallback type /
56+
# symbol definitions, which would otherwise conflict with the real system
57+
# headers during the actual build.
58+
if(USE_MACA)
59+
set(HAVE_SYS_TYPES_H 1 CACHE INTERNAL "")
60+
set(HAVE_UNISTD_H 1 CACHE INTERNAL "")
61+
set(HAVE_DLFCN_H 1 CACHE INTERNAL "")
62+
set(HAVE_GLOB_H 1 CACHE INTERNAL "")
63+
set(HAVE_PWD_H 1 CACHE INTERNAL "")
64+
set(HAVE_SYS_TIME_H 1 CACHE INTERNAL "")
65+
set(HAVE_SYS_UTSNAME_H 1 CACHE INTERNAL "")
66+
set(HAVE_SYS_WAIT_H 1 CACHE INTERNAL "")
67+
set(HAVE_SYS_SYSCALL_H 1 CACHE INTERNAL "")
68+
set(HAVE_SYSLOG_H 1 CACHE INTERNAL "")
69+
set(HAVE_UCONTEXT_H 1 CACHE INTERNAL "")
70+
# check_type_size() uses two internal variables: the size value and a sentinel
71+
# "HAVE_HAVE_<VAR>" that marks the check as done. Pre-setting only the value
72+
# is insufficient — the sentinel must also be set so the check skips entirely.
73+
set(HAVE_MODE_T 4 CACHE INTERNAL "") # 4 bytes on Linux
74+
set(HAVE_HAVE_MODE_T TRUE CACHE INTERNAL "")
75+
set(HAVE_SSIZE_T 8 CACHE INTERNAL "") # 8 bytes on 64-bit Linux
76+
set(HAVE_HAVE_SSIZE_T TRUE CACHE INTERNAL "")
77+
set(HAVE_PREAD 1 CACHE INTERNAL "")
78+
set(HAVE_PWRITE 1 CACHE INTERNAL "")
79+
set(HAVE_POSIX_FADVISE 1 CACHE INTERNAL "")
80+
set(HAVE_SIGACTION 1 CACHE INTERNAL "")
81+
set(HAVE_SIGALTSTACK 1 CACHE INTERNAL "")
82+
set(HAVE_FCNTL 1 CACHE INTERNAL "")
83+
set(HAVE_DLADDR 1 CACHE INTERNAL "")
84+
set(HAVE___CXA_DEMANGLE 1 CACHE INTERNAL "")
85+
endif()
86+
4787
add_subdirectory(third_party/glog)
4888
include_directories(${glog_SOURCE_DIR}/src)
4989

example/gpt2/main.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ const std::unordered_set<std::string> kSupportedModels
9898
= {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "d12", "d24", "d36", "d48"};
9999
constexpr char kDeviceCPU[] = "cpu";
100100
constexpr char kDeviceCUDA[] = "cuda";
101+
constexpr char kDeviceMACA[] = "maca";
101102
constexpr char kDtypeFP32[] = "float32";
102103
constexpr char kDtypeBF16[] = "bfloat16";
103104

@@ -112,8 +113,9 @@ const std::unordered_map<std::string, nn::TransformerConfig> kModelToConfigs = {
112113
} // namespace
113114

114115
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
115-
DEFINE_validator(device,
116-
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
116+
DEFINE_validator(device, [](const char *, const std::string &value) {
117+
return value == kDeviceCPU || value == kDeviceCUDA || value == kDeviceMACA;
118+
});
117119

118120
void Train(const nn::parallel::Rank &rank) {
119121
using namespace nn::parallel;
@@ -169,7 +171,13 @@ void Train(const nn::parallel::Rank &rank) {
169171
nn::parallel::pp_rank = pp_rank;
170172
}
171173
} else {
172-
device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0);
174+
if (FLAGS_device == kDeviceCPU) {
175+
device = Device();
176+
} else if (FLAGS_device == kDeviceMACA) {
177+
device = Device(Device::DeviceType::kMACA, 0);
178+
} else {
179+
device = Device(Device::DeviceType::kCUDA, 0);
180+
}
173181
}
174182

175183
// calculate gradient accumulation from the desired total batch size and the current run configuration

infini_train/include/common/cpu/common_cpu.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <type_traits>
44
#include <utility>
55

6+
#include "infini_train/include/datatype.h"
7+
68
namespace infini_train::common::cpu {
79
/**
810
* Converts a value between arbitrary types. This offers perfect
@@ -16,7 +18,12 @@ namespace infini_train::common::cpu {
1618
template <typename DST, typename SRC> DST Cast(SRC &&x) {
1719
static_assert(!std::is_reference_v<DST>, "Cast cannot return reference types");
1820

19-
// TODO(lzm): add cpu-version fp16 and bf16
20-
return (DST)(std::forward<SRC>(x));
21+
using Dst = std::remove_cv_t<std::remove_reference_t<DST>>;
22+
if constexpr (is_bfloat16<Dst>::value || is_fp16<Dst>::value) {
23+
// TODO(lzm): add cpu-version fp16 and bf16
24+
return Dst(static_cast<float>(std::forward<SRC>(x)));
25+
} else {
26+
return static_cast<DST>(std::forward<SRC>(x));
27+
}
2128
}
2229
} // namespace infini_train::common::cpu

infini_train/include/datatype.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,20 @@ template <> struct TypeMap<DataType::kFLOAT16> {
103103
#endif
104104
#undef DEFINE_DATA_TYPE_MAPPING
105105

106+
template <typename T> struct is_bfloat16 : std::false_type {};
107+
#if defined(USE_CUDA)
108+
template <> struct is_bfloat16<nv_bfloat16> : std::true_type {};
109+
#elif defined(USE_MACA)
110+
template <> struct is_bfloat16<__maca_bfloat16> : std::true_type {};
111+
#endif
112+
113+
template <typename T> struct is_fp16 : std::false_type {};
114+
#if defined(USE_CUDA)
115+
template <> struct is_fp16<half> : std::true_type {};
116+
#elif defined(USE_MACA)
117+
template <> struct is_fp16<__half> : std::true_type {};
118+
#endif
119+
106120
// Extends std::is_floating_point to support CUDA floating-point types.
107121
template <typename T> struct is_floating_point_ext : std::is_floating_point<T> {};
108122

infini_train/src/kernels/maca/concat.maca

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,11 @@ std::shared_ptr<Tensor> ConcatForward(const std::vector<std::shared_ptr<Tensor>>
112112
const T **device_input_ptrs = nullptr;
113113
int64_t *device_offsets = nullptr;
114114

115-
MACA_CHECK(mcMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream));
115+
MACA_CHECK(mcMallocAsync(reinterpret_cast<void **>(&device_input_ptrs), sizeof(T *) * num_inputs, stream));
116116
MACA_CHECK(mcMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs,
117117
mcMemcpyHostToDevice, stream));
118118

119-
MACA_CHECK(mcMallocAsync(&device_offsets, sizeof(int64_t) * (num_inputs + 1), stream));
119+
MACA_CHECK(mcMallocAsync(reinterpret_cast<void **>(&device_offsets), sizeof(int64_t) * (num_inputs + 1), stream));
120120
MACA_CHECK(mcMemcpyAsync(device_offsets, host_offsets.data(), sizeof(int64_t) * (num_inputs + 1),
121121
mcMemcpyHostToDevice, stream));
122122

@@ -218,11 +218,11 @@ std::vector<std::shared_ptr<Tensor>> ConcatBackward(const std::shared_ptr<Tensor
218218
T **device_ptrs = nullptr;
219219
int64_t *device_offsets = nullptr;
220220

221-
MACA_CHECK(mcMallocAsync(&device_ptrs, sizeof(T *) * num_inputs, stream));
221+
MACA_CHECK(mcMallocAsync(reinterpret_cast<void **>(&device_ptrs), sizeof(T *) * num_inputs, stream));
222222
MACA_CHECK(mcMemcpyAsync(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, mcMemcpyHostToDevice,
223223
stream));
224224

225-
MACA_CHECK(mcMallocAsync(&device_offsets, sizeof(int64_t) * (num_inputs + 1), stream));
225+
MACA_CHECK(mcMallocAsync(reinterpret_cast<void **>(&device_offsets), sizeof(int64_t) * (num_inputs + 1), stream));
226226
MACA_CHECK(mcMemcpyAsync(device_offsets, host_offsets.data(), sizeof(int64_t) * (num_inputs + 1),
227227
mcMemcpyHostToDevice, stream));
228228

infini_train/src/kernels/maca/elementwise.maca

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ void BinaryBackwardBhistLaunch(FuncA fn_a, FuncB fn_b, T *outA, T *outB, const T
427427

428428
// Workspace layout: [grid, K] floats.
429429
float *work = nullptr;
430-
MACA_CHECK(mcMallocAsync(&work, static_cast<size_t>(grid) * static_cast<size_t>(K) * sizeof(float), stream));
430+
MACA_CHECK(mcMallocAsync(reinterpret_cast<void **>(&work), static_cast<size_t>(grid) * static_cast<size_t>(K) * sizeof(float), stream));
431431

432432
// Pass 1: per-block histogram accumulation.
433433
const size_t smem_bytes = static_cast<size_t>(K + (K >> 5)) * sizeof(float);
@@ -439,7 +439,7 @@ void BinaryBackwardBhistLaunch(FuncA fn_a, FuncB fn_b, T *outA, T *outB, const T
439439
int dev = 0;
440440
int sm_count = 0;
441441
MACA_CHECK(mcGetDevice(&dev));
442-
MACA_CHECK(mcDeviceGetAttribute(&sm_count, mcDevAttrMultiProcessorCount, dev));
442+
MACA_CHECK(mcDeviceGetAttribute(&sm_count, mcDeviceAttributeMultiProcessorCount, dev));
443443

444444
const int RED_THREADS = 256;
445445
const int oneD_blocks = (K + RED_THREADS - 1) / RED_THREADS;
@@ -457,7 +457,7 @@ void BinaryBackwardBhistLaunch(FuncA fn_a, FuncB fn_b, T *outA, T *outB, const T
457457
// 2D tiling path: slice the workspace and accumulate using float atomics.
458458
constexpr int kTileHeight = 128; // rows per CTA; tune between 128 and 256 if needed
459459
float *outB_accum = nullptr;
460-
MACA_CHECK(mcMallocAsync(&outB_accum, static_cast<size_t>(K) * sizeof(float), stream));
460+
MACA_CHECK(mcMallocAsync(reinterpret_cast<void **>(&outB_accum), static_cast<size_t>(K) * sizeof(float), stream));
461461
MACA_CHECK(mcMemsetAsync(outB_accum, 0, static_cast<size_t>(K) * sizeof(float), stream));
462462

463463
const dim3 rblock(RED_THREADS, 1, 1);

infini_train/src/kernels/maca/gather.maca

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ std::shared_ptr<Tensor> IndexGatherForward(const std::shared_ptr<Tensor> &input,
8484
const int64_t gather_dim_size = in_dims[dim];
8585

8686
int64_t *dev_buf = nullptr;
87-
MACA_CHECK(mcMallocAsync(&dev_buf, (3 * num_dims) * sizeof(int64_t), stream));
87+
MACA_CHECK(mcMallocAsync(reinterpret_cast<void **>(&dev_buf), (3 * num_dims) * sizeof(int64_t), stream));
8888
int64_t *out_dims_dev = dev_buf + 0 * num_dims;
8989
int64_t *in_strides_dev = dev_buf + 1 * num_dims;
9090
int64_t *out_strides_dev = dev_buf + 2 * num_dims;
@@ -193,7 +193,7 @@ std::shared_ptr<Tensor> IndexGatherBackward(const std::shared_ptr<Tensor> &grad_
193193
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
194194
->maca_stream();
195195

196-
MACA_CHECK(mcMallocAsync(&dev_buf, total_i64 * sizeof(int64_t), stream));
196+
MACA_CHECK(mcMallocAsync(reinterpret_cast<void **>(&dev_buf), total_i64 * sizeof(int64_t), stream));
197197
int64_t *out_dims_dev = dev_buf;
198198
int64_t *in_strides_dev = out_dims_dev + n_out;
199199
int64_t *out_strides_dev = in_strides_dev + n_in_strides;

infini_train/src/kernels/maca/slice.maca

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ std::shared_ptr<Tensor> SliceForward(const std::shared_ptr<Tensor> &input, const
7373
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
7474
->maca_stream();
7575

76-
mcMallocAsync(&new_dims_dev,
76+
mcMallocAsync(reinterpret_cast<void **>(&new_dims_dev),
7777
(ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t),
7878
stream);
7979
starts_dev = new_dims_dev + ends.size();
@@ -167,7 +167,7 @@ std::shared_ptr<Tensor> SliceBackward(const std::shared_ptr<Tensor> &grad_output
167167
const auto &stream = dynamic_cast<infini_train::core::maca::MacaStream *>(
168168
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
169169
->maca_stream();
170-
mcMallocAsync(&new_dims_dev,
170+
mcMallocAsync(reinterpret_cast<void **>(&new_dims_dev),
171171
(ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t),
172172
stream);
173173
starts_dev = new_dims_dev + ends.size();

infini_train/src/kernels/maca/split.maca

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ std::shared_ptr<Tensor> LaunchSplitBackward(const std::vector<int64_t> &input_di
133133
void *device_ptr;
134134
const T **device_grad_output_ptrs;
135135
int64_t *device_H_outs;
136-
mcMallocAsync(&device_ptr, (sizeof(T *) + sizeof(int64_t)) * num_splits, stream);
136+
mcMallocAsync(reinterpret_cast<void **>(&device_ptr), (sizeof(T *) + sizeof(int64_t)) * num_splits, stream);
137137
device_grad_output_ptrs = (const T **)(device_ptr);
138138
device_H_outs = reinterpret_cast<int64_t *>(device_grad_output_ptrs + num_splits);
139139

infini_train/src/kernels/maca/stack.maca

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ std::shared_ptr<Tensor> StackForward(const std::vector<std::shared_ptr<Tensor>>
6767
for (const auto &t : inputs) { host_input_ptrs.push_back(static_cast<const T *>(t->DataPtr())); }
6868

6969
const T **device_input_ptrs;
70-
mcMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream);
70+
mcMallocAsync(reinterpret_cast<void **>(&device_input_ptrs), sizeof(T *) * num_inputs, stream);
7171
mcMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, mcMemcpyHostToDevice,
7272
stream);
7373

@@ -136,7 +136,7 @@ std::vector<std::shared_ptr<Tensor>> StackBackward(const std::vector<int64_t> &i
136136
for (auto &t : grads) { host_ptrs.push_back(static_cast<T *>(t->DataPtr())); }
137137

138138
T **device_ptrs;
139-
mcMallocAsync(&device_ptrs, sizeof(T *) * num_inputs, stream);
139+
mcMallocAsync(reinterpret_cast<void **>(&device_ptrs), sizeof(T *) * num_inputs, stream);
140140
mcMemcpyAsync(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, mcMemcpyHostToDevice, stream);
141141

142142
StackBackwardKernel<<<num_blocks, threads_per_block, 0, stream>>>(

0 commit comments

Comments
 (0)