@@ -702,10 +702,113 @@ def _generate_generated_dispatch_source(impl_paths, definitions):
702702"""
703703
704704
705+ def _strip_top_level_const (type_spelling ):
706+ type_spelling = " " .join (type_spelling .split ())
707+
708+ while type_spelling .startswith ("const " ):
709+ type_spelling = type_spelling [len ("const " ) :]
710+
711+ return type_spelling
712+
713+
714+ def _generate_operator_call_instantiation_entries (operator ):
715+ def _generate_template_arguments (node ):
716+ return ", " .join (
717+ _strip_top_level_const (arg .type .spelling )
718+ for arg in node .get_arguments ()
719+ if arg .spelling != "stream"
720+ )
721+
722+ def _generate_parameters (node ):
723+ return ", " .join (
724+ f"const { _strip_top_level_const (arg .type .spelling )} & { arg .spelling } "
725+ for arg in node .get_arguments ()
726+ if arg .spelling != "stream"
727+ )
728+
729+ def _append_optional_params (prefix , params ):
730+ if params :
731+ return f"{ prefix } , { params } "
732+
733+ return prefix
734+
735+ pascal_case_op_name = _snake_to_pascal (operator .name )
736+ declarations = []
737+ definitions = []
738+
739+ for call in operator .calls :
740+ template_arguments = _generate_template_arguments (call )
741+ params = _generate_parameters (call )
742+ function_params = _append_optional_params (
743+ "const Handle& handle, const Config& config" , params
744+ )
745+ instantiation = (
746+ f"Operator<{ pascal_case_op_name } >::Call<{ template_arguments } >"
747+ f"({ function_params } )"
748+ )
749+
750+ declarations .append (f"extern template auto { instantiation } ;" )
751+ definitions .append (f"template auto { instantiation } ;" )
752+
753+ return declarations , definitions
754+
755+
756+ def _generate_operator_call_instantiation_header (op_names , declarations ):
757+ header_base_includes = "\n " .join (
758+ f'#include "base/{ op_name } .h"' for op_name in op_names
759+ )
760+
761+ return f"""#ifndef INFINI_OPS_OPERATOR_CALL_INSTANTIATIONS_H_
762+ #define INFINI_OPS_OPERATOR_CALL_INSTANTIATIONS_H_
763+
764+ #include <cstdint>
765+ #include <optional>
766+ #include <vector>
767+
768+ #include "config.h"
769+ #include "handle.h"
770+ #include "operator.h"
771+
772+ { header_base_includes }
773+
774+ namespace infini::ops {{
775+
776+ { chr (10 ).join (declarations )}
777+
778+ }} // namespace infini::ops
779+
780+ #endif
781+ """
782+
783+
784+ def _generate_operator_call_instantiation_source (devices , impl_paths , definitions ):
785+ device_includes = "\n " .join (
786+ f'#include "{ path } "' for path in _device_marker_headers (devices )
787+ )
788+ impl_includes = "\n " .join (
789+ f'#include "{ _to_include_path (impl_path )} "' for impl_path in impl_paths
790+ )
791+
792+ return f"""#include "infini/operator_call_instantiations.h"
793+
794+ // clang-format off
795+ { device_includes }
796+ { impl_includes }
797+ // clang-format on
798+
799+ namespace infini::ops {{
800+
801+ { chr (10 ).join (definitions )}
802+
803+ }} // namespace infini::ops
804+ """
805+
806+
705807def _device_marker_headers (devices ):
706808 paths = {
707809 "cpu" : "native/cpu/device_.h" ,
708810 "nvidia" : "native/cuda/nvidia/device_.h" ,
811+ "hygon" : "native/cuda/hygon/device_.h" ,
709812 "cambricon" : "native/cambricon/device_.h" ,
710813 "ascend" : "native/ascend/device_.h" ,
711814 "metax" : "native/cuda/metax/device_.h" ,
@@ -821,6 +924,10 @@ def _generate_op_artifacts(item):
821924 dispatch_declarations , dispatch_definitions = _generate_generated_dispatch_entries (
822925 operator
823926 )
927+ (
928+ call_instantiation_declarations ,
929+ call_instantiation_definitions ,
930+ ) = _generate_operator_call_instantiation_entries (operator )
824931
825932 return {
826933 "op_name" : op_name ,
@@ -832,6 +939,8 @@ def _generate_op_artifacts(item):
832939 "legacy_c_header" : legacy_c_header ,
833940 "dispatch_declarations" : dispatch_declarations ,
834941 "dispatch_definitions" : dispatch_definitions ,
942+ "call_instantiation_declarations" : call_instantiation_declarations ,
943+ "call_instantiation_definitions" : call_instantiation_definitions ,
835944 "impl_paths" : impl_paths ,
836945 }
837946
@@ -929,6 +1038,11 @@ def _dispatch_gen_batch_size():
9291038 for artifact in artifacts
9301039 for declaration in artifact ["dispatch_declarations" ]
9311040 ]
1041+ call_instantiation_declarations = [
1042+ declaration
1043+ for artifact in artifacts
1044+ for declaration in artifact ["call_instantiation_declarations" ]
1045+ ]
9321046 use_monolithic_bindings = _use_monolithic_bindings ()
9331047 op_includes = []
9341048
@@ -958,6 +1072,14 @@ def _dispatch_gen_batch_size():
9581072 )
9591073 (_BINDINGS_DIR / "generated_dispatch.h" ).write_text (dispatch_header )
9601074
1075+ call_instantiation_header = _generate_operator_call_instantiation_header (
1076+ op_names , call_instantiation_declarations
1077+ )
1078+ (_INCLUDE_DIR / "infini" ).mkdir (exist_ok = True )
1079+ (_INCLUDE_DIR / "infini" / "operator_call_instantiations.h" ).write_text (
1080+ call_instantiation_header
1081+ )
1082+
9611083 dispatch_batch_size = _dispatch_gen_batch_size ()
9621084
9631085 for dispatch_batch_index , start in enumerate (
@@ -979,6 +1101,28 @@ def _dispatch_gen_batch_size():
9791101 dispatch_source
9801102 )
9811103
1104+ for call_instantiation_batch_index , start in enumerate (
1105+ range (0 , len (artifacts ), dispatch_batch_size )
1106+ ):
1107+ batch = artifacts [start : start + dispatch_batch_size ]
1108+ impl_paths = list (
1109+ dict .fromkeys (
1110+ impl_path for artifact in batch for impl_path in artifact ["impl_paths" ]
1111+ )
1112+ )
1113+ definitions = [
1114+ definition
1115+ for artifact in batch
1116+ for definition in artifact ["call_instantiation_definitions" ]
1117+ ]
1118+ call_instantiation_source = _generate_operator_call_instantiation_source (
1119+ args .devices , impl_paths , definitions
1120+ )
1121+ (
1122+ _GENERATED_SRC_DIR
1123+ / f"operator_call_instantiations_{ call_instantiation_batch_index } .cc"
1124+ ).write_text (call_instantiation_source )
1125+
9821126 bind_func_calls = "\n " .join (
9831127 f"{ bind_func_name } (m);" for bind_func_name in bind_func_names
9841128 )
0 commit comments