Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,27 @@ endif()

if(WITH_NVIDIA)
add_compile_definitions(WITH_NVIDIA=1)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
set(CMAKE_CUDA_ARCHITECTURES "native")
else()
# Detect GPU architecture via `nvidia-smi`.
execute_process(
COMMAND nvidia-smi --query-gpu=compute_cap --format=csv,noheader
OUTPUT_VARIABLE _gpu_cc OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET RESULT_VARIABLE _nvsmi_result
)
if(_nvsmi_result EQUAL 0 AND _gpu_cc)
string(REGEX MATCH "^[0-9]+\\.[0-9]+" _first_cc "${_gpu_cc}")
string(REPLACE "." "" _arch "${_first_cc}")
set(CMAKE_CUDA_ARCHITECTURES "${_arch}")
message(STATUS "Auto-detected CUDA architecture: `sm_${_arch}`.")
else()
message(WARNING "Could not detect GPU architecture; defaulting to `sm_80`.")
set(CMAKE_CUDA_ARCHITECTURES "80")
endif()
endif()
endif()
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
endif()
Expand Down
190 changes: 190 additions & 0 deletions scripts/_generate_legacy_c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from _operator_utils import snake_to_pascal

# Override `PascalCase` names to match InfiniCore's existing C API conventions.
_C_API_NAME_OVERRIDES = {
"rms_norm": "RMSNorm",
"swiglu": "SwiGLU",
}

# Override which constructor/call overload index to use per operator.
# Default is `-1` (the last one, typically the simplest).
_CONSTRUCTOR_INDEX_OVERRIDES = {
"rms_norm": 0,
}

_CALL_INDEX_OVERRIDES = {}


def generate_legacy_c(operator, paths):
# The C++ class name from InfiniOps (e.g. `RmsNorm`, `Swiglu`).
cpp_name = snake_to_pascal(operator.name)
# The C API name, which may differ from the C++ class name.
pascal_name = _C_API_NAME_OVERRIDES.get(operator.name, cpp_name)
constructor_index = _CONSTRUCTOR_INDEX_OVERRIDES.get(operator.name, -1)
call_index = _CALL_INDEX_OVERRIDES.get(operator.name, -1)

def _generate_source(operator):
return f"""#include "../../handle.h"
#include "../../tensor.h"
#include "infiniop/ops/{operator.name}.h"
#include "base/{operator.name}.h"
#include "make.h"

static infini::ops::DataType DataTypeFromInfiniDType(
const infiniDtype_t& dtype) {{
static constexpr infini::ops::ConstexprMap<infiniDtype_t,
infini::ops::DataType, 12>
kInfiniDTypeToDataType{{
{{{{{{INFINI_DTYPE_I8, infini::ops::DataType::kInt8}},
{{INFINI_DTYPE_I16, infini::ops::DataType::kInt16}},
{{INFINI_DTYPE_I32, infini::ops::DataType::kInt32}},
{{INFINI_DTYPE_I64, infini::ops::DataType::kInt64}},
{{INFINI_DTYPE_U8, infini::ops::DataType::kUInt8}},
{{INFINI_DTYPE_U16, infini::ops::DataType::kUInt16}},
{{INFINI_DTYPE_U32, infini::ops::DataType::kUInt32}},
{{INFINI_DTYPE_U64, infini::ops::DataType::kUInt64}},
{{INFINI_DTYPE_F16, infini::ops::DataType::kFloat16}},
{{INFINI_DTYPE_BF16, infini::ops::DataType::kBFloat16}},
{{INFINI_DTYPE_F32, infini::ops::DataType::kFloat32}},
{{INFINI_DTYPE_F64, infini::ops::DataType::kFloat64}}}}}}}};

return kInfiniDTypeToDataType.at(dtype);
}}

static infini::ops::Device::Type DeviceTypeFromInfiniDevice(
const infiniDevice_t& device) {{
static constexpr infini::ops::ConstexprMap<
infiniDevice_t, infini::ops::Device::Type,
static_cast<std::size_t>(INFINI_DEVICE_TYPE_COUNT)>
kInfiniDeviceToDeviceType{{
{{{{{{INFINI_DEVICE_CPU, infini::ops::Device::Type::kCpu}},
{{INFINI_DEVICE_NVIDIA, infini::ops::Device::Type::kNvidia}},
{{INFINI_DEVICE_CAMBRICON, infini::ops::Device::Type::kCambricon}},
{{INFINI_DEVICE_ASCEND, infini::ops::Device::Type::kAscend}},
{{INFINI_DEVICE_METAX, infini::ops::Device::Type::kMetax}},
{{INFINI_DEVICE_MOORE, infini::ops::Device::Type::kMoore}},
{{INFINI_DEVICE_ILUVATAR, infini::ops::Device::Type::kIluvatar}},
{{INFINI_DEVICE_KUNLUN, infini::ops::Device::Type::kKunlun}},
{{INFINI_DEVICE_HYGON, infini::ops::Device::Type::kHygon}},
{{INFINI_DEVICE_QY, infini::ops::Device::Type::kQy}}}}}}}};

return kInfiniDeviceToDeviceType.at(device);
}}

__INFINI_C {_generate_create_func_def(operator)}

__INFINI_C {_generate_get_workspace_size_func_def(operator)}

__INFINI_C {_generate_call_func_def(operator)}

__INFINI_C {_generate_destroy_func_def(operator)}
"""

