@@ -702,10 +702,113 @@ def _generate_generated_dispatch_source(impl_paths, definitions):
702702"""
703703
704704
705+ def _strip_top_level_const (type_spelling ):
706+ type_spelling = " " .join (type_spelling .split ())
707+
708+ while type_spelling .startswith ("const " ):
709+ type_spelling = type_spelling [len ("const " ) :]
710+
711+ return type_spelling
712+
713+
714+ def _generate_operator_call_instantiation_entries (operator ):
715+ def _generate_template_arguments (node ):
716+ return ", " .join (
717+ _strip_top_level_const (arg .type .spelling )
718+ for arg in node .get_arguments ()
719+ if arg .spelling != "stream"
720+ )
721+
722+ def _generate_parameters (node ):
723+ return ", " .join (
724+ f"const { _strip_top_level_const (arg .type .spelling )} & { arg .spelling } "
725+ for arg in node .get_arguments ()
726+ if arg .spelling != "stream"
727+ )
728+
729+ def _append_optional_params (prefix , params ):
730+ if params :
731+ return f"{ prefix } , { params } "
732+
733+ return prefix
734+
735+ pascal_case_op_name = _snake_to_pascal (operator .name )
736+ declarations = []
737+ definitions = []
738+
739+ for call in operator .calls :
740+ template_arguments = _generate_template_arguments (call )
741+ params = _generate_parameters (call )
742+ function_params = _append_optional_params (
743+ "const Handle& handle, const Config& config" , params
744+ )
745+ instantiation = (
746+ f"Operator<{ pascal_case_op_name } >::Call<{ template_arguments } >"
747+ f"({ function_params } )"
748+ )
749+
750+ declarations .append (f"extern template auto { instantiation } ;" )
751+ definitions .append (f"template auto { instantiation } ;" )
752+
753+ return declarations , definitions
754+
755+
756+ def _generate_operator_call_instantiation_header (op_names , declarations ):
757+ header_base_includes = "\n " .join (
758+ f'#include "base/{ op_name } .h"' for op_name in op_names
759+ )
760+
761+ return f"""#ifndef INFINI_OPS_OPERATOR_CALL_INSTANTIATIONS_H_
762+ #define INFINI_OPS_OPERATOR_CALL_INSTANTIATIONS_H_
763+
764+ #include <cstdint>
765+ #include <optional>
766+ #include <vector>
767+
768+ #include "config.h"
769+ #include "handle.h"
770+ #include "operator.h"
771+
772+ { header_base_includes }
773+
774+ namespace infini::ops {{
775+
776+ { chr (10 ).join (declarations )}
777+
778+ }} // namespace infini::ops
779+
780+ #endif
781+ """
782+
783+
784+ def _generate_operator_call_instantiation_source (devices , impl_paths , definitions ):
785+ device_includes = "\n " .join (
786+ f'#include "{ path } "' for path in _device_marker_headers (devices )
787+ )
788+ impl_includes = "\n " .join (
789+ f'#include "{ _to_include_path (impl_path )} "' for impl_path in impl_paths
790+ )
791+
792+ return f"""#include "infini/operator_call_instantiations.h"
793+
794+ // clang-format off
795+ { device_includes }
796+ { impl_includes }
797+ // clang-format on
798+
799+ namespace infini::ops {{
800+
801+ { chr (10 ).join (definitions )}
802+
803+ }} // namespace infini::ops
804+ """
805+
806+
705807def _device_marker_headers (devices ):
706808 paths = {
707809 "cpu" : "native/cpu/device_.h" ,
708810 "nvidia" : "native/cuda/nvidia/device_.h" ,
811+ "hygon" : "native/cuda/hygon/device_.h" ,
709812 "cambricon" : "native/cambricon/device_.h" ,
710813 "ascend" : "native/ascend/device_.h" ,
711814 "metax" : "native/cuda/metax/device_.h" ,
@@ -819,6 +922,10 @@ def _generate_op_artifacts(item):
819922 dispatch_declarations , dispatch_definitions = _generate_generated_dispatch_entries (
820923 operator
821924 )
925+ (
926+ call_instantiation_declarations ,
927+ call_instantiation_definitions ,
928+ ) = _generate_operator_call_instantiation_entries (operator )
822929
823930 return {
824931 "op_name" : op_name ,
@@ -830,6 +937,8 @@ def _generate_op_artifacts(item):
830937 "legacy_c_header" : legacy_c_header ,
831938 "dispatch_declarations" : dispatch_declarations ,
832939 "dispatch_definitions" : dispatch_definitions ,
940+ "call_instantiation_declarations" : call_instantiation_declarations ,
941+ "call_instantiation_definitions" : call_instantiation_definitions ,
833942 "impl_paths" : impl_paths ,
834943 }
835944
@@ -918,6 +1027,11 @@ def _dispatch_gen_batch_size():
9181027 for artifact in artifacts
9191028 for declaration in artifact ["dispatch_declarations" ]
9201029 ]
1030+ call_instantiation_declarations = [
1031+ declaration
1032+ for artifact in artifacts
1033+ for declaration in artifact ["call_instantiation_declarations" ]
1034+ ]
9211035 use_monolithic_bindings = _use_monolithic_bindings ()
9221036 op_includes = []
9231037
@@ -947,6 +1061,14 @@ def _dispatch_gen_batch_size():
9471061 )
9481062 (_BINDINGS_DIR / "generated_dispatch.h" ).write_text (dispatch_header )
9491063
1064+ call_instantiation_header = _generate_operator_call_instantiation_header (
1065+ op_names , call_instantiation_declarations
1066+ )
1067+ (_INCLUDE_DIR / "infini" ).mkdir (exist_ok = True )
1068+ (_INCLUDE_DIR / "infini" / "operator_call_instantiations.h" ).write_text (
1069+ call_instantiation_header
1070+ )
1071+
9501072 dispatch_batch_size = _dispatch_gen_batch_size ()
9511073
9521074 for dispatch_batch_index , start in enumerate (
@@ -968,6 +1090,28 @@ def _dispatch_gen_batch_size():
9681090 dispatch_source
9691091 )
9701092
1093+ for call_instantiation_batch_index , start in enumerate (
1094+ range (0 , len (artifacts ), dispatch_batch_size )
1095+ ):
1096+ batch = artifacts [start : start + dispatch_batch_size ]
1097+ impl_paths = list (
1098+ dict .fromkeys (
1099+ impl_path for artifact in batch for impl_path in artifact ["impl_paths" ]
1100+ )
1101+ )
1102+ definitions = [
1103+ definition
1104+ for artifact in batch
1105+ for definition in artifact ["call_instantiation_definitions" ]
1106+ ]
1107+ call_instantiation_source = _generate_operator_call_instantiation_source (
1108+ args .devices , impl_paths , definitions
1109+ )
1110+ (
1111+ _GENERATED_SRC_DIR
1112+ / f"operator_call_instantiations_{ call_instantiation_batch_index } .cc"
1113+ ).write_text (call_instantiation_source )
1114+
9711115 bind_func_calls = "\n " .join (
9721116 f"{ bind_func_name } (m);" for bind_func_name in bind_func_names
9731117 )
0 commit comments