@@ -81,10 +81,18 @@ def _find_tensor_by_source_name(tensors, name):
8181 variant_specs = _enumerate_variant_specs (
8282 launch_arg_names , tensors , _find_tensor_by_source_name
8383 )
84+ _ , tensor_ndims , _ = _per_tensor_dim_options (
85+ launch_arg_names , tensors , _find_tensor_by_source_name
86+ )
8487
8588 output_contents = {}
8689
87- for variant_suffix , divisibility_spec , contiguity_spec in variant_specs :
90+ for (
91+ variant_suffix ,
92+ divisibility_spec ,
93+ contiguity_spec ,
94+ index_dtype ,
95+ ) in variant_specs :
8896 variant_outputs = _build_variant (
8997 source_file ,
9098 kernel_func ,
@@ -99,11 +107,12 @@ def _find_tensor_by_source_name(tensors, name):
99107 num_stages = num_stages ,
100108 divisibility_spec = divisibility_spec ,
101109 contiguity_spec = contiguity_spec ,
110+ index_dtype = index_dtype ,
102111 )
103112 output_contents .update (variant_outputs )
104113
105114 dispatcher_source , dispatcher_header = _generate_dispatcher (
106- kernel_name , launch_arg_names , variant_specs
115+ kernel_name , launch_arg_names , variant_specs , tensor_ndims
107116 )
108117
109118 output_contents [f"{ kernel_name } .cpp" ] = dispatcher_source
@@ -112,7 +121,7 @@ def _find_tensor_by_source_name(tensors, name):
112121 return output_contents
113122
114123
115- def _generate_dispatcher (kernel_name , launch_arg_names , variant_specs ):
124+ def _generate_dispatcher (kernel_name , launch_arg_names , variant_specs , tensor_ndims ):
116125 tensor_params = ", " .join (f"NineToothedTensor { name } " for name in launch_arg_names )
117126 signature_params = (
118127 f"NineToothedStream stream, { tensor_params } "
@@ -138,14 +147,25 @@ def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
138147 externs = []
139148 branches = []
140149
141- for variant_suffix , divisibility_spec , contiguity_spec in variant_specs :
150+ fallback_call = None
151+
152+ for (
153+ variant_suffix ,
154+ divisibility_spec ,
155+ contiguity_spec ,
156+ index_dtype ,
157+ ) in variant_specs :
142158 variant_name = f"launch_{ kernel_name } _{ variant_suffix } "
143159 externs .append (
144160 f'extern "C" NineToothedResult { variant_name } ({ signature_params } );'
145161 )
146162
147163 call = f"return { variant_name } ({ call_args } );"
148164
165+ if index_dtype == ninetoothed .dtype .int64 :
166+ fallback_call = call
167+ continue
168+
149169 checks = tuple (
150170 f"{ name } .shape[{ dim } ] % 16 == 0" for name , dim in divisibility_spec
151171 ) + tuple (f"{ name } .strides[{ dim } ] == 1" for name , dim in contiguity_spec )
@@ -155,11 +175,26 @@ def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
155175 else :
156176 branches .append (f"{ _INDENTATION } { call } " )
157177
178+ prelude_lines = []
179+ if fallback_call is not None and launch_arg_names :
180+ overflow_terms = []
181+ for name , ndim in zip (launch_arg_names , tensor_ndims ):
182+ for d in range (ndim ):
183+ overflow_terms .append (f"{ name } .shape[{ d } ] > 2147483647ULL" )
184+ overflow_terms .append (f"{ name } .strides[{ d } ] > 2147483647LL" )
185+ overflow_terms .append (f"{ name } .strides[{ d } ] < -2147483648LL" )
186+ if overflow_terms :
187+ prelude_lines .append (
188+ f"{ _INDENTATION } if ({ ' || ' .join (overflow_terms )} ) { fallback_call } "
189+ )
190+
191+ body_lines = prelude_lines + branches
192+
158193 source = (
159194 f'#include "{ _HEADER_PATH } "\n \n '
160195 + "\n " .join (externs )
161196 + f'\n \n extern "C" { signature } {{\n '
162- + "\n " .join (branches )
197+ + "\n " .join (body_lines )
163198 + "\n }\n "
164199 )
165200
@@ -181,6 +216,7 @@ def _build_variant(
181216 num_stages ,
182217 divisibility_spec ,
183218 contiguity_spec ,
219+ index_dtype = ninetoothed .dtype .int32 ,
184220):
185221 divisibility_set = {
186222 (naming .remove_prefixes (name ), dim ) for name , dim in divisibility_spec
@@ -211,9 +247,9 @@ def _build_variant(
211247 bare_source_name = naming .remove_prefixes (source_name )
212248
213249 if (bare_source_name , dim_index ) in divisibility_set :
214- param_types .append (f"{ ninetoothed . dtype . int64 } :16" )
250+ param_types .append (f"{ index_dtype } :16" )
215251 else :
216- param_types .append (ninetoothed . dtype . int64 )
252+ param_types .append (index_dtype )
217253 elif match := Tensor .stride_pattern ().fullmatch (param ):
218254 source_name = match .group (1 )
219255 dim_index = int (match .group (3 ))
@@ -224,7 +260,7 @@ def _build_variant(
224260 constexpr_param_indices .append (len (param_types ) - 1 )
225261 constexpr_strides .append ((source_name , dim_index ))
226262 else :
227- param_types .append (f"{ ninetoothed . dtype . int64 } :16" )
263+ param_types .append (f"{ index_dtype } :16" )
228264 else :
229265 source_name = param
230266 tensor = find_tensor (tensors , source_name )
@@ -331,15 +367,21 @@ def _spec_from_combo(combo):
331367 for divisibility_spec in dim_specs :
332368 for contiguity_spec in dim_specs :
333369 suffix = _variant_suffix (
334- divisibility_spec , contiguity_spec , launch_arg_names , tensor_ndims
370+ divisibility_spec ,
371+ contiguity_spec ,
372+ launch_arg_names ,
373+ tensor_ndims ,
374+ index_dtype = ninetoothed .dtype .int32 ,
375+ )
376+ specs .append (
377+ (suffix , divisibility_spec , contiguity_spec , ninetoothed .dtype .int32 )
335378 )
336- specs .append ((suffix , divisibility_spec , contiguity_spec ))
337379
338380 def _num_innermost (spec ):
339381 return sum (1 for name , dim in spec if innermost_dims .get (name ) == dim )
340382
341383 def _specificity (entry ):
342- _ , divisibility_spec , contiguity_spec = entry
384+ _ , divisibility_spec , contiguity_spec , _ = entry
343385
344386 return (
345387 - len (divisibility_spec ),
@@ -350,6 +392,11 @@ def _specificity(entry):
350392
351393 specs .sort (key = _specificity )
352394
395+ fallback_suffix = _variant_suffix (
396+ (), (), launch_arg_names , tensor_ndims , index_dtype = ninetoothed .dtype .int64
397+ )
398+ specs .append ((fallback_suffix , (), (), ninetoothed .dtype .int64 ))
399+
353400 return specs
354401
355402
@@ -381,15 +428,21 @@ def _per_tensor_dim_options(launch_arg_names, tensors, find_tensor):
381428 return per_tensor_dims , tensor_ndims , innermost_dims
382429
383430
384- def _variant_suffix (divisibility_spec , contiguity_spec , launch_arg_names , tensor_ndims ):
431+ def _variant_suffix (
432+ divisibility_spec ,
433+ contiguity_spec ,
434+ launch_arg_names ,
435+ tensor_ndims ,
436+ index_dtype = ninetoothed .dtype .int32 ,
437+ ):
385438 divisibility_part = _divisibility_suffix (
386439 divisibility_spec , launch_arg_names , tensor_ndims
387440 )
388441 contiguity_part = _contiguity_suffix (
389442 contiguity_spec , launch_arg_names , tensor_ndims
390443 )
391444
392- return f"{ divisibility_part } _{ contiguity_part } "
445+ return f"{ divisibility_part } _{ contiguity_part } _index_ { index_dtype } "
393446
394447
395448def _divisibility_suffix (divisibility_spec , launch_arg_names , tensor_ndims ):
0 commit comments