Skip to content

Commit c4141be

Browse files
authored
feat(nvidia): add ntops RMSNorm backend (#616)
1 parent d6804bf commit c4141be

8 files changed

Lines changed: 452 additions & 2 deletions

File tree

CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ option(WITH_ASCEND "Enable Ascend backend" OFF)
2525

2626
option(WITH_TORCH "Enable PyTorch C++ backend" OFF)
2727

28+
option(WITH_NINETOOTHED "Enable NineToothed-generated kernels" OFF)
29+
2830
# Default OFF until CANN's `extract_host_stub.py` path handling is fixed for
2931
# `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed
3032
# object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the
@@ -293,6 +295,14 @@ if(_gpu_backend_count GREATER 1)
293295
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_HYGON`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.")
294296
endif()
295297

298+
if(WITH_NINETOOTHED AND NOT WITH_NVIDIA)
299+
message(FATAL_ERROR "`WITH_NINETOOTHED` currently requires `WITH_NVIDIA=ON` because NineToothed AOT uses `caller=\"cuda\"`.")
300+
endif()
301+
302+
if(WITH_NINETOOTHED)
303+
set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run NineToothed code generation")
304+
endif()
305+
296306
if(WITH_NVIDIA)
297307
add_compile_definitions(WITH_NVIDIA=1)
298308
enable_language(CUDA)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import argparse
2+
import importlib.util
3+
import pathlib
4+
import shutil
5+
import sys
6+
7+
_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[1]
8+
_OPS_DIR = _PROJECT_DIR / "src" / "ninetoothed" / "ops"
9+
10+
11+
def _find_op_modules():
12+
return {
13+
path.parent.name: path
14+
for path in sorted(_OPS_DIR.glob("*/build.py"))
15+
if path.is_file()
16+
}
17+
18+
19+
def _build_manifest(output_dir):
20+
return sorted(
21+
str(path)
22+
for path in pathlib.Path(output_dir).rglob("*.cpp")
23+
if not path.name.endswith(".tmp.cpp")
24+
)
25+
26+
27+
def _write_cmake_manifest(output_dir, sources):
28+
manifest_path = pathlib.Path(output_dir) / "manifest.cmake"
29+
lines = ["set(INFINIOPS_NINETOOTHED_SOURCES"]
30+
lines.extend(f' "{source}"' for source in sources)
31+
lines.append(")")
32+
lines.append("")
33+
lines.append(f'set(INFINIOPS_NINETOOTHED_INCLUDE_DIRS "{output_dir}")')
34+
lines.append("")
35+
manifest_path.write_text("\n".join(lines) + "\n")
36+
37+
38+
def _load_op_module(op):
39+
path = _find_op_modules()[op]
40+
sys.path.insert(0, str(path.parent))
41+
spec = importlib.util.spec_from_file_location(path.stem, path)
42+
module = importlib.util.module_from_spec(spec)
43+
assert spec.loader is not None
44+
sys.modules[spec.name] = module
45+
spec.loader.exec_module(module)
46+
47+
return module
48+
49+
50+
def generate(ops, *, output_dir):
51+
op_modules = _find_op_modules()
52+
unknown_ops = tuple(op for op in ops if op not in op_modules)
53+
54+
if unknown_ops:
55+
raise ValueError(f"unsupported NineToothed ops: {', '.join(unknown_ops)}")
56+
57+
output_dir = pathlib.Path(output_dir)
58+
shutil.rmtree(output_dir, ignore_errors=True)
59+
output_dir.mkdir(parents=True, exist_ok=True)
60+
61+
for op in ops:
62+
module = _load_op_module(op)
63+
module.build(output_dir)
64+
65+
sources = _build_manifest(output_dir)
66+
_write_cmake_manifest(output_dir, sources)
67+
68+
return sources
69+
70+
71+
def _parse_args():
72+
parser = argparse.ArgumentParser(
73+
description="Generate NineToothed operator sources for InfiniOps."
74+
)
75+
parser.add_argument("--output-dir", required=True)
76+
parser.add_argument("--ops", nargs="+", default=tuple(_find_op_modules()))
77+
78+
return parser.parse_args()
79+
80+
81+
def main():
82+
args = _parse_args()
83+
generate(args.ops, output_dir=args.output_dir)
84+
85+
86+
if __name__ == "__main__":
87+
main()

