Skip to content

Commit 3ce8881

Browse files
committed
Specialize AOT strides and sizes as int32 with int64 fallback
The AOT path previously declared all size and stride parameters as `int64` in the Triton signature, causing the compiled kernel to use 64-bit integer arithmetic throughout the address computation chain. For typical tensor dimensions (< 2^31 elements), `int32` suffices and matches what Triton's JIT path auto-selects, yielding significantly fewer PTX 64-bit instructions (1246 → 9 in conv2d, ~1.46× speedup). This change introduces an `index_dtype` axis to variant enumeration: all divisibility × contiguity combinations use `int32` by default, and a single `int64` fallback variant with no hints is appended. The C++ dispatcher checks whether any shape or stride value exceeds `int32` range before dispatching to an `int32` variant; if overflow is detected, it falls back to the `int64` kernel.
1 parent d1c79e9 commit 3ce8881

1 file changed

Lines changed: 66 additions & 13 deletions

File tree

src/ninetoothed/aot.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,18 @@ def _find_tensor_by_source_name(tensors, name):
8181
variant_specs = _enumerate_variant_specs(
8282
launch_arg_names, tensors, _find_tensor_by_source_name
8383
)
84+
_, tensor_ndims, _ = _per_tensor_dim_options(
85+
launch_arg_names, tensors, _find_tensor_by_source_name
86+
)
8487

8588
output_contents = {}
8689

87-
for variant_suffix, divisibility_spec, contiguity_spec in variant_specs:
90+
for (
91+
variant_suffix,
92+
divisibility_spec,
93+
contiguity_spec,
94+
index_dtype,
95+
) in variant_specs:
8896
variant_outputs = _build_variant(
8997
source_file,
9098
kernel_func,
@@ -99,11 +107,12 @@ def _find_tensor_by_source_name(tensors, name):
99107
num_stages=num_stages,
100108
divisibility_spec=divisibility_spec,
101109
contiguity_spec=contiguity_spec,
110+
index_dtype=index_dtype,
102111
)
103112
output_contents.update(variant_outputs)
104113

105114
dispatcher_source, dispatcher_header = _generate_dispatcher(
106-
kernel_name, launch_arg_names, variant_specs
115+
kernel_name, launch_arg_names, variant_specs, tensor_ndims
107116
)
108117

109118
output_contents[f"{kernel_name}.cpp"] = dispatcher_source
@@ -112,7 +121,7 @@ def _find_tensor_by_source_name(tensors, name):
112121
return output_contents
113122

114123

