Skip to content

Commit a02cb31

Browse files
committed
fix: export operator call instantiations
1 parent 6a3aad4 commit a02cb31

6 files changed

Lines changed: 402 additions & 17 deletions

File tree

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ option(BUILD_CUSTOM_KERNEL "Build custom AscendC kernel PyTorch extension (requi
3232

3333
option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
3434
option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF)
35+
option(GENERATE_OPERATOR_CALL_INSTANTIATIONS
36+
"Generate explicit operator call instantiations" ON)
3537
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)
3638

3739
set(_DEFAULT_HYGON_DTK_ROOT "/opt/dtk")

include/infini/ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef INFINI_OPS_H_
2+
#define INFINI_OPS_H_
3+
4+
#ifdef __cplusplus
5+
#include <infini/operator_call_instantiations.h>
6+
#endif
7+
8+
#endif // INFINI_OPS_H_

scripts/generate_wrappers.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,10 +702,113 @@ def _generate_generated_dispatch_source(impl_paths, definitions):
702702
"""
703703

704704

705+
def _strip_top_level_const(type_spelling):
706+
type_spelling = " ".join(type_spelling.split())
707+
708+
while type_spelling.startswith("const "):
709+
type_spelling = type_spelling[len("const ") :]
710+
711+
return type_spelling
712+
713+
714+
def _generate_operator_call_instantiation_entries(operator):
715+
def _generate_template_arguments(node):
716+
return ", ".join(
717+
_strip_top_level_const(arg.type.spelling)
718+
for arg in node.get_arguments()
719+
if arg.spelling != "stream"
720+
)
721+
722+
def _generate_parameters(node):
723+
return ", ".join(
724+
f"const {_strip_top_level_const(arg.type.spelling)}& {arg.spelling}"
725+
for arg in node.get_arguments()
726+
if arg.spelling != "stream"
727+
)
728+
729+
def _append_optional_params(prefix, params):
730+
if params:
731+
return f"{prefix}, {params}"
732+
733+
return prefix
734+
735+
pascal_case_op_name = _snake_to_pascal(operator.name)
736+
declarations = []
737+
definitions = []
738+
739+
for call in operator.calls:
740+
template_arguments = _generate_template_arguments(call)
741+
params = _generate_parameters(call)
742+
function_params = _append_optional_params(
743+
"const Handle& handle, const Config& config", params
744+
)
745+
instantiation = (
746+
f"Operator<{pascal_case_op_name}>::Call<{template_arguments}>"
747+
f"({function_params})"
748+
)
749+
750+
declarations.append(f"extern template auto {instantiation};")
751+
definitions.append(f"template auto {instantiation};")
752+
753+
return declarations, definitions
754+
755+
756+
def _generate_operator_call_instantiation_header(op_names, declarations):
757+
header_base_includes = "\n".join(
758+
f'#include "base/{op_name}.h"' for op_name in op_names
759+
)
760+
761+
return f"""#ifndef INFINI_OPS_OPERATOR_CALL_INSTANTIATIONS_H_
762+
#define INFINI_OPS_OPERATOR_CALL_INSTANTIATIONS_H_
763+
764+
#include <cstdint>
765+
#include <optional>
766+
#include <vector>
767+
768+
#include "config.h"
769+
#include "handle.h"
770+
#include "operator.h"
771+
772+
{header_base_includes}
773+
774+
namespace infini::ops {{
775+
776+
{chr(10).join(declarations)}
777+
778+
}} // namespace infini::ops
779+
780+
#endif
781+
"""
782+
783+
784+
def _generate_operator_call_instantiation_source(devices, impl_paths, definitions):
785+
device_includes = "\n".join(
786+
f'#include "{path}"' for path in _device_marker_headers(devices)
787+
)
788+
impl_includes = "\n".join(
789+
f'#include "{_to_include_path(impl_path)}"' for impl_path in impl_paths
790+
)
791+
792+
return f"""#include "infini/operator_call_instantiations.h"
793+
794+
// clang-format off
795+
{device_includes}
796+
{impl_includes}
797+
// clang-format on
798+
799+
namespace infini::ops {{
800+
801+
{chr(10).join(definitions)}
802+
803+
}} // namespace infini::ops
804+
"""
805+
806+
705807
def _device_marker_headers(devices):
706808
paths = {
707809
"cpu": "native/cpu/device_.h",
708810
"nvidia": "native/cuda/nvidia/device_.h",
811+
"hygon": "native/cuda/hygon/device_.h",
709812
"cambricon": "native/cambricon/device_.h",
710813
"ascend": "native/ascend/device_.h",
711814
"metax": "native/cuda/metax/device_.h",
@@ -819,6 +922,10 @@ def _generate_op_artifacts(item):
819922
dispatch_declarations, dispatch_definitions = _generate_generated_dispatch_entries(
820923
operator
821924
)
925+
(
926+
call_instantiation_declarations,
927+
call_instantiation_definitions,
928+
) = _generate_operator_call_instantiation_entries(operator)
822929

823930
return {
824931
"op_name": op_name,
@@ -830,6 +937,8 @@ def _generate_op_artifacts(item):
830937
"legacy_c_header": legacy_c_header,
831938
"dispatch_declarations": dispatch_declarations,
832939
"dispatch_definitions": dispatch_definitions,
940+
"call_instantiation_declarations": call_instantiation_declarations,
941+
"call_instantiation_definitions": call_instantiation_definitions,
833942
"impl_paths": impl_paths,
834943
}
835944

@@ -918,6 +1027,11 @@ def _dispatch_gen_batch_size():
9181027
for artifact in artifacts
9191028
for declaration in artifact["dispatch_declarations"]
9201029
]
1030+
call_instantiation_declarations = [
1031+
declaration
1032+
for artifact in artifacts
1033+
for declaration in artifact["call_instantiation_declarations"]
1034+
]
9211035
use_monolithic_bindings = _use_monolithic_bindings()
9221036
op_includes = []
9231037

@@ -947,6 +1061,14 @@ def _dispatch_gen_batch_size():
9471061
)
9481062
(_BINDINGS_DIR / "generated_dispatch.h").write_text(dispatch_header)
9491063

1064+
call_instantiation_header = _generate_operator_call_instantiation_header(
1065+
op_names, call_instantiation_declarations
1066+
)
1067+
(_INCLUDE_DIR / "infini").mkdir(exist_ok=True)
1068+
(_INCLUDE_DIR / "infini" / "operator_call_instantiations.h").write_text(
1069+
call_instantiation_header
1070+
)
1071+
9501072
dispatch_batch_size = _dispatch_gen_batch_size()
9511073

9521074
for dispatch_batch_index, start in enumerate(
@@ -968,6 +1090,28 @@ def _dispatch_gen_batch_size():
9681090
dispatch_source
9691091
)
9701092

1093+
for call_instantiation_batch_index, start in enumerate(
1094+
range(0, len(artifacts), dispatch_batch_size)
1095+
):
1096+
batch = artifacts[start : start + dispatch_batch_size]
1097+
impl_paths = list(
1098+
dict.fromkeys(
1099+
impl_path for artifact in batch for impl_path in artifact["impl_paths"]
1100+
)
1101+
)
1102+
definitions = [
1103+
definition
1104+
for artifact in batch
1105+
for definition in artifact["call_instantiation_definitions"]
1106+
]
1107+
call_instantiation_source = _generate_operator_call_instantiation_source(
1108+
args.devices, impl_paths, definitions
1109+
)
1110+
(
1111+
_GENERATED_SRC_DIR
1112+
/ f"operator_call_instantiations_{call_instantiation_batch_index}.cc"
1113+
).write_text(call_instantiation_source)
1114+
9711115
bind_func_calls = "\n".join(
9721116
f"{bind_func_name}(m);" for bind_func_name in bind_func_names
9731117
)

src/CMakeLists.txt

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
add_library(infiniops SHARED)
22

3+
include(GNUInstallDirs)
4+
35
file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc")
46
target_sources(infiniops PRIVATE ${BASE_SRCS})
57

@@ -467,14 +469,20 @@ if(WITH_TORCH)
467469
endif()
468470
endif()
469471

470-
target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
472+
target_include_directories(infiniops
473+
PUBLIC
474+
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
475+
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
476+
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/generated/include>
477+
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>
478+
)
471479

472-
if(GENERATE_PYTHON_BINDINGS)
480+
if(GENERATE_OPERATOR_CALL_INSTANTIATIONS OR GENERATE_PYTHON_BINDINGS)
473481
find_package(Python COMPONENTS Interpreter REQUIRED)
474-
# Always regenerate bindings so the included kernel headers match the
475-
# active device list. Stale generated files (e.g., committed for one
476-
# platform) would omit specializations for other enabled backends,
477-
# causing link-time or runtime failures.
482+
# Always regenerate wrappers so emitted call instantiations and bindings
483+
# match the active device list. Stale generated files would omit
484+
# specializations for enabled backends, causing link-time or runtime
485+
# failures.
478486

479487
set(GENERATOR_ARGS --devices ${DEVICE_LIST})
480488
if(WITH_TORCH)
@@ -492,7 +500,76 @@ if(GENERATE_PYTHON_BINDINGS)
492500
else()
493501
message(STATUS "Generating wrappers - done")
494502
endif()
503+
endif()
504+
505+
if(GENERATE_OPERATOR_CALL_INSTANTIATIONS)
506+
file(GLOB_RECURSE OPERATOR_CALL_INSTANTIATION_SOURCES CONFIGURE_DEPENDS
507+
"${PROJECT_SOURCE_DIR}/generated/src/operator_call_instantiations_*.cc")
508+
509+
if(WITH_NVIDIA OR WITH_HYGON)
510+
set_source_files_properties(${OPERATOR_CALL_INSTANTIATION_SOURCES}
511+
PROPERTIES LANGUAGE CUDA)
512+
target_sources(infiniops PRIVATE ${OPERATOR_CALL_INSTANTIATION_SOURCES})
513+
elseif(WITH_ILUVATAR)
514+
set(_iluvatar_call_instantiation_include_flags
515+
"-I${CMAKE_CURRENT_SOURCE_DIR}"
516+
"-I${PROJECT_SOURCE_DIR}"
517+
"-I${PROJECT_SOURCE_DIR}/generated"
518+
"-I${PROJECT_SOURCE_DIR}/generated/include")
519+
foreach(_dir IN LISTS TORCH_INCLUDE_DIRS CUDAToolkit_INCLUDE_DIRS)
520+
list(APPEND _iluvatar_call_instantiation_include_flags "-I${_dir}")
521+
endforeach()
522+
523+
set(_iluvatar_call_instantiation_defs -DWITH_ILUVATAR=1)
524+
if(WITH_CPU)
525+
list(APPEND _iluvatar_call_instantiation_defs -DWITH_CPU=1)
526+
endif()
527+
if(WITH_TORCH)
528+
list(APPEND _iluvatar_call_instantiation_defs -DWITH_TORCH=1)
529+
endif()
530+
if(DEFINED TORCH_CXX11_ABI)
531+
list(APPEND _iluvatar_call_instantiation_defs
532+
"-D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
533+
endif()
495534

535+
set(ILUVATAR_CALL_INSTANTIATION_OBJECTS)
536+
set(_iluvatar_call_instantiation_object_dir
537+
"${CMAKE_CURRENT_BINARY_DIR}/iluvatar_call_instantiation_objs")
538+
foreach(_src IN LISTS OPERATOR_CALL_INSTANTIATION_SOURCES)
539+
get_filename_component(_name "${_src}" NAME_WE)
540+
set(_obj "${_iluvatar_call_instantiation_object_dir}/${_name}.o")
541+
set(_dep "${_obj}.d")
542+
set(_depfile_arg)
543+
if(CMAKE_GENERATOR MATCHES "Ninja")
544+
set(_depfile_arg DEPFILE "${_dep}")
545+
endif()
546+
add_custom_command(
547+
OUTPUT "${_obj}"
548+
COMMAND ${CMAKE_COMMAND} -E make_directory
549+
"${_iluvatar_call_instantiation_object_dir}"
550+
COMMAND ${ILUVATAR_CUDA_COMPILER}
551+
${_iluvatar_call_instantiation_defs}
552+
${_iluvatar_call_instantiation_include_flags}
553+
${ILUVATAR_CUDA_FLAGS}
554+
-MMD -MF "${_dep}"
555+
-c "${_src}" -o "${_obj}"
556+
DEPENDS "${_src}"
557+
${_depfile_arg}
558+
COMMENT "Compiling ${_name}.cc with CoreX clang++"
559+
VERBATIM
560+
)
561+
list(APPEND ILUVATAR_CALL_INSTANTIATION_OBJECTS "${_obj}")
562+
endforeach()
563+
564+
set_source_files_properties(${ILUVATAR_CALL_INSTANTIATION_OBJECTS}
565+
PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
566+
target_sources(infiniops PRIVATE ${ILUVATAR_CALL_INSTANTIATION_OBJECTS})
567+
else()
568+
target_sources(infiniops PRIVATE ${OPERATOR_CALL_INSTANTIATION_SOURCES})
569+
endif()
570+
endif()
571+
572+
if(GENERATE_PYTHON_BINDINGS)
496573
file(GLOB_RECURSE PYBIND11_SOURCES CONFIGURE_DEPENDS
497574
"${PROJECT_SOURCE_DIR}/generated/bindings/*.cc")
498575

@@ -675,3 +752,44 @@ if(GENERATE_PYTHON_BINDINGS)
675752
DESTINATION .)
676753
endif()
677754
endif()
755+
756+
install(TARGETS infiniops
757+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
758+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
759+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
760+
)
761+
762+
install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/
763+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
764+
)
765+
766+
if(GENERATE_OPERATOR_CALL_INSTANTIATIONS)
767+
install(FILES
768+
${PROJECT_SOURCE_DIR}/generated/include/infini/operator_call_instantiations.h
769+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/infini
770+
)
771+
endif()
772+
773+
file(GLOB INFINIOPS_PUBLIC_CORE_HEADERS CONFIGURE_DEPENDS
774+
"${CMAKE_CURRENT_SOURCE_DIR}/*.h")
775+
776+
install(FILES ${INFINIOPS_PUBLIC_CORE_HEADERS}
777+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
778+
)
779+
780+
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/base/
781+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/base
782+
FILES_MATCHING PATTERN "*.h"
783+
)
784+
785+
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/common/
786+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/common
787+
FILES_MATCHING PATTERN "*.h"
788+
)
789+
790+
if(EXISTS ${PROJECT_SOURCE_DIR}/generated/base)
791+
install(DIRECTORY ${PROJECT_SOURCE_DIR}/generated/base/
792+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/base
793+
FILES_MATCHING PATTERN "*.h"
794+
)
795+
endif()

0 commit comments

Comments
 (0)