Skip to content

Commit 55bef7d

Browse files
committed
fix: cross-platform torch compilation and version compatibility
- Iluvatar: set `CMAKE_CUDA_ARCHITECTURES` to `OFF` instead of `ivcore20` (CMake 4.3 rejects non-integer architecture names; the architecture is already passed via `CMAKE_CUDA_FLAGS`). - MetaX/Moore: split torch operator headers into declaration-only `.h` files and `.cc` implementation files with explicit instantiations. Compile the `.cc` files with the system `g++` instead of the vendor compiler (`mxcc`/`mcc`), which cannot parse vendor-forked torch headers in C++ extension mode. - Cambricon: guard `UInt16`/`UInt32`/`UInt64` scalar types in `ToAtenDtype()` with a `TORCH_VERSION` check (these types require PyTorch 2.4+; Cambricon ships torch 2.1). - Wrapper generator: scan only `.h` files to avoid including `.cc` explicit-instantiation files in the generated `ops.cc`.
1 parent 55703d6 commit 55bef7d

8 files changed

Lines changed: 163 additions & 62 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ if(WITH_ILUVATAR)
166166
endif()
167167
set(CMAKE_CUDA_FLAGS "-x ivcore -std=c++17 --cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags")
168168
set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Iluvatar")
169-
set(CMAKE_CUDA_ARCHITECTURES "${ILUVATAR_ARCH}" CACHE STRING "Iluvatar CUDA architectures")
169+
set(CMAKE_CUDA_ARCHITECTURES OFF CACHE STRING "Iluvatar CUDA architectures (passed via CMAKE_CUDA_FLAGS)")
170170
message(STATUS "Iluvatar: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${ILUVATAR_ARCH}")
171171
enable_language(CUDA)
172172
find_package(CUDAToolkit REQUIRED)

scripts/generate_wrappers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,8 @@ def _get_all_ops(devices, with_torch=False):
429429

430430
ops[op_name] = []
431431

432-
for file_path in _SRC_DIR.rglob("*"):
433-
if not file_path.is_file() or file_path.parent.parent.name not in scan_dirs:
432+
for file_path in _SRC_DIR.rglob("*.h"):
433+
if file_path.parent.parent.name not in scan_dirs:
434434
continue
435435

436436
if f"class Operator<{_snake_to_pascal(op_name)}" in file_path.read_text():

src/CMakeLists.txt

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,17 +219,56 @@ if(WITH_ASCEND)
219219
endif()
220220

221221
if(WITH_TORCH)
222-
set(TORCH_PATTERNS
223-
"torch/*.cc"
224-
"torch/*.cpp"
225-
)
226-
227-
file(GLOB_RECURSE TORCH_SOURCES CONFIGURE_DEPENDS ${TORCH_PATTERNS})
222+
file(GLOB_RECURSE TORCH_SOURCES CONFIGURE_DEPENDS "torch/*.cc" "torch/*.cpp")
228223

