Skip to content

Commit b1c554b

Browse files
authored
Wrap scalar int as NineToothedTensor in launcher (#152)
1 parent 7355513 commit b1c554b

1 file changed

Lines changed: 29 additions & 4 deletions

File tree

src/ninetoothed/aot.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -696,10 +696,30 @@ def _load_launch_func(kernel_name, output_dir):
696696
launch_func_name = f"launch_{kernel_name}"
697697
launch_func = getattr(library, launch_func_name)
698698

699+
def _num_tensor_args(kernel_name, output_dir):
700+
header_path = pathlib.Path(output_dir) / f"{kernel_name}.h"
701+
text = header_path.read_text()
702+
pattern = (
703+
rf"NineToothedResult\s+launch_{re.escape(kernel_name)}\s*\((.*?)\)\s*;"
704+
)
705+
match = re.search(pattern, text)
706+
707+
if match is None:
708+
raise RuntimeError(
709+
f"Could not find launch signature for `{kernel_name}` in `{header_path}`."
710+
)
711+
712+
params = (param.strip() for param in match.group(1).split(","))
713+
714+
return sum(1 for param in params if param.startswith("NineToothedTensor "))
715+
716+
num_tensor_args = _num_tensor_args(kernel_name, output_dir)
717+
699718
dtype_to_index = _DTYPE_TO_INDEX
700719
from_torch_tensor = _ArgumentTensor.from_torch_tensor
701720
from_scalar = _ArgumentTensor.from_scalar
702721
c_double = ctypes.c_double
722+
c_int64 = ctypes.c_int64
703723
c_void_p = ctypes.c_void_p
704724
current_device = torch.cuda.current_device
705725
get_current_raw_stream = torch._C._cuda_getCurrentRawStream
@@ -709,12 +729,17 @@ def _run_launch_func(*args):
709729
arguments = [None] * len(args)
710730

711731
for i, arg in enumerate(args):
712-
if isinstance(arg, Tensor_cls):
713-
arguments[i] = from_torch_tensor(arg)
732+
if i < num_tensor_args:
733+
if isinstance(arg, Tensor_cls):
734+
arguments[i] = from_torch_tensor(arg)
735+
elif type(arg) is float:
736+
arguments[i] = from_scalar(arg, c_double)
737+
elif type(arg) is int:
738+
arguments[i] = from_scalar(arg, c_int64)
739+
else:
740+
arguments[i] = arg
714741
elif type(arg) is str:
715742
arguments[i] = dtype_to_index[arg]
716-
elif type(arg) is float:
717-
arguments[i] = from_scalar(arg, c_double)
718743
else:
719744
arguments[i] = arg
720745

0 commit comments

Comments
 (0)