def _generate_header(operator):
return f"""#ifndef __INFINIOP_{operator.name.upper()}_API_H__
#define __INFINIOP_{operator.name.upper()}_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniop{pascal_name}Descriptor_t;

__INFINI_C __export {_generate_create_func_decl(operator)};

__INFINI_C __export {_generate_get_workspace_size_func_decl(operator)};

__INFINI_C __export {_generate_call_func_decl(operator)};

__INFINI_C __export {_generate_destroy_func_decl(operator)};

#endif
"""

def _generate_create_func_def(operator):
constructor = operator.constructors[constructor_index]

return f"""{_generate_create_func_decl(operator)} {{
*desc_ptr = reinterpret_cast<infiniop{pascal_name}Descriptor_t>(infini::ops::Make{cpp_name}({{}}, {_generate_arguments(constructor)}).release());

return INFINI_STATUS_SUCCESS;
}}"""

def _generate_get_workspace_size_func_def(operator):
return f"""{_generate_get_workspace_size_func_decl(operator)} {{
*size = 0; // desc->workspace_size();

return INFINI_STATUS_SUCCESS;
}}"""

def _generate_call_func_def(operator):
call = operator.calls[call_index]

return f"""{_generate_call_func_decl(operator)} {{
auto *op = reinterpret_cast<infini::ops::OperatorBase *>(desc);
op->set_stream(stream);
op->set_workspace(workspace);
op->set_workspace_size_in_bytes(workspace_size);
static_cast<const infini::ops::{cpp_name} &>(*op)({_generate_arguments(call, is_data=True)});

return INFINI_STATUS_SUCCESS;
}}"""

def _generate_destroy_func_def(operator):
return f"""{_generate_destroy_func_decl(operator)} {{
delete reinterpret_cast<infini::ops::OperatorBase *>(desc);

return INFINI_STATUS_SUCCESS;
}}"""

def _generate_create_func_decl(operator):
constructor = operator.constructors[constructor_index]
params = _generate_params(constructor)

return f"infiniStatus_t infiniopCreate{pascal_name}Descriptor(infiniopHandle_t handle, infiniop{pascal_name}Descriptor_t *desc_ptr, {params})"

def _generate_get_workspace_size_func_decl(operator):
return f"infiniStatus_t infiniopGet{pascal_name}WorkspaceSize(infiniop{pascal_name}Descriptor_t desc, size_t *size)"

def _generate_call_func_decl(operator):
call = operator.calls[call_index]
params = _generate_params(call, call=True)
params = params.replace("void * stream, ", "")

return f"infiniStatus_t infiniop{pascal_name}(infiniop{pascal_name}Descriptor_t desc, void *workspace, size_t workspace_size, {params}, void *stream)"