115-
def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
124+
def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs, tensor_ndims):
116125
tensor_params = ", ".join(f"NineToothedTensor {name}" for name in launch_arg_names)
117126
signature_params = (
118127
f"NineToothedStream stream, {tensor_params}"
@@ -138,14 +147,25 @@ def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
138147
externs = []
139148
branches = []
140149

141-
for variant_suffix, divisibility_spec, contiguity_spec in variant_specs:
150+
fallback_call = None
151+
152+
for (
153+
variant_suffix,
154+
divisibility_spec,
155+
contiguity_spec,
156+
index_dtype,
157+
) in variant_specs:
142158
variant_name = f"launch_{kernel_name}_{variant_suffix}"
143159
externs.append(
144160
f'extern "C" NineToothedResult {variant_name}({signature_params});'
145161
)
146162

147163
call = f"return {variant_name}({call_args});"
148164

165+
if index_dtype == ninetoothed.dtype.int64:
166+
fallback_call = call
167+
continue
168+
149169
checks = tuple(
150170
f"{name}.shape[{dim}] % 16 == 0" for name, dim in divisibility_spec
151171
) + tuple(f"{name}.strides[{dim}] == 1" for name, dim in contiguity_spec)
@@ -155,11 +175,26 @@ def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
155175
else:
156176
branches.append(f"{_INDENTATION}{call}")
157177

178+
prelude_lines = []
179+
if fallback_call is not None and launch_arg_names:
180+
overflow_terms = []
181+
for name, ndim in zip(launch_arg_names, tensor_ndims):
182+
for d in range(ndim):
183+
overflow_terms.append(f"{name}.shape[{d}] > 2147483647ULL")
184+
overflow_terms.append(f"{name}.strides[{d}] > 2147483647LL")
185+
overflow_terms.append(f"{name}.strides[{d}] < -2147483648LL")
186+
if overflow_terms:
187+
prelude_lines.append(
188+
f"{_INDENTATION}if ({' || '.join(overflow_terms)}) {fallback_call}"
189+
)
190+
191+
body_lines = prelude_lines + branches
192+
158193
source = (
159194
f'#include "{_HEADER_PATH}"\n\n'
160195
+ "\n".join(externs)
161196
+ f'\n\nextern "C" {signature} {{\n'
162-
+ "\n".join(branches)
197+
+ "\n".join(body_lines)
163198
+ "\n}\n"
164199
)
165200

@@ -181,6 +216,7 @@ def _build_variant(
181216
num_stages,
182217
divisibility_spec,
183218
contiguity_spec,
219+
index_dtype=ninetoothed.dtype.int32,
184220
):
185221
divisibility_set = {
186222
(naming.remove_prefixes(name), dim) for name, dim in divisibility_spec
@@ -211,9 +247,9 @@ def _build_variant(
211247
bare_source_name = naming.remove_prefixes(source_name)
212248

213249
if (bare_source_name, dim_index) in divisibility_set:
214-
param_types.append(f"{ninetoothed.dtype.int64}:16")
250+
param_types.append(f"{index_dtype}:16")
215251
else:
216-
param_types.append(ninetoothed.dtype.int64)
252+
param_types.append(index_dtype)
217253
elif match := Tensor.stride_pattern().fullmatch(param):
218254
source_name = match.group(1)
219255
dim_index = int(match.group(3))
@@ -224,7 +260,7 @@ def _build_variant(
224260
constexpr_param_indices.append(len(param_types) - 1)
225261
constexpr_strides.append((source_name, dim_index))
226262
else:
227-
param_types.append(f"{ninetoothed.dtype.int64}:16")
263+
param_types.append(f"{index_dtype}:16")
228264
else:
229265
source_name = param
230266
tensor = find_tensor(tensors, source_name)
@@ -331,15 +367,21 @@ def _spec_from_combo(combo):
331367
for divisibility_spec in dim_specs:
332368
for contiguity_spec in dim_specs:
333369
suffix = _variant_suffix(
334-
divisibility_spec, contiguity_spec, launch_arg_names, tensor_ndims
370+
divisibility_spec,
371+
contiguity_spec,
372+
launch_arg_names,
373+
tensor_ndims,
374+
index_dtype=ninetoothed.dtype.int32,
375+
)
376+
specs.append(
377+
(suffix, divisibility_spec, contiguity_spec, ninetoothed.dtype.int32)
335378
)
336-
specs.append((suffix, divisibility_spec, contiguity_spec))
337379

338380
def _num_innermost(spec):
339381
return sum(1 for name, dim in spec if innermost_dims.get(name) == dim)
340382

341383
def _specificity(entry):
342-
_, divisibility_spec, contiguity_spec = entry
384+
_, divisibility_spec, contiguity_spec, _ = entry
343385

344386
return (
345387
-len(divisibility_spec),
@@ -350,6 +392,11 @@ def _specificity(entry):
350392

351393
specs.sort(key=_specificity)
352394

395+
fallback_suffix = _variant_suffix(
396+
(), (), launch_arg_names, tensor_ndims, index_dtype=ninetoothed.dtype.int64
397+
)
398+
specs.append((fallback_suffix, (), (), ninetoothed.dtype.int64))
399+
353400
return specs
354401

355402

@@ -381,15 +428,21 @@ def _per_tensor_dim_options(launch_arg_names, tensors, find_tensor):
381428
return per_tensor_dims, tensor_ndims, innermost_dims
382429

383430

384-
def _variant_suffix(divisibility_spec, contiguity_spec, launch_arg_names, tensor_ndims):
431+
def _variant_suffix(
432+
divisibility_spec,
433+
contiguity_spec,
434+
launch_arg_names,
435+
tensor_ndims,
436+
index_dtype=ninetoothed.dtype.int32,
437+
):
385438
divisibility_part = _divisibility_suffix(
386439
divisibility_spec, launch_arg_names, tensor_ndims
387440
)
388441
contiguity_part = _contiguity_suffix(
389442
contiguity_spec, launch_arg_names, tensor_ndims
390443
)
391444

392-
return f"{divisibility_part}_{contiguity_part}"
445+
return f"{divisibility_part}_{contiguity_part}_index_{index_dtype}"
393446

394447

395448
def _divisibility_suffix(divisibility_spec, launch_arg_names, tensor_ndims):

0 commit comments

Comments
 (0)