Skip to content

Commit 1400daf

Browse files
authored
fix(scripts): apply regex fallback in dispatch entries generation (#620)
1 parent 76094ad commit 1400daf

1 file changed

Lines changed: 35 additions & 5 deletions

File tree

scripts/generate_wrappers.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -536,12 +536,42 @@ def _generate_tensor_caster(name, is_data=False):
536536

537537

538538
def _generate_generated_dispatch_entries(operator):
539+
optional_tensor_params = _find_optional_tensor_params(operator.name)
540+
vector_tensor_params = _find_vector_tensor_params(operator.name)
541+
vector_int64_params = _find_vector_int64_params(operator.name)
542+
543+
def _is_optional_tensor(arg):
544+
if arg.spelling in optional_tensor_params:
545+
return True
546+
547+
return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling
548+
549+
def _is_vector_tensor(arg):
550+
if arg.spelling in vector_tensor_params:
551+
return True
552+
553+
return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling
554+
555+
def _is_vector_int64(arg):
556+
return arg.spelling in vector_int64_params
557+
539558
def _generate_params(node):
540-
return ", ".join(
541-
f"{arg.type.spelling} {arg.spelling}"
542-
for arg in node.get_arguments()
543-
if arg.spelling != "stream"
544-
)
559+
parts = []
560+
561+
for arg in node.get_arguments():
562+
if arg.spelling == "stream":
563+
continue
564+
565+
if _is_optional_tensor(arg):
566+
parts.append(f"std::optional<Tensor> {arg.spelling}")
567+
elif _is_vector_tensor(arg):
568+
parts.append(f"std::vector<Tensor> {arg.spelling}")
569+
elif _is_vector_int64(arg):
570+
parts.append(f"std::vector<int64_t> {arg.spelling}")
571+
else:
572+
parts.append(f"{arg.type.spelling} {arg.spelling}")
573+
574+
return ", ".join(parts)
545575

546576
def _generate_arguments(node):
547577
return ", ".join(

0 commit comments

Comments
 (0)