@@ -536,12 +536,42 @@ def _generate_tensor_caster(name, is_data=False):
536536
537537
538538def _generate_generated_dispatch_entries (operator ):
539+ optional_tensor_params = _find_optional_tensor_params (operator .name )
540+ vector_tensor_params = _find_vector_tensor_params (operator .name )
541+ vector_int64_params = _find_vector_int64_params (operator .name )
542+
543+ def _is_optional_tensor (arg ):
544+ if arg .spelling in optional_tensor_params :
545+ return True
546+
547+ return "std::optional" in arg .type .spelling and "Tensor" in arg .type .spelling
548+
549+ def _is_vector_tensor (arg ):
550+ if arg .spelling in vector_tensor_params :
551+ return True
552+
553+ return "std::vector" in arg .type .spelling and "Tensor" in arg .type .spelling
554+
555+ def _is_vector_int64 (arg ):
556+ return arg .spelling in vector_int64_params
557+
539558 def _generate_params (node ):
540- return ", " .join (
541- f"{ arg .type .spelling } { arg .spelling } "
542- for arg in node .get_arguments ()
543- if arg .spelling != "stream"
544- )
559+ parts = []
560+
561+ for arg in node .get_arguments ():
562+ if arg .spelling == "stream" :
563+ continue
564+
565+ if _is_optional_tensor (arg ):
566+ parts .append (f"std::optional<Tensor> { arg .spelling } " )
567+ elif _is_vector_tensor (arg ):
568+ parts .append (f"std::vector<Tensor> { arg .spelling } " )
569+ elif _is_vector_int64 (arg ):
570+ parts .append (f"std::vector<int64_t> { arg .spelling } " )
571+ else :
572+ parts .append (f"{ arg .type .spelling } { arg .spelling } " )
573+
574+ return ", " .join (parts )
545575
546576 def _generate_arguments (node ):
547577 return ", " .join (
0 commit comments