def _generate_destroy_func_decl(operator):
return f"infiniStatus_t infiniopDestroy{pascal_name}Descriptor(infiniop{pascal_name}Descriptor_t desc)"

def _generate_params(node, call=False):
arguments = tuple(node.get_arguments())
arguments = (arguments[-1], *arguments[:-1])

def _handle_tensor(spelling):
if call:
return spelling.replace("Tensor", "void *")

return spelling.replace("Tensor", "infiniopTensorDescriptor_t")

def _handle_std_optional(spelling):
return spelling.replace("std::optional<", "").replace(">", "")

return ", ".join(
f"{_handle_std_optional(_handle_tensor(arg.type.spelling))} {arg.spelling}"
for arg in arguments
)

def _generate_arguments(node, is_data=False):
return ", ".join(
_generate_tensor_caster(arg.spelling, is_data=is_data)
if "Tensor" in arg.type.spelling
else arg.spelling
for arg in node.get_arguments()
if arg.spelling != "handle" and arg.spelling != "stream"
)

def _generate_tensor_caster(name, is_data=False):
if is_data:
return f"infini::ops::Tensor(const_cast<void *>({name}), infini::ops::Tensor::Shape{{}})"

return f"infini::ops::Tensor{{nullptr, {name}->shape(), DataTypeFromInfiniDType({name}->dtype()), infini::ops::Device{{DeviceTypeFromInfiniDevice(handle->device), handle->device_id}}, {name}->strides()}}"

return _generate_source(operator), _generate_header(operator)
104 changes: 104 additions & 0 deletions scripts/_generate_pybind11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from _operator_utils import snake_to_pascal


def generate_pybind11(operator):
def _generate_params(node):
return (
", ".join(
f"{arg.type.spelling} {arg.spelling}"
for arg in node.get_arguments()
if arg.spelling != "stream"
)
.replace("const Tensor", "py::object")
.replace("Tensor", "py::object")
)

def _generate_arguments(node):
return ", ".join(
f"TensorFromPybind11Handle({arg.spelling})"
if "Tensor" in arg.type.spelling
else arg.spelling
for arg in node.get_arguments()
if arg.spelling != "stream"
)

op_name = operator.name

def _generate_init(constructor):
constructor_params = _generate_params(constructor)

return f""" .def(py::init([]({constructor_params}) {{
return std::unique_ptr<Self>{{static_cast<Self*>(Self::make({_generate_arguments(constructor)}).release())}};
}}))"""

def _generate_py_args(node):
return ", ".join(
f'py::arg("{arg.spelling}")'
for arg in node.get_arguments()
if arg.spelling != "stream"
)

def _generate_call(op_name, call, method=True):
call_params = _generate_params(call)
call_args = _generate_arguments(call)

if not method:
params = (
f"{call_params}, std::size_t implementation_index"
if call_params
else "std::size_t implementation_index"
)
py_args = _generate_py_args(call)
py_args_str = f"{py_args}, " if py_args else ""

return f""" m.def("{op_name}", []({params}) {{
Config config;
config.set_implementation_index(implementation_index);
return Self::call({{}}, config, {call_args});
}}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0);"""

return f""" .def("__call__", [](const Self& self, {call_params}) {{
return static_cast<const Operator<Self>&>(self)({call_args});
}})"""

inits = "\n".join(
_generate_init(constructor) for constructor in operator.constructors
)
calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls)
callers = "\n".join(
_generate_call(operator.name, call, method=False) for call in operator.calls
)

pascal_case_op_name = snake_to_pascal(op_name)

return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_
#define INFINI_OPS_BINDINGS_{op_name.upper()}_H_

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "base/{op_name}.h"
#include "config.h"
#include "pybind11_utils.h"

namespace py = pybind11;

namespace infini::ops {{

void Bind{pascal_case_op_name}(py::module& m) {{
using Self = {pascal_case_op_name};

py::class_<Self>(m, "{pascal_case_op_name}")
{inits}
{calls}
.def_static("active_implementation_indices", [](const std::string& device) {{
return Self::active_implementation_indices(DeviceTypeFromString(device));
}});

{callers}
}}

}} // namespace infini::ops

#endif
"""
Loading