diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f4c5cb4..e88cc20c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,10 @@ option(WITH_CAMBRICON "Enable Cambricon backend" OFF) option(WITH_MOORE "Enable Moore backend" OFF) option(WITH_ASCEND "Enable Ascend backend" OFF) +option(WITH_TORCH "Enable PyTorch C++ backend" OFF) + option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) +option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) if(AUTO_DETECT_DEVICES) @@ -79,6 +82,72 @@ if(AUTO_DETECT_DEVICES) endif() endif() +if(AUTO_DETECT_BACKENDS) + message(STATUS "Auto-detecting available backends...") + + find_package(Python COMPONENTS Interpreter QUIET) + + if(Python_FOUND) + execute_process( + COMMAND ${Python_EXECUTABLE} -c "import torch" + RESULT_VARIABLE _torch_import_result + OUTPUT_QUIET + ERROR_QUIET + ) + + if(_torch_import_result EQUAL 0) + set(WITH_TORCH ON) + message(STATUS "Auto-detected PyTorch.") + endif() + endif() +endif() + +if(WITH_TORCH) + find_package(Python COMPONENTS Interpreter REQUIRED) + + # Query `torch` paths directly instead of using `find_package(Torch)`, + # which pulls in Caffe2's CMake config and may fail on platforms with + # non-standard CUDA toolchains. + execute_process( + COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))" + OUTPUT_VARIABLE TORCH_INCLUDE_DIRS + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _torch_result + ) + + if(NOT _torch_result EQUAL 0) + message(FATAL_ERROR "`WITH_TORCH` is `ON` but `torch` is not installed.") + endif() + + execute_process( + COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import library_paths; print(';'.join(library_paths()))" + OUTPUT_VARIABLE _torch_lib_dirs + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + + find_library(TORCH_LIB torch HINTS ${_torch_lib_dirs} REQUIRED) + find_library(TORCH_CPU_LIB torch_cpu HINTS ${_torch_lib_dirs} REQUIRED) + find_library(C10_LIB c10 HINTS ${_torch_lib_dirs} REQUIRED) + set(TORCH_LIBRARIES ${TORCH_LIB} ${TORCH_CPU_LIB} ${C10_LIB}) + + # Query the `CXX11` ABI setting that `torch` was compiled with. + # A mismatch causes linker errors (e.g. undefined reference to + # `c10::Device::Device(std::string const&)`). + execute_process( + COMMAND ${Python_EXECUTABLE} -c "import torch; print(int(torch.compiled_with_cxx11_abi()))" + OUTPUT_VARIABLE TORCH_CXX11_ABI + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _torch_abi_result + ) + + if(_torch_abi_result EQUAL 0) + add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}) + message(STATUS "PyTorch CXX11 ABI: ${TORCH_CXX11_ABI}") + endif() + + message(STATUS "Found PyTorch: ${TORCH_INCLUDE_DIRS}") +endif() + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) # Only one CUDA-like GPU backend can be enabled at a time. @@ -110,11 +179,23 @@ if(WITH_ILUVATAR) else() set(CMAKE_CUDA_COMPILER "clang++" CACHE STRING "Iluvatar CUDA compiler (clang++)") endif() - set(CMAKE_CUDA_FLAGS "-x ivcore -std=c++17 --cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags") + # `-x ivcore` must not be in `CMAKE_CUDA_FLAGS` — CMake passes those flags + # to both compile and link steps. During linking, `-x ivcore` causes + # `clang++` to re-parse `.o` files as source code. + set(CMAKE_CUDA_FLAGS "--cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags") set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Iluvatar") + set(CMAKE_CUDA_ARCHITECTURES OFF CACHE STRING "Iluvatar CUDA architectures (passed via CMAKE_CUDA_FLAGS)") + # Iluvatar does not ship `libcudadevrt`, which CMake's compiler test + # tries to link. Disable automatic CUDA runtime linking and link + # manually via `find_package(CUDAToolkit)` instead. + set(CMAKE_CUDA_RUNTIME_LIBRARY NONE) message(STATUS "Iluvatar: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${ILUVATAR_ARCH}") enable_language(CUDA) find_package(CUDAToolkit REQUIRED) + # Add `-x ivcore` as a compile-only flag so it is not passed during + # linking, where it would cause `clang++` to re-parse `.o` files as + # source. + add_compile_options($<$:-x$ivcore>) endif() if(WITH_METAX) diff --git a/pyproject.toml b/pyproject.toml index 3dbc186d..58740166 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ install-dir = "infini" [tool.scikit-build.cmake.define] AUTO_DETECT_DEVICES = "ON" +AUTO_DETECT_BACKENDS = "ON" GENERATE_PYTHON_BINDINGS = "ON" [tool.pytest.ini_options] diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 0023c7e9..a2795e24 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -413,7 +413,12 @@ def _snake_to_pascal(snake_str): return "".join(word.capitalize() for word in snake_str.split("_")) -def _get_all_ops(devices): +def _get_all_ops(devices, with_torch=False): + scan_dirs = set(devices) + + if with_torch: + scan_dirs.add("torch") + ops = {} for file_path in _BASE_DIR.iterdir(): @@ -424,8 +429,8 @@ def _get_all_ops(devices): ops[op_name] = [] - for file_path in _SRC_DIR.rglob("*"): - if not file_path.is_file() or file_path.parent.parent.name not in devices: + for file_path in _SRC_DIR.rglob("*.h"): + if file_path.parent.parent.name not in scan_dirs: continue if f"class Operator<{_snake_to_pascal(op_name)}" in file_path.read_text(): @@ -445,6 +450,12 @@ def _get_all_ops(devices): help="Devices to use. Please pick from `cpu`, `nvidia`, `cambricon`, `ascend`, `metax`, `moore`, `iluvatar`, `kunlun`, `hygon`, and `qy`. (default: `cpu`)", ) + parser.add_argument( + "--with-torch", + action="store_true", + help="Include PyTorch C++ backend implementations.", + ) + args = parser.parse_args() _BINDINGS_DIR.mkdir(parents=True, exist_ok=True) @@ -456,7 +467,7 @@ def _get_all_ops(devices): if ops_json.exists(): ops = json.loads(ops_json.read_text()) else: - ops = _get_all_ops(args.devices) + ops = _get_all_ops(args.devices, with_torch=args.with_torch) header_paths = [] bind_func_names = [] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2eb2591d..95f5da4f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -16,7 +16,7 @@ if(WITH_CPU) target_compile_definitions(infiniops PUBLIC WITH_CPU=1) - find_package(OpenMP REQUIRED) + find_package(OpenMP REQUIRED COMPONENTS CXX) target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX) list(APPEND DEVICE_LIST "cpu") @@ -218,6 +218,69 @@ if(WITH_ASCEND) list(APPEND DEVICE_LIST "ascend") endif() +if(WITH_TORCH) + file(GLOB_RECURSE TORCH_SOURCES CONFIGURE_DEPENDS "torch/*.cc" "torch/*.cpp") + + target_compile_definitions(infiniops PUBLIC WITH_TORCH=1) + target_link_libraries(infiniops PUBLIC ${TORCH_LIBRARIES}) + target_include_directories(infiniops PUBLIC ${TORCH_INCLUDE_DIRS}) + + if(WITH_METAX OR WITH_MOORE) + # Vendor compilers (`mxcc`/`mcc`) cannot compile vendor-forked `torch` + # headers. Compile `torch` sources with the system C++ compiler instead. + find_program(SYSTEM_CXX NAMES g++ c++) + + if(NOT SYSTEM_CXX) + message(FATAL_ERROR "Could not find system `g++` for `torch` compilation.") + endif() + + set(_torch_include_flags "") + foreach(_dir ${TORCH_INCLUDE_DIRS}) + list(APPEND _torch_include_flags "-isystem" "${_dir}") + endforeach() + + # Vendor-specific defines required by forked `torch` headers. + set(_torch_extra_flags "") + if(WITH_METAX) + list(APPEND _torch_extra_flags "-DUSE_MACA=1") + endif() + if(DEFINED TORCH_CXX11_ABI) + list(APPEND _torch_extra_flags "-D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}") + endif() + + set(TORCH_OBJECT_DIR "${CMAKE_CURRENT_BINARY_DIR}/torch_objs") + file(MAKE_DIRECTORY "${TORCH_OBJECT_DIR}") + + set(TORCH_OBJECT_FILES) + foreach(_src ${TORCH_SOURCES}) + file(RELATIVE_PATH _rel ${CMAKE_CURRENT_SOURCE_DIR} ${_src}) + string(REPLACE "/" "_" _obj_name "${_rel}") + string(REPLACE ".cc" ".o" _obj_name "${_obj_name}") + string(REPLACE ".cpp" ".o" _obj_name "${_obj_name}") + set(_obj "${TORCH_OBJECT_DIR}/${_obj_name}") + + add_custom_command( + OUTPUT "${_obj}" + COMMAND ${SYSTEM_CXX} + -std=c++17 -fPIC -O2 + "-I${CMAKE_CURRENT_SOURCE_DIR}" + ${_torch_include_flags} + ${_torch_extra_flags} + -c "${_src}" -o "${_obj}" + DEPENDS "${_src}" + COMMENT "Compiling ${_rel} with system C++ compiler" + ) + list(APPEND TORCH_OBJECT_FILES "${_obj}") + endforeach() + + set_source_files_properties(${TORCH_OBJECT_FILES} + PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE) + target_sources(infiniops PRIVATE ${TORCH_OBJECT_FILES}) + else() + target_sources(infiniops PRIVATE ${TORCH_SOURCES}) + endif() +endif() + target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) if(GENERATE_PYTHON_BINDINGS) @@ -226,8 +289,14 @@ if(GENERATE_PYTHON_BINDINGS) # active device list. Stale generated files (e.g., committed for one # platform) would omit specializations for other enabled backends, # causing link-time or runtime failures. + + set(GENERATOR_ARGS --devices ${DEVICE_LIST}) + if(WITH_TORCH) + list(APPEND GENERATOR_ARGS --with-torch) + endif() + execute_process( - COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST} + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS} WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} RESULT_VARIABLE script_result ) @@ -246,6 +315,20 @@ if(GENERATE_PYTHON_BINDINGS) endif() find_package(Python COMPONENTS Interpreter Development) + + if(NOT pybind11_DIR) + execute_process( + COMMAND ${Python_EXECUTABLE} -m pybind11 --cmakedir + OUTPUT_VARIABLE _pybind11_cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _pybind11_result + ) + + if(_pybind11_result EQUAL 0) + set(pybind11_DIR "${_pybind11_cmake_dir}" CACHE PATH "pybind11 CMake directory") + endif() + endif() + find_package(pybind11 CONFIG) if(PYBIND11_ENABLE_EXTRAS) diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h index 35bdd77a..189da177 100644 --- a/src/nvidia/gemm/cublas.h +++ b/src/nvidia/gemm/cublas.h @@ -3,7 +3,6 @@ #include "cuda/gemm/blas.h" #include "nvidia/blas.h" -#include "nvidia/gemm/registry.h" namespace infini::ops { diff --git a/src/nvidia/gemm/cublaslt.h b/src/nvidia/gemm/cublaslt.h index 38de8507..4033b472 100644 --- a/src/nvidia/gemm/cublaslt.h +++ b/src/nvidia/gemm/cublaslt.h @@ -10,7 +10,6 @@ #include "base/gemm.h" #include "nvidia/blas_utils.h" -#include "nvidia/gemm/registry.h" #include "nvidia/runtime_.h" namespace infini::ops { diff --git a/src/nvidia/gemm/registry.h b/src/nvidia/gemm/registry.h deleted file mode 100644 index a13591dc..00000000 --- a/src/nvidia/gemm/registry.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef INFINI_OPS_NVIDIA_GEMM_REGISTRY_H_ -#define INFINI_OPS_NVIDIA_GEMM_REGISTRY_H_ - -#include "base/gemm.h" - -namespace infini::ops { - -template <> -struct ActiveImplementationsImpl { - using type = List<0, 1>; -}; - -} // namespace infini::ops - -#endif diff --git a/src/operator.h b/src/operator.h index dbe92d7d..8a71a334 100644 --- a/src/operator.h +++ b/src/operator.h @@ -84,14 +84,9 @@ struct std::equal_to { namespace infini::ops { +// Forward declaration — defined after `Operator` using SFINAE auto-detection. template -struct ActiveImplementationsImpl { - using type = List<0>; -}; - -template -using ActiveImplementations = - typename ActiveImplementationsImpl::type; +struct ActiveImplementations; class OperatorBase { public: @@ -153,7 +148,7 @@ class Operator : public OperatorBase { } }, "Operator::make(implementation_index)", - ActiveImplementations{}); + typename ActiveImplementations::type{}); }, "Operator::make"); @@ -200,7 +195,8 @@ class Operator : public OperatorBase { dev_type, [&](auto device_tag) { constexpr Device::Type kDev = decltype(device_tag)::value; - result = detail::ListToVector(ActiveImplementations{}); + result = detail::ListToVector( + typename ActiveImplementations::type{}); }, "Operator::active_implementation_indices"); return result; @@ -227,6 +223,45 @@ class Operator : public OperatorBase { static constexpr std::size_t implementation_index_{implementation_index}; }; +// Maximum number of implementation slots per (operator, device) pair. +// Increase this value when adding operators with more implementations. +constexpr std::size_t kMaxImplementations = 16; + +// SFINAE-based implementation detection. A partial specialization +// `Operator` inherits from `Key` (the operator base class), +// while the unspecialized primary template inherits only from `OperatorBase`. +// `std::is_base_of` distinguishes the two at compile time, eliminating the +// need for manual `registry.h` files. +template >> +struct ActiveImplementationsImpl { + using type = List<>; +}; + +template +struct ActiveImplementationsImpl { + using type = List; +}; + +namespace detail { + +template +struct ActiveImplementationsHelper; + +template +struct ActiveImplementationsHelper> { + using type = typename Flatten< + typename ActiveImplementationsImpl::type...>::type; +}; + +} // namespace detail + +template +struct ActiveImplementations { + using type = typename detail::ActiveImplementationsHelper< + Key, kDev, std::make_index_sequence>::type; +}; + } // namespace infini::ops #endif diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 766b6eab..b595836c 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -5,6 +5,7 @@ #include #include "tensor.h" +#include "torch/device_.h" namespace py = pybind11; @@ -12,59 +13,6 @@ namespace infini::ops { namespace detail { -template -struct TorchDeviceName; - -template <> -struct TorchDeviceName { - static constexpr std::string_view kValue{"cpu"}; -}; - -template <> -struct TorchDeviceName { - static constexpr std::string_view kValue{"cuda"}; -}; - -template <> -struct TorchDeviceName { - static constexpr std::string_view kValue{"cuda"}; -}; - -template <> -struct TorchDeviceName { - static constexpr std::string_view kValue{"cuda"}; -}; - -template <> -struct TorchDeviceName { - static constexpr std::string_view kValue{"cuda"}; -}; - -template <> -struct TorchDeviceName { - static constexpr std::string_view kValue{"cuda"}; -}; - -template <> -struct TorchDeviceName { - static constexpr std::string_view kValue{"cuda"}; -}; - -template <> -struct TorchDeviceName { - static constexpr std::string_view kValue{"mlu"}; -}; - -template <> -struct TorchDeviceName { - static constexpr std::string_view kValue{"npu"}; -}; - -template <> -struct TorchDeviceName { - static constexpr std::string_view kValue{"musa"}; -}; - template std::unordered_map BuildTorchNameMap( List) { diff --git a/src/torch/add/add.cc b/src/torch/add/add.cc new file mode 100644 index 00000000..cd2f2cc9 --- /dev/null +++ b/src/torch/add/add.cc @@ -0,0 +1,38 @@ +#include "torch/add/add.h" + +#include "torch/tensor_.h" + +namespace infini::ops { + +template +Operator::Operator(const Tensor input, const Tensor other, + Tensor out) + : Add{input, other, out}, device_index_{out.device().index()} {} + +template +void Operator::operator()(const Tensor input, const Tensor other, + Tensor out) const { + auto at_input = + ToAtenTensor(const_cast(input.data()), input_shape_, + input_strides_, input_type_, device_index_); + auto at_other = + ToAtenTensor(const_cast(other.data()), other_shape_, + other_strides_, other_type_, device_index_); + auto at_out = ToAtenTensor(out.data(), out_shape_, out_strides_, + out_type_, device_index_); + + at::add_out(at_out, at_input, at_other); +} + +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; + +} // namespace infini::ops diff --git a/src/torch/add/add.h b/src/torch/add/add.h new file mode 100644 index 00000000..5a7ea230 --- /dev/null +++ b/src/torch/add/add.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_TORCH_ADD_H_ +#define INFINI_OPS_TORCH_ADD_H_ + +#include "base/add.h" + +namespace infini::ops { + +template +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out); + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override; + + private: + int device_index_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/torch/device_.h b/src/torch/device_.h new file mode 100644 index 00000000..f83fc8a2 --- /dev/null +++ b/src/torch/device_.h @@ -0,0 +1,65 @@ +#ifndef INFINI_OPS_TORCH_DEVICE__H_ +#define INFINI_OPS_TORCH_DEVICE__H_ + +#include + +#include "device.h" + +namespace infini::ops::detail { + +template +struct TorchDeviceName; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cpu"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"mlu"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"npu"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"musa"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +} // namespace infini::ops::detail + +#endif diff --git a/src/torch/gemm/gemm.cc b/src/torch/gemm/gemm.cc new file mode 100644 index 00000000..d17353f1 --- /dev/null +++ b/src/torch/gemm/gemm.cc @@ -0,0 +1,62 @@ +#include "torch/gemm/gemm.h" + +#include "torch/tensor_.h" + +namespace infini::ops { + +template +Operator::Operator(const Tensor a, const Tensor b, + std::optional alpha, + std::optional beta, + std::optional trans_a, + std::optional trans_b, Tensor c) + : Gemm{a, b, alpha, beta, trans_a, trans_b, c}, + a_shape_{a.shape()}, + b_shape_{b.shape()}, + c_shape_{c.shape()}, + device_index_{c.device().index()} {} + +template +void Operator::operator()(const Tensor a, const Tensor b, + std::optional alpha, + std::optional beta, + std::optional trans_a, + std::optional trans_b, + Tensor c) const { + auto at_a = ToAtenTensor(const_cast(a.data()), a_shape_, + a_strides_, a_type_, device_index_); + auto at_b = ToAtenTensor(const_cast(b.data()), b_shape_, + b_strides_, b_type_, device_index_); + auto at_c = ToAtenTensor(c.data(), c_shape_, c_strides_, c_type_, + device_index_); + + auto alpha_val = alpha.value_or(alpha_); + auto beta_val = beta.value_or(beta_); + + if (trans_a.value_or(trans_a_)) { + at_a = at_a.transpose(-2, -1); + } + + if (trans_b.value_or(trans_b_)) { + at_b = at_b.transpose(-2, -1); + } + + if (at_a.dim() == 2) { + at::addmm_out(at_c, at_c, at_a, at_b, beta_val, alpha_val); + } else { + at::baddbmm_out(at_c, at_c, at_a, at_b, beta_val, alpha_val); + } +} + +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; +template class Operator; + +} // namespace infini::ops diff --git a/src/torch/gemm/gemm.h b/src/torch/gemm/gemm.h new file mode 100644 index 00000000..4fd22ff3 --- /dev/null +++ b/src/torch/gemm/gemm.h @@ -0,0 +1,33 @@ +#ifndef INFINI_OPS_TORCH_GEMM_H_ +#define INFINI_OPS_TORCH_GEMM_H_ + +#include "base/gemm.h" + +namespace infini::ops { + +template +class Operator : public Gemm { + public: + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c); + + using Gemm::operator(); + + void operator()(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) const override; + + private: + Tensor::Shape a_shape_; + + Tensor::Shape b_shape_; + + Tensor::Shape c_shape_; + + int device_index_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/torch/tensor_.h b/src/torch/tensor_.h new file mode 100644 index 00000000..d568c05a --- /dev/null +++ b/src/torch/tensor_.h @@ -0,0 +1,105 @@ +#ifndef INFINI_OPS_TORCH_TENSOR__H_ +#define INFINI_OPS_TORCH_TENSOR__H_ + +#include +#include + +#include "tensor.h" +#include "torch/device_.h" + +namespace infini::ops { + +namespace detail { + +// Introduces a dependent type alias for `c10::ScalarType`, allowing +// `if constexpr` to discard branches that reference enum values absent +// in older PyTorch versions. Without this indirection the enum member +// names (e.g. `UInt16`) are non-dependent and looked up at definition +// time, causing a hard error on PyTorch < 2.4. +template +struct DependentScalarType { + using type = c10::ScalarType; +}; + +constexpr int kTorchVersion = TORCH_VERSION_MAJOR * 100 + TORCH_VERSION_MINOR; + +// Unsigned integer scalar types are only available in PyTorch >= 2.4. +template +inline at::ScalarType ToAtenUnsignedDataType(DataType dtype) { + if constexpr (kVersion >= 204) { + using ST = typename DependentScalarType::type; + switch (dtype) { + case DataType::kUInt16: + return ST::UInt16; + case DataType::kUInt32: + return ST::UInt32; + case DataType::kUInt64: + return ST::UInt64; + default: + assert(false && "not an unsigned integer dtype"); + return at::kFloat; + } + } else { + (void)dtype; + assert(false && "unsigned integer types require PyTorch 2.4 or later"); + return at::kFloat; + } +} + +} // namespace detail + +inline at::ScalarType ToAtenDataType(DataType dtype) { + switch (dtype) { + case DataType::kInt8: + return at::kChar; + case DataType::kInt16: + return at::kShort; + case DataType::kInt32: + return at::kInt; + case DataType::kInt64: + return at::kLong; + case DataType::kUInt8: + return at::kByte; + case DataType::kUInt16: + case DataType::kUInt32: + case DataType::kUInt64: + return detail::ToAtenUnsignedDataType(dtype); + case DataType::kFloat16: + return at::kHalf; + case DataType::kBFloat16: + return at::kBFloat16; + case DataType::kFloat32: + return at::kFloat; + case DataType::kFloat64: + return at::kDouble; + default: + assert(false && "unsupported dtype for ATen conversion"); + return at::kFloat; + } +} + +// Build an ATen tensor from explicit metadata. Use this instead of reading +// shape/strides from the `Tensor` parameter, which may have been moved-from +// by the `call()` dispatch path (see `operator.h`). +template +inline at::Tensor ToAtenTensor(void* data, const Tensor::Shape& shape, + const Tensor::Strides& strides, DataType dtype, + int device_index = 0) { + std::vector at_shape(shape.begin(), shape.end()); + std::vector at_strides(strides.begin(), strides.end()); + + auto options = at::TensorOptions().dtype(ToAtenDataType(dtype)); + + if constexpr (kDev != Device::Type::kCpu) { + std::string device_str = + std::string(detail::TorchDeviceName::kValue) + ":" + + std::to_string(device_index); + options = options.device(device_str); + } + + return at::from_blob(data, at_shape, at_strides, options); +} + +} // namespace infini::ops + +#endif diff --git a/tests/test_add.py b/tests/test_add.py index 8b8166c3..825fc932 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -29,6 +29,9 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) +# TODO: Generate implementation indices dynamically from +# `Add.active_implementation_indices` instead of hardcoding. +@pytest.mark.parametrize("implementation_index", (0, 1)) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -39,13 +42,29 @@ + tuple((dtype, 0, 0) for dtype in _INT_DTYPES + _UINT_DTYPES), ) def test_add( - shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol + shape, + input_strides, + other_strides, + out_strides, + implementation_index, + dtype, + device, + rtol, + atol, ): if device == "musa" and dtype in _UINT_DTYPES: pytest.skip( "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." ) + active_indices = infini.ops.Add.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + + if implementation_index == 1 and dtype in _UINT_DTYPES: + pytest.skip("ATen `add` does not support unsigned integer types") + if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: input = randint_strided( 0, 100, shape, input_strides, dtype=dtype, device=device @@ -59,11 +78,18 @@ def test_add( out = empty_strided(shape, out_strides, dtype=dtype, device=device) - return Payload(_add, _torch_add, (input, other, out), {}, rtol=rtol, atol=atol) + return Payload( + lambda *args: _add(*args, implementation_index=implementation_index), + _torch_add, + (input, other, out), + {}, + rtol=rtol, + atol=atol, + ) -def _add(input, other, out): - infini.ops.add(input, other, out) +def _add(input, other, out, implementation_index=0): + infini.ops.add(input, other, out, implementation_index=implementation_index) return out diff --git a/tests/test_gemm.py b/tests/test_gemm.py index d3b26884..26e102d2 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -20,7 +20,9 @@ @pytest.mark.parametrize("beta", (-1, -0.5, 0, 0.5, 1)) @pytest.mark.parametrize("trans_a", (False, True)) @pytest.mark.parametrize("trans_b", (False, True)) -@pytest.mark.parametrize("implementation_index", (0, 1)) +# TODO: Generate implementation indices dynamically from +# `Gemm.active_implementation_indices` instead of hardcoding. +@pytest.mark.parametrize("implementation_index", (0, 1, 2)) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -62,6 +64,13 @@ def test_gemm( if implementation_index == 1 and dtype in (torch.float16, torch.bfloat16): pytest.skip("cuBLASLt half-precision exceeds current tolerances") + if ( + implementation_index == 2 + and device == "cpu" + and dtype in (torch.float16, torch.bfloat16) + ): + pytest.skip("ATen CPU `addmm`/`baddbmm` does not support half-precision") + a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) b = randn_strided(b_shape, b_strides, dtype=dtype, device=device)