Skip to content

Commit c16b0a6

Browse files
voltjiabitzyz
authored andcommitted
feat(nvidia): add ntops RMSNorm backend (#616)
1 parent d6804bf commit c16b0a6

10 files changed

Lines changed: 723 additions & 2 deletions

File tree

CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ option(WITH_ASCEND "Enable Ascend backend" OFF)
2525

2626
option(WITH_TORCH "Enable PyTorch C++ backend" OFF)
2727

28+
option(WITH_NINETOOTHED "Enable NineToothed-generated kernels" OFF)
29+
2830
# Default OFF until CANN's `extract_host_stub.py` path handling is fixed for
2931
# `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed
3032
# object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the
@@ -293,6 +295,14 @@ if(_gpu_backend_count GREATER 1)
293295
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_HYGON`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.")
294296
endif()
295297

298+
if(WITH_NINETOOTHED AND NOT WITH_NVIDIA)
299+
message(FATAL_ERROR "`WITH_NINETOOTHED` currently requires `WITH_NVIDIA=ON` because NineToothed AOT uses `caller=\"cuda\"`.")
300+
endif()
301+
302+
if(WITH_NINETOOTHED)
303+
set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run NineToothed code generation")
304+
endif()
305+
296306
if(WITH_NVIDIA)
297307
add_compile_definitions(WITH_NVIDIA=1)
298308
enable_language(CUDA)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import argparse
2+
import importlib.util
3+
import pathlib
4+
import shutil
5+
import sys
6+
7+
_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[1]
8+
_OPS_DIR = _PROJECT_DIR / "src" / "ninetoothed" / "ops"
9+
10+
11+
def _find_op_modules():
12+
return {
13+
path.parent.name: path
14+
for path in sorted(_OPS_DIR.glob("*/build.py"))
15+
if path.is_file()
16+
}
17+
18+
19+
def _build_manifest(output_dir):
20+
return sorted(
21+
str(path)
22+
for path in pathlib.Path(output_dir).rglob("*.cpp")
23+
if not path.name.endswith(".tmp.cpp")
24+
)
25+
26+
27+
def _write_cmake_manifest(output_dir, sources):
28+
manifest_path = pathlib.Path(output_dir) / "manifest.cmake"
29+
lines = ["set(INFINIOPS_NINETOOTHED_SOURCES"]
30+
lines.extend(f' "{source}"' for source in sources)
31+
lines.append(")")
32+
lines.append("")
33+
lines.append(f'set(INFINIOPS_NINETOOTHED_INCLUDE_DIRS "{output_dir}")')
34+
lines.append("")
35+
manifest_path.write_text("\n".join(lines) + "\n")
36+
37+
38+
def _load_op_module(op):
39+
path = _find_op_modules()[op]
40+
sys.path.insert(0, str(path.parent))
41+
spec = importlib.util.spec_from_file_location(path.stem, path)
42+
module = importlib.util.module_from_spec(spec)
43+
assert spec.loader is not None
44+
sys.modules[spec.name] = module
45+
spec.loader.exec_module(module)
46+
47+
return module
48+
49+
50+
def generate(ops, *, output_dir):
51+
op_modules = _find_op_modules()
52+
unknown_ops = tuple(op for op in ops if op not in op_modules)
53+
54+
if unknown_ops:
55+
raise ValueError(f"unsupported NineToothed ops: {', '.join(unknown_ops)}")
56+
57+
output_dir = pathlib.Path(output_dir)
58+
shutil.rmtree(output_dir, ignore_errors=True)
59+
output_dir.mkdir(parents=True, exist_ok=True)
60+
61+
for op in ops:
62+
module = _load_op_module(op)
63+
module.build(output_dir)
64+
65+
sources = _build_manifest(output_dir)
66+
_write_cmake_manifest(output_dir, sources)
67+
68+
return sources
69+
70+
71+
def _parse_args():
72+
parser = argparse.ArgumentParser(
73+
description="Generate NineToothed operator sources for InfiniOps."
74+
)
75+
parser.add_argument("--output-dir", required=True)
76+
parser.add_argument("--ops", nargs="+", default=tuple(_find_op_modules()))
77+
78+
return parser.parse_args()
79+
80+
81+
def main():
82+
args = _parse_args()
83+
generate(args.ops, output_dir=args.output_dir)
84+
85+
86+
if __name__ == "__main__":
87+
main()

scripts/generate_wrappers.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,11 +1017,13 @@ def _index_impl_headers(impl_roots, scan_dirs):
10171017
return by_operator
10181018

10191019