229224
target_compile_definitions(infiniops PUBLIC WITH_TORCH=1)
230-
target_sources(infiniops PRIVATE ${TORCH_SOURCES})
231225
target_link_libraries(infiniops PUBLIC ${TORCH_LIBRARIES})
232226
target_include_directories(infiniops PUBLIC ${TORCH_INCLUDE_DIRS})
227+
228+
if(WITH_METAX OR WITH_MOORE)
229+
# Vendor compilers (mxcc/mcc) cannot compile vendor-forked torch
230+
# headers. Compile torch sources with the system C++ compiler instead.
231+
find_program(SYSTEM_CXX NAMES g++ c++)
232+
233+
if(NOT SYSTEM_CXX)
234+
message(FATAL_ERROR "Could not find system g++ for torch compilation.")
235+
endif()
236+
237+
set(_torch_include_flags "")
238+
foreach(_dir ${TORCH_INCLUDE_DIRS})
239+
list(APPEND _torch_include_flags "-isystem" "${_dir}")
240+
endforeach()
241+
242+
set(TORCH_OBJECT_DIR "${CMAKE_CURRENT_BINARY_DIR}/torch_objs")
243+
file(MAKE_DIRECTORY "${TORCH_OBJECT_DIR}")
244+
245+
set(TORCH_OBJECT_FILES)
246+
foreach(_src ${TORCH_SOURCES})
247+
file(RELATIVE_PATH _rel ${CMAKE_CURRENT_SOURCE_DIR} ${_src})
248+
string(REPLACE "/" "_" _obj_name "${_rel}")
249+
string(REPLACE ".cc" ".o" _obj_name "${_obj_name}")
250+
string(REPLACE ".cpp" ".o" _obj_name "${_obj_name}")
251+
set(_obj "${TORCH_OBJECT_DIR}/${_obj_name}")
252+
253+
add_custom_command(
254+
OUTPUT "${_obj}"
255+
COMMAND ${SYSTEM_CXX}
256+
-std=c++17 -fPIC -O2
257+
"-I${CMAKE_CURRENT_SOURCE_DIR}"
258+
${_torch_include_flags}
259+
-c "${_src}" -o "${_obj}"
260+
DEPENDS "${_src}"
261+
COMMENT "Compiling ${_rel} with system C++ compiler"
262+
)
263+
list(APPEND TORCH_OBJECT_FILES "${_obj}")
264+
endforeach()
265+
266+
set_source_files_properties(${TORCH_OBJECT_FILES}
267+
PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
268+
target_sources(infiniops PRIVATE ${TORCH_OBJECT_FILES})
269+
else()
270+
target_sources(infiniops PRIVATE ${TORCH_SOURCES})
271+
endif()
233272
endif()
234273

235274
target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

src/torch/add/add.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "torch/add/add.h"
2+
3+
#include "torch/tensor_.h"
4+
5+
namespace infini::ops {
6+
7+
template <Device::Type kDev>
8+
Operator<Add, kDev, 1>::Operator(const Tensor input, const Tensor other,
9+
Tensor out)
10+
: Add{input, other, out},
11+
device_index_{out.device().index()} {}
12+
13+
template <Device::Type kDev>
14+
void Operator<Add, kDev, 1>::operator()(const Tensor input, const Tensor other,
15+
Tensor out) const {
16+
auto at_input = ToAtenTensor<kDev>(
17+
const_cast<void*>(input.data()), input_shape_, input_strides_,
18+
input_type_, device_index_);
19+
auto at_other = ToAtenTensor<kDev>(
20+
const_cast<void*>(other.data()), other_shape_, other_strides_,
21+
other_type_, device_index_);
22+
auto at_out = ToAtenTensor<kDev>(
23+
out.data(), out_shape_, out_strides_, out_type_, device_index_);
24+
25+
at::add_out(at_out, at_input, at_other);
26+
}
27+
28+
template class Operator<Add, Device::Type::kCpu, 1>;
29+
template class Operator<Add, Device::Type::kNvidia, 1>;
30+
template class Operator<Add, Device::Type::kCambricon, 1>;
31+
template class Operator<Add, Device::Type::kAscend, 1>;
32+
template class Operator<Add, Device::Type::kMetax, 1>;
33+
template class Operator<Add, Device::Type::kMoore, 1>;
34+
template class Operator<Add, Device::Type::kIluvatar, 1>;
35+
template class Operator<Add, Device::Type::kKunlun, 1>;
36+
template class Operator<Add, Device::Type::kHygon, 1>;
37+
template class Operator<Add, Device::Type::kQy, 1>;
38+
39+
} // namespace infini::ops

src/torch/add/add.h

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,16 @@
33

44
#include "base/add.h"
55
#include "torch/add/registry.h"
6-
#include "torch/tensor_.h"
76

87
namespace infini::ops {
98

109
template <Device::Type kDev>
1110
class Operator<Add, kDev, 1> : public Add {
1211
public:
13-
Operator(const Tensor input, const Tensor other, Tensor out)
14-
: Add{input, other, out},
15-
device_index_{out.device().index()} {}
12+
Operator(const Tensor input, const Tensor other, Tensor out);
1613

1714
void operator()(const Tensor input, const Tensor other,
18-
Tensor out) const override {
19-
// Use base-class stored metadata (not parameter tensors, which may be
20-
// moved-from by the `call()` dispatch path).
21-
auto at_input = ToAtenTensor<kDev>(
22-
const_cast<void*>(input.data()), input_shape_, input_strides_,
23-
input_type_, device_index_);
24-
auto at_other = ToAtenTensor<kDev>(
25-
const_cast<void*>(other.data()), other_shape_, other_strides_,
26-
other_type_, device_index_);
27-
auto at_out = ToAtenTensor<kDev>(
28-
out.data(), out_shape_, out_strides_, out_type_, device_index_);
29-
30-
at::add_out(at_out, at_input, at_other);
31-
}
15+
Tensor out) const override;
3216

3317
private:
3418
int device_index_{0};

src/torch/gemm/gemm.cc

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include "torch/gemm/gemm.h"
2+
3+
#include "torch/tensor_.h"
4+
5+
namespace infini::ops {
6+
7+
template <Device::Type kDev>
8+
Operator<Gemm, kDev, 2>::Operator(const Tensor a, const Tensor b,
9+
std::optional<float> alpha,
10+
std::optional<float> beta,
11+
std::optional<int> trans_a,
12+
std::optional<int> trans_b, Tensor c)
13+
: Gemm{a, b, alpha, beta, trans_a, trans_b, c},
14+
a_shape_{a.shape()},
15+
b_shape_{b.shape()},
16+
c_shape_{c.shape()},
17+
device_index_{c.device().index()} {}
18+
19+
template <Device::Type kDev>
20+
void Operator<Gemm, kDev, 2>::operator()(
21+
const Tensor a, const Tensor b, std::optional<float> alpha,
22+
std::optional<float> beta, std::optional<int> trans_a,
23+
std::optional<int> trans_b, Tensor c) const {
24+
auto at_a = ToAtenTensor<kDev>(const_cast<void*>(a.data()), a_shape_,
25+
a_strides_, a_type_, device_index_);
26+
auto at_b = ToAtenTensor<kDev>(const_cast<void*>(b.data()), b_shape_,
27+
b_strides_, b_type_, device_index_);
28+
auto at_c = ToAtenTensor<kDev>(c.data(), c_shape_, c_strides_, c_type_,
29+
device_index_);
30+
31+
auto alpha_val = alpha.value_or(alpha_);
32+
auto beta_val = beta.value_or(beta_);
33+
34+
if (trans_a.value_or(trans_a_)) {
35+
at_a = at_a.transpose(-2, -1);
36+
}
37+
38+
if (trans_b.value_or(trans_b_)) {
39+
at_b = at_b.transpose(-2, -1);
40+
}
41+
42+
if (at_a.dim() == 2) {
43+
at::addmm_out(at_c, at_c, at_a, at_b, beta_val, alpha_val);
44+
} else {
45+
at::baddbmm_out(at_c, at_c, at_a, at_b, beta_val, alpha_val);
46+
}
47+
}
48+
49+
template class Operator<Gemm, Device::Type::kCpu, 2>;
50+
template class Operator<Gemm, Device::Type::kNvidia, 2>;
51+
template class Operator<Gemm, Device::Type::kCambricon, 2>;
52+
template class Operator<Gemm, Device::Type::kAscend, 2>;
53+
template class Operator<Gemm, Device::Type::kMetax, 2>;
54+
template class Operator<Gemm, Device::Type::kMoore, 2>;
55+
template class Operator<Gemm, Device::Type::kIluvatar, 2>;
56+
template class Operator<Gemm, Device::Type::kKunlun, 2>;
57+
template class Operator<Gemm, Device::Type::kHygon, 2>;
58+
template class Operator<Gemm, Device::Type::kQy, 2>;
59+
60+
} // namespace infini::ops

src/torch/gemm/gemm.h

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include "base/gemm.h"
55
#include "torch/gemm/registry.h"
6-
#include "torch/tensor_.h"
76

87
namespace infini::ops {
98

@@ -12,44 +11,13 @@ class Operator<Gemm, kDev, 2> : public Gemm {
1211
public:
1312
Operator(const Tensor a, const Tensor b, std::optional<float> alpha,
1413
std::optional<float> beta, std::optional<int> trans_a,
15-
std::optional<int> trans_b, Tensor c)
16-
: Gemm{a, b, alpha, beta, trans_a, trans_b, c},
17-
a_shape_{a.shape()},
18-
b_shape_{b.shape()},
19-
c_shape_{c.shape()},
20-
device_index_{c.device().index()} {}
14+
std::optional<int> trans_b, Tensor c);
2115

2216
using Gemm::operator();
2317

2418
void operator()(const Tensor a, const Tensor b, std::optional<float> alpha,
2519
std::optional<float> beta, std::optional<int> trans_a,
26-
std::optional<int> trans_b, Tensor c) const override {
27-
// Use stored metadata instead of parameter tensors, which may be
28-
// moved-from by the `call()` dispatch path (see `operator.h`).
29-
auto at_a = ToAtenTensor<kDev>(const_cast<void*>(a.data()), a_shape_,
30-
a_strides_, a_type_, device_index_);
31-
auto at_b = ToAtenTensor<kDev>(const_cast<void*>(b.data()), b_shape_,
32-
b_strides_, b_type_, device_index_);
33-
auto at_c = ToAtenTensor<kDev>(c.data(), c_shape_, c_strides_, c_type_,
34-
device_index_);
35-
36-
auto alpha_val = alpha.value_or(alpha_);
37-
auto beta_val = beta.value_or(beta_);
38-
39-
if (trans_a.value_or(trans_a_)) {
40-
at_a = at_a.transpose(-2, -1);
41-
}
42-
43-
if (trans_b.value_or(trans_b_)) {
44-
at_b = at_b.transpose(-2, -1);
45-
}
46-
47-
if (at_a.dim() == 2) {
48-
at::addmm_out(at_c, at_c, at_a, at_b, beta_val, alpha_val);
49-
} else {
50-
at::baddbmm_out(at_c, at_c, at_a, at_b, beta_val, alpha_val);
51-
}
52-
}
20+
std::optional<int> trans_b, Tensor c) const override;
5321

5422
private:
5523
Tensor::Shape a_shape_;

src/torch/tensor_.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define INFINI_OPS_TORCH_TENSOR__H_
33

44
#include <torch/torch.h>
5+
#include <torch/version.h>
56

67
#include "tensor.h"
78
#include "torch/device_.h"
@@ -21,11 +22,21 @@ inline at::ScalarType ToAtenDtype(DataType dtype) {
2122
case DataType::kUInt8:
2223
return at::kByte;
2324
case DataType::kUInt16:
25+
#if TORCH_VERSION_MAJOR > 2 || \
26+
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 4)
2427
return c10::ScalarType::UInt16;
2528
case DataType::kUInt32:
2629
return c10::ScalarType::UInt32;
2730
case DataType::kUInt64:
2831
return c10::ScalarType::UInt64;
32+
#else
33+
[[fallthrough]];
34+
case DataType::kUInt32:
35+
[[fallthrough]];
36+
case DataType::kUInt64:
37+
assert(false && "Unsigned integer types require PyTorch 2.4 or later.");
38+
return at::kFloat;
39+
#endif
2940
case DataType::kFloat16:
3041
return at::kHalf;
3142
case DataType::kBFloat16:

0 commit comments

Comments
 (0)