scripts/generate_wrappers.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,11 +1017,13 @@ def _index_impl_headers(impl_roots, scan_dirs):
10171017
return by_operator
10181018

10191019

1020-
def _get_all_ops(devices, with_torch=False):
1020+
def _get_all_ops(devices, with_torch=False, with_ninetoothed=False):
10211021
scan_dirs = set(devices)
10221022

10231023
if with_torch:
10241024
scan_dirs.add("torch")
1025+
if with_ninetoothed:
1026+
scan_dirs.add("ninetoothed")
10251027

10261028
ops = {}
10271029

@@ -1140,6 +1142,11 @@ def _dispatch_gen_batch_size():
11401142
action="store_true",
11411143
help="Include PyTorch C++ backend implementations.",
11421144
)
1145+
parser.add_argument(
1146+
"--with-ninetoothed",
1147+
action="store_true",
1148+
help="Include NineToothed backend implementations.",
1149+
)
11431150

11441151
args = parser.parse_args()
11451152

@@ -1159,7 +1166,11 @@ def _dispatch_gen_batch_size():
11591166
if ops_json.exists():
11601167
ops = json.loads(ops_json.read_text())
11611168
else:
1162-
ops = _get_all_ops(args.devices, with_torch=args.with_torch)
1169+
ops = _get_all_ops(
1170+
args.devices,
1171+
with_torch=args.with_torch,
1172+
with_ninetoothed=args.with_ninetoothed,
1173+
)
11631174

11641175
bind_func_names = []
11651176

src/CMakeLists.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,39 @@ if(WITH_NVIDIA)
4949
)
5050
endif()
5151

52+
if(WITH_NINETOOTHED)
53+
find_package(Python COMPONENTS Interpreter REQUIRED)
54+
55+
if(NINETOOTHED_PYTHON_EXECUTABLE)
56+
set(_ninetoothed_python "${NINETOOTHED_PYTHON_EXECUTABLE}")
57+
elseif(_TORCH_PYTHON)
58+
set(_ninetoothed_python "${_TORCH_PYTHON}")
59+
else()
60+
set(_ninetoothed_python "${Python_EXECUTABLE}")
61+
endif()
62+
message(STATUS "NineToothed codegen Python: ${_ninetoothed_python}")
63+
64+
set(_ninetoothed_output_dir "${CMAKE_CURRENT_BINARY_DIR}/ninetoothed")
65+
set(_ninetoothed_generator_args
66+
"${PROJECT_SOURCE_DIR}/scripts/generate_ninetoothed_ops.py"
67+
--output-dir "${_ninetoothed_output_dir}")
68+
69+
execute_process(
70+
COMMAND "${_ninetoothed_python}" ${_ninetoothed_generator_args}
71+
WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}"
72+
RESULT_VARIABLE _ninetoothed_generation_result
73+
)
74+
75+
if(NOT _ninetoothed_generation_result EQUAL 0)
76+
message(FATAL_ERROR "Generating NineToothed operator sources failed with `${_ninetoothed_python}`. Set `NINETOOTHED_PYTHON_EXECUTABLE` to a Python with `ninetoothed`, `ntops`, `triton`, `sympy`, and CUDA dependencies installed.")
77+
endif()
78+
79+
include("${_ninetoothed_output_dir}/manifest.cmake")
80+
target_include_directories(infiniops PRIVATE
81+
${INFINIOPS_NINETOOTHED_INCLUDE_DIRS})
82+
target_sources(infiniops PRIVATE ${INFINIOPS_NINETOOTHED_SOURCES})
83+
endif()
84+
5285
if(WITH_ILUVATAR)
5386
set(ILUVATAR_PATTERNS
5487
"native/cuda/*.cc"
@@ -496,6 +529,9 @@ if(GENERATE_CPP_OPERATOR_API OR GENERATE_PYTHON_BINDINGS)
496529
if(WITH_TORCH)
497530
list(APPEND GENERATOR_ARGS --with-torch)
498531
endif()
532+
if(WITH_NINETOOTHED)
533+
list(APPEND GENERATOR_ARGS --with-ninetoothed)
534+
endif()
499535

500536
execute_process(
501537
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS}
@@ -730,6 +766,10 @@ if(GENERATE_PYTHON_BINDINGS)
730766
${PROJECT_SOURCE_DIR}/include
731767
${PROJECT_SOURCE_DIR}/generated/include
732768
)
769+
if(WITH_NINETOOTHED)
770+
target_include_directories(ops PRIVATE
771+
${INFINIOPS_NINETOOTHED_INCLUDE_DIRS})
772+
endif()
733773
target_link_libraries(ops PRIVATE infiniops)
734774

