@@ -81,10 +81,19 @@ def _find_tensor_by_source_name(tensors, name):
8181 variant_specs = _enumerate_variant_specs (
8282 launch_arg_names , tensors , _find_tensor_by_source_name
8383 )
84+ _ , tensor_ndims , _ = _per_tensor_dim_options (
85+ launch_arg_names , tensors , _find_tensor_by_source_name
86+ )
8487
8588 output_contents = {}
8689
87- for variant_suffix , divisibility_spec , contiguity_spec in variant_specs :
90+ for (
91+ variant_suffix ,
92+ divisibility_spec ,
93+ contiguity_spec ,
94+ size_type ,
95+ stride_type ,
96+ ) in variant_specs :
8897 variant_outputs = _build_variant (
8998 source_file ,
9099 kernel_func ,
@@ -99,11 +108,13 @@ def _find_tensor_by_source_name(tensors, name):
99108 num_stages = num_stages ,
100109 divisibility_spec = divisibility_spec ,
101110 contiguity_spec = contiguity_spec ,
111+ size_type = size_type ,
112+ stride_type = stride_type ,
102113 )
103114 output_contents .update (variant_outputs )
104115
105116 dispatcher_source , dispatcher_header = _generate_dispatcher (
106- kernel_name , launch_arg_names , variant_specs
117+ kernel_name , launch_arg_names , variant_specs , tensor_ndims
107118 )
108119
109120 output_contents [f"{ kernel_name } .cpp" ] = dispatcher_source
@@ -112,7 +123,7 @@ def _find_tensor_by_source_name(tensors, name):
112123 return output_contents
113124
114125
115- def _generate_dispatcher (kernel_name , launch_arg_names , variant_specs ):
126+ def _generate_dispatcher (kernel_name , launch_arg_names , variant_specs , tensor_ndims ):
116127 tensor_params = ", " .join (f"NineToothedTensor { name } " for name in launch_arg_names )
117128 signature_params = (
118129 f"NineToothedStream stream, { tensor_params } "
@@ -138,14 +149,30 @@ def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
138149 externs = []
139150 branches = []
140151
141- for variant_suffix , divisibility_spec , contiguity_spec in variant_specs :
152+ fallback_call = None
153+
154+ for (
155+ variant_suffix ,
156+ divisibility_spec ,
157+ contiguity_spec ,
158+ size_type ,
159+ stride_type ,
160+ ) in variant_specs :
142161 variant_name = f"launch_{ kernel_name } _{ variant_suffix } "
143162 externs .append (
144163 f'extern "C" NineToothedResult { variant_name } ({ signature_params } );'
145164 )
146165
147166 call = f"return { variant_name } ({ call_args } );"
148167
168+ if (size_type , stride_type ) == (
169+ ninetoothed .dtype .int64 ,
170+ ninetoothed .dtype .int64 ,
171+ ):
172+ fallback_call = call
173+
174+ continue
175+
149176 checks = tuple (
150177 f"{ name } .shape[{ dim } ] % 16 == 0" for name , dim in divisibility_spec
151178 ) + tuple (f"{ name } .strides[{ dim } ] == 1" for name , dim in contiguity_spec )
@@ -155,11 +182,23 @@ def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
155182 else :
156183 branches .append (f"{ _INDENTATION } { call } " )
157184
185+ prelude_lines = []
186+
187+ if fallback_call is not None and launch_arg_names :
188+ overflow_terms = _overflow_terms (launch_arg_names , tensor_ndims )
189+
190+ if overflow_terms :
191+ prelude_lines .append (
192+ f"{ _INDENTATION } if ({ ' || ' .join (overflow_terms )} ) { fallback_call } "
193+ )
194+
195+ body_lines = prelude_lines + branches
196+
158197 source = (
159198 f'#include "{ _HEADER_PATH } "\n \n '
160199 + "\n " .join (externs )
161200 + f'\n \n extern "C" { signature } {{\n '
162- + "\n " .join (branches )
201+ + "\n " .join (body_lines )
163202 + "\n }\n "
164203 )
165204
@@ -181,6 +220,8 @@ def _build_variant(
181220 num_stages ,
182221 divisibility_spec ,
183222 contiguity_spec ,
223+ size_type = ninetoothed .dtype .int32 ,
224+ stride_type = ninetoothed .dtype .int32 ,
184225):
185226 divisibility_set = {
186227 (naming .remove_prefixes (name ), dim ) for name , dim in divisibility_spec
@@ -211,9 +252,9 @@ def _build_variant(
211252 bare_source_name = naming .remove_prefixes (source_name )
212253
213254 if (bare_source_name , dim_index ) in divisibility_set :
214- param_types .append (f"{ ninetoothed . dtype . int64 } :16" )
255+ param_types .append (f"{ size_type } :16" )
215256 else :
216- param_types .append (ninetoothed . dtype . int64 )
257+ param_types .append (size_type )
217258 elif match := Tensor .stride_pattern ().fullmatch (param ):
218259 source_name = match .group (1 )
219260 dim_index = int (match .group (3 ))
@@ -224,7 +265,7 @@ def _build_variant(
224265 constexpr_param_indices .append (len (param_types ) - 1 )
225266 constexpr_strides .append ((source_name , dim_index ))
226267 else :
227- param_types .append (f"{ ninetoothed . dtype . int64 } :16" )
268+ param_types .append (f"{ stride_type } :16" )
228269 else :
229270 source_name = param
230271 tensor = find_tensor (tensors , source_name )
@@ -331,15 +372,28 @@ def _spec_from_combo(combo):
331372 for divisibility_spec in dim_specs :
332373 for contiguity_spec in dim_specs :
333374 suffix = _variant_suffix (
334- divisibility_spec , contiguity_spec , launch_arg_names , tensor_ndims
375+ divisibility_spec ,
376+ contiguity_spec ,
377+ launch_arg_names ,
378+ tensor_ndims ,
379+ size_type = ninetoothed .dtype .int32 ,
380+ stride_type = ninetoothed .dtype .int32 ,
381+ )
382+ specs .append (
383+ (
384+ suffix ,
385+ divisibility_spec ,
386+ contiguity_spec ,
387+ ninetoothed .dtype .int32 ,
388+ ninetoothed .dtype .int32 ,
389+ )
335390 )
336- specs .append ((suffix , divisibility_spec , contiguity_spec ))
337391
338392 def _num_innermost (spec ):
339393 return sum (1 for name , dim in spec if innermost_dims .get (name ) == dim )
340394
341395 def _specificity (entry ):
342- _ , divisibility_spec , contiguity_spec = entry
396+ _ , divisibility_spec , contiguity_spec , _ , _ = entry
343397
344398 return (
345399 - len (divisibility_spec ),
@@ -350,6 +404,24 @@ def _specificity(entry):
350404
351405 specs .sort (key = _specificity )
352406
407+ fallback_suffix = _variant_suffix (
408+ (),
409+ (),
410+ launch_arg_names ,
411+ tensor_ndims ,
412+ size_type = ninetoothed .dtype .int64 ,
413+ stride_type = ninetoothed .dtype .int64 ,
414+ )
415+ specs .append (
416+ (
417+ fallback_suffix ,
418+ (),
419+ (),
420+ ninetoothed .dtype .int64 ,
421+ ninetoothed .dtype .int64 ,
422+ )
423+ )
424+
353425 return specs
354426
355427
@@ -381,15 +453,40 @@ def _per_tensor_dim_options(launch_arg_names, tensors, find_tensor):
381453 return per_tensor_dims , tensor_ndims , innermost_dims
382454
383455
384- def _variant_suffix (divisibility_spec , contiguity_spec , launch_arg_names , tensor_ndims ):
456+ def _overflow_terms (launch_arg_names , tensor_ndims ):
457+ int32_min = - (2 ** 31 )
458+ int32_max = 2 ** 31 - 1
459+
460+ return tuple (
461+ term
462+ for name , ndim in zip (launch_arg_names , tensor_ndims )
463+ for dim in range (ndim )
464+ for term in (
465+ f"{ name } .shape[{ dim } ] > { int32_max } ULL" ,
466+ f"{ name } .strides[{ dim } ] > { int32_max } LL" ,
467+ f"{ name } .strides[{ dim } ] < { int32_min } LL" ,
468+ )
469+ )
470+
471+
472+ def _variant_suffix (
473+ divisibility_spec ,
474+ contiguity_spec ,
475+ launch_arg_names ,
476+ tensor_ndims ,
477+ size_type = ninetoothed .dtype .int32 ,
478+ stride_type = ninetoothed .dtype .int32 ,
479+ ):
385480 divisibility_part = _divisibility_suffix (
386481 divisibility_spec , launch_arg_names , tensor_ndims
387482 )
388483 contiguity_part = _contiguity_suffix (
389484 contiguity_spec , launch_arg_names , tensor_ndims
390485 )
391486
392- return f"{ divisibility_part } _{ contiguity_part } "
487+ return (
488+ f"{ divisibility_part } _{ contiguity_part } _size_{ size_type } _stride_{ stride_type } "
489+ )
393490
394491
395492def _divisibility_suffix (divisibility_spec , launch_arg_names , tensor_ndims ):
0 commit comments