@@ -696,10 +696,30 @@ def _load_launch_func(kernel_name, output_dir):
696696 launch_func_name = f"launch_{ kernel_name } "
697697 launch_func = getattr (library , launch_func_name )
698698
699+ def _num_tensor_args (kernel_name , output_dir ):
700+ header_path = pathlib .Path (output_dir ) / f"{ kernel_name } .h"
701+ text = header_path .read_text ()
702+ pattern = (
703+ rf"NineToothedResult\s+launch_{ re .escape (kernel_name )} \s*\((.*?)\)\s*;"
704+ )
705+ match = re .search (pattern , text )
706+
707+ if match is None :
708+ raise RuntimeError (
709+ f"Could not find launch signature for `{ kernel_name } ` in `{ header_path } `."
710+ )
711+
712+ params = (param .strip () for param in match .group (1 ).split ("," ))
713+
714+ return sum (1 for param in params if param .startswith ("NineToothedTensor " ))
715+
716+ num_tensor_args = _num_tensor_args (kernel_name , output_dir )
717+
699718 dtype_to_index = _DTYPE_TO_INDEX
700719 from_torch_tensor = _ArgumentTensor .from_torch_tensor
701720 from_scalar = _ArgumentTensor .from_scalar
702721 c_double = ctypes .c_double
722+ c_int64 = ctypes .c_int64
703723 c_void_p = ctypes .c_void_p
704724 current_device = torch .cuda .current_device
705725 get_current_raw_stream = torch ._C ._cuda_getCurrentRawStream
@@ -709,12 +729,17 @@ def _run_launch_func(*args):
709729 arguments = [None ] * len (args )
710730
711731 for i , arg in enumerate (args ):
712- if isinstance (arg , Tensor_cls ):
713- arguments [i ] = from_torch_tensor (arg )
732+ if i < num_tensor_args :
733+ if isinstance (arg , Tensor_cls ):
734+ arguments [i ] = from_torch_tensor (arg )
735+ elif type (arg ) is float :
736+ arguments [i ] = from_scalar (arg , c_double )
737+ elif type (arg ) is int :
738+ arguments [i ] = from_scalar (arg , c_int64 )
739+ else :
740+ arguments [i ] = arg
714741 elif type (arg ) is str :
715742 arguments [i ] = dtype_to_index [arg ]
716- elif type (arg ) is float :
717- arguments [i ] = from_scalar (arg , c_double )
718743 else :
719744 arguments [i ] = arg
720745
0 commit comments