735775
# Cambricon generated dispatch is compiled into the Python extension and
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import itertools
2+
3+
import ninetoothed
4+
import ntops
5+
6+
_BLOCK_SIZES = (256, 512)
7+
8+
_DTYPES = ("float32", "float16", "bfloat16")
9+
10+
_DEFAULT_NDIMS = (2, 3, 4)
11+
12+
_CONFIGS = tuple(
13+
(
14+
(),
15+
{
16+
"ndim": ndim,
17+
"num_normalized_dims": 1,
18+
"input_dtype": dtype,
19+
"weight_dtype": dtype,
20+
"output_dtype": dtype,
21+
"block_size": block_size,
22+
},
23+
{},
24+
)
25+
for ndim, dtype, block_size in itertools.product(
26+
_DEFAULT_NDIMS,
27+
(getattr(ninetoothed, name) for name in _DTYPES),
28+
_BLOCK_SIZES,
29+
)
30+
)
31+
32+
33+
def build(output_dir):
34+
variant_dir = output_dir / "rms_norm"
35+
variant_dir.mkdir(parents=True, exist_ok=True)
36+
ninetoothed.build(
37+
ntops.kernels.rms_norm.premake,
38+
_CONFIGS,
39+
meta_parameters=("block_size",),
40+
caller="cuda",
41+
kernel_name="infini_ops_ninetoothed_rms_norm",
42+
output_dir=variant_dir,
43+
lazy=False,
44+
)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#ifndef INFINI_OPS_NINETOOTHED_RMS_NORM_H_
2+
#define INFINI_OPS_NINETOOTHED_RMS_NORM_H_
3+
4+
#include <cassert>
5+
#include <cstdint>
6+
#include <vector>
7+
8+
#include "base/rms_norm.h"
9+
#include "data_type.h"
10+
#include "ninetoothed/tensor.h"
11+
#include "rms_norm/infini_ops_ninetoothed_rms_norm.h"
12+
13+
namespace infini::ops {
14+
15+
template <>
16+
class Operator<RmsNorm, Device::Type::kNvidia, 9> : public RmsNorm {
17+
public:
18+
using RmsNorm::RmsNorm;
19+
using RmsNorm::operator();
20+
21+
void operator()(const Tensor input, const Tensor weight, float eps,
22+
Tensor out) const override {
23+
assert(input.dtype() == out.dtype() && out.dtype() == weight.dtype() &&
24+
"operator `RmsNorm` requires all input and output tensors to have "
25+
"the same dtype");
26+
assert(input.shape() == out.shape() &&
27+
"NineToothed `RmsNorm` requires input and output tensors with the "
28+
"same shape");
29+
assert(weight.ndim() == 1 && weight.size(-1) == out.size(-1) &&
30+
"NineToothed `RmsNorm` requires a 1D weight matching the last "
31+
"dimension");
32+
assert(
33+
(out.ndim() == 2 || out.ndim() == 3 || out.ndim() == 4) &&
34+
"NineToothed `RmsNorm` currently supports rank-2, rank-3, and rank-4 "
35+
"tensors");
36+
37+
std::vector<std::uint64_t> weight_sizes;
38+
std::vector<std::int64_t> weight_strides;
39+
double eps_value = static_cast<double>(eps);
40+
std::int64_t num_normalized_elements =
41+
static_cast<std::int64_t>(out.size(-1));
42+
std::uint64_t empty_shape[1] = {};
43+
std::int64_t empty_strides[1] = {};
44+
45+
weight_sizes.assign(out.shape().begin(), out.shape().end());
46+
weight_strides.assign(out.ndim(), 0);
47+
weight_strides.back() =
48+
weight.strides().empty() ? 1 : weight.strides().back();
49+
50+
const int dtype_index = ninetoothed::DataTypeIndex(out.dtype());
51+
assert(
52+
dtype_index >= 0 &&
53+
"NineToothed `RmsNorm` supports only float16, bfloat16, and float32");
54+
55+
ninetoothed::Tensor input_tensor(input);
56+
ninetoothed::Tensor weight_tensor(const_cast<void*>(weight.data()),
57+
weight_sizes.data(),
58+
weight_strides.data());
59+
ninetoothed::Tensor eps_tensor(eps_value, empty_shape, empty_strides);
60+
ninetoothed::Tensor out_tensor(out);
61+
ninetoothed::Tensor num_normalized_elements_tensor(
62+
num_normalized_elements, empty_shape, empty_strides);
63+
64+
auto result = launch_infini_ops_ninetoothed_rms_norm(
65+
static_cast<NineToothedStream>(stream_), input_tensor, weight_tensor,
66+
eps_tensor, out_tensor, num_normalized_elements_tensor,
67+
static_cast<int>(out.ndim()), 1, dtype_index, dtype_index, dtype_index);
68+
69+
assert(result == 0 && "NineToothed `RmsNorm` launch failed");
70+
}
71+
};
72+
73+
} // namespace infini::ops
74+
75+
#endif

