Skip to content

Commit ca40c6f

Browse files
committed
fix: export operator call instantiations
1 parent ed66e76 commit ca40c6f

6 files changed

Lines changed: 402 additions & 17 deletions

File tree

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ option(BUILD_CUSTOM_KERNEL "Build custom AscendC kernel PyTorch extension (requi
3434

3535
option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
3636
option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF)
37+
option(GENERATE_OPERATOR_CALL_INSTANTIATIONS
38+
"Generate explicit operator call instantiations" ON)
3739
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)
3840

3941
set(_DEFAULT_HYGON_DTK_ROOT "/opt/dtk")

include/infini/ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef INFINI_OPS_H_
2+
#define INFINI_OPS_H_
3+
4+
#ifdef __cplusplus
5+
#include <infini/operator_call_instantiations.h>
6+
#endif
7+
8+
#endif // INFINI_OPS_H_

scripts/generate_wrappers.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,10 +702,113 @@ def _generate_generated_dispatch_source(impl_paths, definitions):
702702
"""
703703

704704

705+
def _strip_top_level_const(type_spelling):
706+
type_spelling = " ".join(type_spelling.split())
707+
708+
while type_spelling.startswith("const "):
709+
type_spelling = type_spelling[len("const ") :]
710+
711+
return type_spelling
712+
713+
714+
def _generate_operator_call_instantiation_entries(operator):
715+
def _generate_template_arguments(node):
716+
return ", ".join(
717+
_strip_top_level_const(arg.type.spelling)
718+
for arg in node.get_arguments()
719+
if arg.spelling != "stream"
720+
)
721+
722+
def _generate_parameters(node):
723+
return ", ".join(
724+
f"const {_strip_top_level_const(arg.type.spelling)}& {arg.spelling}"
725+
for arg in node.get_arguments()
726+
if arg.spelling != "stream"
727+
)
728+
729+
def _append_optional_params(prefix, params):
730+
if params:
731+
return f"{prefix}, {params}"
732+
733+
return prefix
734+
735+
pascal_case_op_name = _snake_to_pascal(operator.name)
736+
declarations = []
737+
definitions = []
738+
739+
for call in operator.calls:
740+
template_arguments = _generate_template_arguments(call)
741+
params = _generate_parameters(call)
742+
function_params = _append_optional_params(
743+
"const Handle& handle, const Config& config", params
744+
)
745+
instantiation = (
746+
f"Operator<{pascal_case_op_name}>::Call<{template_arguments}>"
747+
f"({function_params})"
748+
)
749+
750+
declarations.append(f"extern template auto {instantiation};")
751+
definitions.append(f"template auto {instantiation};")
752+
753+
return declarations, definitions
754+
755+
756+
def _generate_operator_call_instantiation_header(op_names, declarations):
757+
header_base_includes = "\n".join(
758+
f'#include "base/{op_name}.h"' for op_name in op_names
759+
)
760+
761+
return f"""#ifndef INFINI_OPS_OPERATOR_CALL_INSTANTIATIONS_H_
762+
#define INFINI_OPS_OPERATOR_CALL_INSTANTIATIONS_H_
763+
764+
#include <cstdint>
765+
#include <optional>
766+
#include <vector>
767+
768+
#include "config.h"
769+
#include "handle.h"
770+
#include "operator.h"
771+
772+
{header_base_includes}
773+
774+
namespace infini::ops {{
775+
776+
{chr(10).join(declarations)}
777+
778+
}} // namespace infini::ops
779+
780+
#endif
781+
"""
782+
783+
784+
def _generate_operator_call_instantiation_source(devices, impl_paths, definitions):
785+
device_includes = "\n".join(
786+
f'#include "{path}"' for path in _device_marker_headers(devices)
787+
)
788+
impl_includes = "\n".join(
789+
f'#include "{_to_include_path(impl_path)}"' for impl_path in impl_paths
790+
)
791+
792+
return f"""#include "infini/operator_call_instantiations.h"
793+
794+
// clang-format off
795+
{device_includes}
796+
{impl_includes}
797+
// clang-format on
798+
799+
namespace infini::ops {{
800+
801+
{chr(10).join(definitions)}
802+
803+
}} // namespace infini::ops
804+
"""
805+
806+
705807
def _device_marker_headers(devices):
706808
paths = {
707809
"cpu": "native/cpu/device_.h",
708810
"nvidia": "native/cuda/nvidia/device_.h",
811+
"hygon": "native/cuda/hygon/device_.h",
709812
"cambricon": "native/cambricon/device_.h",
710813
"ascend": "native/ascend/device_.h",
711814
"metax": "native/cuda/metax/device_.h",
@@ -821,6 +924,10 @@ def _generate_op_artifacts(item):
821924
dispatch_declarations, dispatch_definitions = _generate_generated_dispatch_entries(
822925
operator
823926
)
927+
(
928+
call_instantiation_declarations,
929+
call_instantiation_definitions,
930+
) = _generate_operator_call_instantiation_entries(operator)
824931

825932
return {
826933
"op_name": op_name,
@@ -832,6 +939,8 @@ def _generate_op_artifacts(item):
832939
"legacy_c_header": legacy_c_header,
833940
"dispatch_declarations": dispatch_declarations,
834941
"dispatch_definitions": dispatch_definitions,
942+
"call_instantiation_declarations": call_instantiation_declarations,
943+
"call_instantiation_definitions": call_instantiation_definitions,
835944
"impl_paths": impl_paths,
836945
}
837946

@@ -929,6 +1038,11 @@ def _dispatch_gen_batch_size():
9291038
for artifact in artifacts
9301039
for declaration in artifact["dispatch_declarations"]
9311040
]
1041+
call_instantiation_declarations = [
1042+
declaration
1043+
for artifact in artifacts
1044+
for declaration in artifact["call_instantiation_declarations"]
1045+
]
9321046
use_monolithic_bindings = _use_monolithic_bindings()
9331047
op_includes = []
9341048

@@ -958,6 +1072,14 @@ def _dispatch_gen_batch_size():
9581072
)
9591073
(_BINDINGS_DIR / "generated_dispatch.h").write_text(dispatch_header)
9601074

1075+
call_instantiation_header = _generate_operator_call_instantiation_header(
1076+
op_names, call_instantiation_declarations
1077+
)
1078+
(_INCLUDE_DIR / "infini").mkdir(exist_ok=True)
1079+
(_INCLUDE_DIR / "infini" / "operator_call_instantiations.h").write_text(
1080+
call_instantiation_header
1081+
)
1082+
9611083
dispatch_batch_size = _dispatch_gen_batch_size()
9621084

9631085
for dispatch_batch_index, start in enumerate(
@@ -979,6 +1101,28 @@ def _dispatch_gen_batch_size():
9791101
dispatch_source
9801102
)
9811103

1104+
for call_instantiation_batch_index, start in enumerate(
1105+
range(0, len(artifacts), dispatch_batch_size)
1106+
):
1107+
batch = artifacts[start : start + dispatch_batch_size]
1108+
impl_paths = list(
1109+
dict.fromkeys(
1110+
impl_path for artifact in batch for impl_path in artifact["impl_paths"]
1111+
)
1112+
)
1113+
definitions = [
1114+
definition
1115+
for artifact in batch
1116+
for definition in artifact["call_instantiation_definitions"]
1117+
]
1118+
call_instantiation_source = _generate_operator_call_instantiation_source(
1119+
args.devices, impl_paths, definitions
1120+
)
1121+
(
1122+
_GENERATED_SRC_DIR
1123+
/ f"operator_call_instantiations_{call_instantiation_batch_index}.cc"
1124+
).write_text(call_instantiation_source)
1125+
9821126
bind_func_calls = "\n".join(
9831127
f"{bind_func_name}(m);" for bind_func_name in bind_func_names
9841128
)

src/CMakeLists.txt

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
add_library(infiniops SHARED)
22

3+
include(GNUInstallDirs)
4+
35
file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc")
46
target_sources(infiniops PRIVATE ${BASE_SRCS})
57

@@ -500,14 +502,20 @@ if(WITH_TORCH)
500502
endif()
501503
endif()
502504

503-
target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
505+
target_include_directories(infiniops
506+
PUBLIC
507+
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
508+
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
509+
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/generated/include>
510+
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>
511+
)
504512

505-
if(GENERATE_PYTHON_BINDINGS)
513+
if(GENERATE_OPERATOR_CALL_INSTANTIATIONS OR GENERATE_PYTHON_BINDINGS)
506514
find_package(Python COMPONENTS Interpreter REQUIRED)
507-
# Always regenerate bindings so the included kernel headers match the
508-
# active device list. Stale generated files (e.g., committed for one
509-
# platform) would omit specializations for other enabled backends,
510-
# causing link-time or runtime failures.
515+
# Always regenerate wrappers so emitted call instantiations and bindings
516+
# match the active device list. Stale generated files would omit
517+
# specializations for enabled backends, causing link-time or runtime
518+
# failures.
511519

512520
set(GENERATOR_ARGS --devices ${DEVICE_LIST})
513521
if(WITH_TORCH)
@@ -528,7 +536,76 @@ if(GENERATE_PYTHON_BINDINGS)
528536
else()
529537
message(STATUS "Generating wrappers - done")
530538
endif()
539+
endif()
540+
541+
if(GENERATE_OPERATOR_CALL_INSTANTIATIONS)
542+
file(GLOB_RECURSE OPERATOR_CALL_INSTANTIATION_SOURCES CONFIGURE_DEPENDS
543+
"${PROJECT_SOURCE_DIR}/generated/src/operator_call_instantiations_*.cc")
544+
545+
if(WITH_NVIDIA OR WITH_HYGON)
546+
set_source_files_properties(${OPERATOR_CALL_INSTANTIATION_SOURCES}
547+
PROPERTIES LANGUAGE CUDA)
548+
target_sources(infiniops PRIVATE ${OPERATOR_CALL_INSTANTIATION_SOURCES})
549+
elseif(WITH_ILUVATAR)
550+
set(_iluvatar_call_instantiation_include_flags
551+
"-I${CMAKE_CURRENT_SOURCE_DIR}"
552+
"-I${PROJECT_SOURCE_DIR}"
553+
"-I${PROJECT_SOURCE_DIR}/generated"
554+
"-I${PROJECT_SOURCE_DIR}/generated/include")
555+
foreach(_dir IN LISTS TORCH_INCLUDE_DIRS CUDAToolkit_INCLUDE_DIRS)
556+
list(APPEND _iluvatar_call_instantiation_include_flags "-I${_dir}")
557+
endforeach()
558+
559+
set(_iluvatar_call_instantiation_defs -DWITH_ILUVATAR=1)
560+
if(WITH_CPU)
561+
list(APPEND _iluvatar_call_instantiation_defs -DWITH_CPU=1)
562+
endif()
563+
if(WITH_TORCH)
564+
list(APPEND _iluvatar_call_instantiation_defs -DWITH_TORCH=1)
565+
endif()
566+
if(DEFINED TORCH_CXX11_ABI)
567+
list(APPEND _iluvatar_call_instantiation_defs
568+
"-D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
569+
endif()
531570

571+
set(ILUVATAR_CALL_INSTANTIATION_OBJECTS)
572+
set(_iluvatar_call_instantiation_object_dir
573+
"${CMAKE_CURRENT_BINARY_DIR}/iluvatar_call_instantiation_objs")
574+
foreach(_src IN LISTS OPERATOR_CALL_INSTANTIATION_SOURCES)
575+
get_filename_component(_name "${_src}" NAME_WE)
576+
set(_obj "${_iluvatar_call_instantiation_object_dir}/${_name}.o")
577+
set(_dep "${_obj}.d")
578+
set(_depfile_arg)
579+
if(CMAKE_GENERATOR MATCHES "Ninja")
580+
set(_depfile_arg DEPFILE "${_dep}")
581+
endif()
582+
add_custom_command(
583+
OUTPUT "${_obj}"
584+
COMMAND ${CMAKE_COMMAND} -E make_directory
585+
"${_iluvatar_call_instantiation_object_dir}"
586+
COMMAND ${ILUVATAR_CUDA_COMPILER}
587+
${_iluvatar_call_instantiation_defs}
588+
${_iluvatar_call_instantiation_include_flags}
589+
${ILUVATAR_CUDA_FLAGS}
590+
-MMD -MF "${_dep}"
591+
-c "${_src}" -o "${_obj}"
592+
DEPENDS "${_src}"
593+
${_depfile_arg}
594+
COMMENT "Compiling ${_name}.cc with CoreX clang++"
595+
VERBATIM
596+
)
597+
list(APPEND ILUVATAR_CALL_INSTANTIATION_OBJECTS "${_obj}")
598+
endforeach()
599+
600+
set_source_files_properties(${ILUVATAR_CALL_INSTANTIATION_OBJECTS}
601+
PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
602+
target_sources(infiniops PRIVATE ${ILUVATAR_CALL_INSTANTIATION_OBJECTS})
603+
else()
604+
target_sources(infiniops PRIVATE ${OPERATOR_CALL_INSTANTIATION_SOURCES})
605+
endif()
606+
endif()
607+
608+
if(GENERATE_PYTHON_BINDINGS)
532609
file(GLOB_RECURSE PYBIND11_SOURCES CONFIGURE_DEPENDS
533610
"${PROJECT_SOURCE_DIR}/generated/bindings/*.cc")
534611

@@ -727,3 +804,44 @@ if(GENERATE_PYTHON_BINDINGS)
727804
DESTINATION .)
728805
endif()
729806
endif()
807+
808+
install(TARGETS infiniops
809+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
810+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
811+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
812+
)
813+
814+
install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/
815+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
816+
)
817+
818+
if(GENERATE_OPERATOR_CALL_INSTANTIATIONS)
819+
install(FILES
820+
${PROJECT_SOURCE_DIR}/generated/include/infini/operator_call_instantiations.h
821+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/infini
822+
)
823+
endif()
824+
825+
file(GLOB INFINIOPS_PUBLIC_CORE_HEADERS CONFIGURE_DEPENDS
826+
"${CMAKE_CURRENT_SOURCE_DIR}/*.h")
827+
828+
install(FILES ${INFINIOPS_PUBLIC_CORE_HEADERS}
829+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
830+
)
831+
832+
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/base/
833+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/base
834+
FILES_MATCHING PATTERN "*.h"
835+
)
836+
837+
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/common/
838+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/common
839+
FILES_MATCHING PATTERN "*.h"
840+
)
841+
842+
if(EXISTS ${PROJECT_SOURCE_DIR}/generated/base)
843+
install(DIRECTORY ${PROJECT_SOURCE_DIR}/generated/base/
844+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/base
845+
FILES_MATCHING PATTERN "*.h"
846+
)
847+
endif()

0 commit comments

Comments
 (0)