1020-
def _get_all_ops(devices, with_torch=False):
1020+
def _get_all_ops(devices, with_torch=False, with_ninetoothed=False):
10211021
scan_dirs = set(devices)
10221022

10231023
if with_torch:
10241024
scan_dirs.add("torch")
1025+
if with_ninetoothed:
1026+
scan_dirs.add("ninetoothed")
10251027

10261028
ops = {}
10271029

@@ -1140,6 +1142,11 @@ def _dispatch_gen_batch_size():
11401142
action="store_true",
11411143
help="Include PyTorch C++ backend implementations.",
11421144
)
1145+
parser.add_argument(
1146+
"--with-ninetoothed",
1147+
action="store_true",
1148+
help="Include NineToothed backend implementations.",
1149+
)
11431150

11441151
args = parser.parse_args()
11451152

@@ -1159,7 +1166,11 @@ def _dispatch_gen_batch_size():
11591166
if ops_json.exists():
11601167
ops = json.loads(ops_json.read_text())
11611168
else:
1162-
ops = _get_all_ops(args.devices, with_torch=args.with_torch)
1169+
ops = _get_all_ops(
1170+
args.devices,
1171+
with_torch=args.with_torch,
1172+
with_ninetoothed=args.with_ninetoothed,
1173+
)
11631174

11641175
bind_func_names = []
11651176

src/CMakeLists.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,39 @@ if(WITH_NVIDIA)
4949
)
5050
endif()
5151

52+
if(WITH_NINETOOTHED)
53+
find_package(Python COMPONENTS Interpreter REQUIRED)
54+
55+
if(NINETOOTHED_PYTHON_EXECUTABLE)
56+
set(_ninetoothed_python "${NINETOOTHED_PYTHON_EXECUTABLE}")
57+
elseif(_TORCH_PYTHON)
58+
set(_ninetoothed_python "${_TORCH_PYTHON}")
59+
else()
60+
set(_ninetoothed_python "${Python_EXECUTABLE}")
61+
endif()
62+
message(STATUS "NineToothed codegen Python: ${_ninetoothed_python}")
63+
64+
set(_ninetoothed_output_dir "${CMAKE_CURRENT_BINARY_DIR}/ninetoothed")
65+
set(_ninetoothed_generator_args
66+
"${PROJECT_SOURCE_DIR}/scripts/generate_ninetoothed_ops.py"
67+
--output-dir "${_ninetoothed_output_dir}")
68+
69+
execute_process(
70+
COMMAND "${_ninetoothed_python}" ${_ninetoothed_generator_args}
71+
WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}"
72+
RESULT_VARIABLE _ninetoothed_generation_result
73+
)
74+
75+
if(NOT _ninetoothed_generation_result EQUAL 0)
76+
message(FATAL_ERROR "Generating NineToothed operator sources failed with `${_ninetoothed_python}`. Set `NINETOOTHED_PYTHON_EXECUTABLE` to a Python with `ninetoothed`, `ntops`, `triton`, `sympy`, and CUDA dependencies installed.")
77+
endif()
78+
79+
include("${_ninetoothed_output_dir}/manifest.cmake")
80+
target_include_directories(infiniops PRIVATE
81+
${INFINIOPS_NINETOOTHED_INCLUDE_DIRS})
82+
target_sources(infiniops PRIVATE ${INFINIOPS_NINETOOTHED_SOURCES})
83+
endif()
84+
5285
if(WITH_ILUVATAR)
5386
set(ILUVATAR_PATTERNS
5487
"native/cuda/*.cc"
@@ -496,6 +529,9 @@ if(GENERATE_CPP_OPERATOR_API OR GENERATE_PYTHON_BINDINGS)
496529
if(WITH_TORCH)
497530
list(APPEND GENERATOR_ARGS --with-torch)
498531
endif()
532+
if(WITH_NINETOOTHED)
533+
list(APPEND GENERATOR_ARGS --with-ninetoothed)
534+
endif()
499535

500536
execute_process(
501537
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS}
@@ -730,6 +766,10 @@ if(GENERATE_PYTHON_BINDINGS)
730766
${PROJECT_SOURCE_DIR}/include
731767
${PROJECT_SOURCE_DIR}/generated/include
732768
)
769+
if(WITH_NINETOOTHED)
770+
target_include_directories(ops PRIVATE
771+
${INFINIOPS_NINETOOTHED_INCLUDE_DIRS})
772+
endif()
733773
target_link_libraries(ops PRIVATE infiniops)
734774

735775
# Cambricon generated dispatch is compiled into the Python extension and

0 commit comments

Comments
 (0)