src/ninetoothed/tensor.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#ifndef INFINI_OPS_NINETOOTHED_TENSOR_H_
2+
#define INFINI_OPS_NINETOOTHED_TENSOR_H_
3+
4+
#include <cstdint>
5+
#include <type_traits>
6+
7+
#include "data_type.h"
8+
#include "tensor.h"
9+
10+
namespace infini::ops::ninetoothed {
11+
12+
inline int DataTypeIndex(DataType dtype) {
13+
switch (dtype) {
14+
case DataType::kFloat16:
15+
return 8;
16+
case DataType::kBFloat16:
17+
return 9;
18+
case DataType::kFloat32:
19+
return 10;
20+
default:
21+
return -1;
22+
}
23+
}
24+
25+
class Tensor {
26+
public:
27+
explicit Tensor(const ::infini::ops::Tensor& tensor)
28+
: Tensor(const_cast<void*>(tensor.data()),
29+
reinterpret_cast<std::uint64_t*>(
30+
const_cast<::infini::ops::Tensor::Size*>(
31+
tensor.shape().data())),
32+
reinterpret_cast<std::int64_t*>(
33+
const_cast<::infini::ops::Tensor::Stride*>(
34+
tensor.strides().data()))) {
35+
static_assert(sizeof(::infini::ops::Tensor::Size) == sizeof(std::uint64_t));
36+
static_assert(sizeof(::infini::ops::Tensor::Stride) ==
37+
sizeof(std::int64_t));
38+
static_assert(std::is_unsigned_v<::infini::ops::Tensor::Size>);
39+
static_assert(std::is_signed_v<::infini::ops::Tensor::Stride>);
40+
}
41+
42+
Tensor(void* data, std::uint64_t* shape, std::int64_t* strides)
43+
: data_(data), shape_(shape), strides_(strides) {}
44+
45+
template <typename T>
46+
Tensor(T& value, std::uint64_t* shape, std::int64_t* strides)
47+
: Tensor(static_cast<void*>(&value), shape, strides) {}
48+
49+
template <typename NineToothedTensor>
50+
operator NineToothedTensor() const {
51+
return NineToothedTensor{data_, shape_, strides_};
52+
}
53+
54+
private:
55+
void* data_;
56+
57+
std::uint64_t* shape_;
58+
59+
std::int64_t* strides_;
60+
};
61+
62+
} // namespace infini::ops::ninetoothed
63+
64+
#endif

0 commit comments

Comments
 (0)