Skip to content

Commit 41afcf8

Browse files
authored
Specialize AOT sizes and strides as int32 (#155)
* Specialize AOT sizes and strides as `int32` * Add a test case for AOT `int32` overflow checks
1 parent efbf963 commit 41afcf8

2 files changed

Lines changed: 124 additions & 13 deletions

File tree

src/ninetoothed/aot.py

Lines changed: 110 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,19 @@ def _find_tensor_by_source_name(tensors, name):
8181
variant_specs = _enumerate_variant_specs(
8282
launch_arg_names, tensors, _find_tensor_by_source_name
8383
)
84+
_, tensor_ndims, _ = _per_tensor_dim_options(
85+
launch_arg_names, tensors, _find_tensor_by_source_name
86+
)
8487

8588
output_contents = {}
8689

87-
for variant_suffix, divisibility_spec, contiguity_spec in variant_specs:
90+
for (
91+
variant_suffix,
92+
divisibility_spec,
93+
contiguity_spec,
94+
size_type,
95+
stride_type,
96+
) in variant_specs:
8897
variant_outputs = _build_variant(
8998
source_file,
9099
kernel_func,
@@ -99,11 +108,13 @@ def _find_tensor_by_source_name(tensors, name):
99108
num_stages=num_stages,
100109
divisibility_spec=divisibility_spec,
101110
contiguity_spec=contiguity_spec,
111+
size_type=size_type,
112+
stride_type=stride_type,
102113
)
103114
output_contents.update(variant_outputs)
104115

105116
dispatcher_source, dispatcher_header = _generate_dispatcher(
106-
kernel_name, launch_arg_names, variant_specs
117+
kernel_name, launch_arg_names, variant_specs, tensor_ndims
107118
)
108119

109120
output_contents[f"{kernel_name}.cpp"] = dispatcher_source
@@ -112,7 +123,7 @@ def _find_tensor_by_source_name(tensors, name):
112123
return output_contents
113124

114125

115-
def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
126+
def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs, tensor_ndims):
116127
tensor_params = ", ".join(f"NineToothedTensor {name}" for name in launch_arg_names)
117128
signature_params = (
118129
f"NineToothedStream stream, {tensor_params}"
@@ -138,14 +149,30 @@ def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
138149
externs = []
139150
branches = []
140151

141-
for variant_suffix, divisibility_spec, contiguity_spec in variant_specs:
152+
fallback_call = None
153+
154+
for (
155+
variant_suffix,
156+
divisibility_spec,
157+
contiguity_spec,
158+
size_type,
159+
stride_type,
160+
) in variant_specs:
142161
variant_name = f"launch_{kernel_name}_{variant_suffix}"
143162
externs.append(
144163
f'extern "C" NineToothedResult {variant_name}({signature_params});'
145164
)
146165

147166
call = f"return {variant_name}({call_args});"
148167

168+
if (size_type, stride_type) == (
169+
ninetoothed.dtype.int64,
170+
ninetoothed.dtype.int64,
171+
):
172+
fallback_call = call
173+
174+
continue
175+
149176
checks = tuple(
150177
f"{name}.shape[{dim}] % 16 == 0" for name, dim in divisibility_spec
151178
) + tuple(f"{name}.strides[{dim}] == 1" for name, dim in contiguity_spec)
@@ -155,11 +182,23 @@ def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
155182
else:
156183
branches.append(f"{_INDENTATION}{call}")
157184

185+
prelude_lines = []
186+
187+
if fallback_call is not None and launch_arg_names:
188+
overflow_terms = _overflow_terms(launch_arg_names, tensor_ndims)
189+
190+
if overflow_terms:
191+
prelude_lines.append(
192+
f"{_INDENTATION}if ({' || '.join(overflow_terms)}) {fallback_call}"
193+
)
194+
195+
body_lines = prelude_lines + branches
196+
158197
source = (
159198
f'#include "{_HEADER_PATH}"\n\n'
160199
+ "\n".join(externs)
161200
+ f'\n\nextern "C" {signature} {{\n'
162-
+ "\n".join(branches)
201+
+ "\n".join(body_lines)
163202
+ "\n}\n"
164203
)
165204

@@ -181,6 +220,8 @@ def _build_variant(
181220
num_stages,
182221
divisibility_spec,
183222
contiguity_spec,
223+
size_type=ninetoothed.dtype.int32,
224+
stride_type=ninetoothed.dtype.int32,
184225
):
185226
divisibility_set = {
186227
(naming.remove_prefixes(name), dim) for name, dim in divisibility_spec
@@ -211,9 +252,9 @@ def _build_variant(
211252
bare_source_name = naming.remove_prefixes(source_name)
212253

213254
if (bare_source_name, dim_index) in divisibility_set:
214-
param_types.append(f"{ninetoothed.dtype.int64}:16")
255+
param_types.append(f"{size_type}:16")
215256
else:
216-
param_types.append(ninetoothed.dtype.int64)
257+
param_types.append(size_type)
217258
elif match := Tensor.stride_pattern().fullmatch(param):
218259
source_name = match.group(1)
219260
dim_index = int(match.group(3))
@@ -224,7 +265,7 @@ def _build_variant(
224265
constexpr_param_indices.append(len(param_types) - 1)
225266
constexpr_strides.append((source_name, dim_index))
226267
else:
227-
param_types.append(f"{ninetoothed.dtype.int64}:16")
268+
param_types.append(f"{stride_type}:16")
228269
else:
229270
source_name = param
230271
tensor = find_tensor(tensors, source_name)
@@ -331,15 +372,28 @@ def _spec_from_combo(combo):
331372
for divisibility_spec in dim_specs:
332373
for contiguity_spec in dim_specs:
333374
suffix = _variant_suffix(
334-
divisibility_spec, contiguity_spec, launch_arg_names, tensor_ndims
375+
divisibility_spec,
376+
contiguity_spec,
377+
launch_arg_names,
378+
tensor_ndims,
379+
size_type=ninetoothed.dtype.int32,
380+
stride_type=ninetoothed.dtype.int32,
381+
)
382+
specs.append(
383+
(
384+
suffix,
385+
divisibility_spec,
386+
contiguity_spec,
387+
ninetoothed.dtype.int32,
388+
ninetoothed.dtype.int32,
389+
)
335390
)
336-
specs.append((suffix, divisibility_spec, contiguity_spec))
337391

338392
def _num_innermost(spec):
339393
return sum(1 for name, dim in spec if innermost_dims.get(name) == dim)
340394

341395
def _specificity(entry):
342-
_, divisibility_spec, contiguity_spec = entry
396+
_, divisibility_spec, contiguity_spec, _, _ = entry
343397

344398
return (
345399
-len(divisibility_spec),
@@ -350,6 +404,24 @@ def _specificity(entry):
350404

351405
specs.sort(key=_specificity)
352406

407+
fallback_suffix = _variant_suffix(
408+
(),
409+
(),
410+
launch_arg_names,
411+
tensor_ndims,
412+
size_type=ninetoothed.dtype.int64,
413+
stride_type=ninetoothed.dtype.int64,
414+
)
415+
specs.append(
416+
(
417+
fallback_suffix,
418+
(),
419+
(),
420+
ninetoothed.dtype.int64,
421+
ninetoothed.dtype.int64,
422+
)
423+
)
424+
353425
return specs
354426

355427

@@ -381,15 +453,40 @@ def _per_tensor_dim_options(launch_arg_names, tensors, find_tensor):
381453
return per_tensor_dims, tensor_ndims, innermost_dims
382454

383455

384-
def _variant_suffix(divisibility_spec, contiguity_spec, launch_arg_names, tensor_ndims):
456+
def _overflow_terms(launch_arg_names, tensor_ndims):
457+
int32_min = -(2**31)
458+
int32_max = 2**31 - 1
459+
460+
return tuple(
461+
term
462+
for name, ndim in zip(launch_arg_names, tensor_ndims)
463+
for dim in range(ndim)
464+
for term in (
465+
f"{name}.shape[{dim}] > {int32_max}ULL",
466+
f"{name}.strides[{dim}] > {int32_max}LL",
467+
f"{name}.strides[{dim}] < {int32_min}LL",
468+
)
469+
)
470+
471+
472+
def _variant_suffix(
473+
divisibility_spec,
474+
contiguity_spec,
475+
launch_arg_names,
476+
tensor_ndims,
477+
size_type=ninetoothed.dtype.int32,
478+
stride_type=ninetoothed.dtype.int32,
479+
):
385480
divisibility_part = _divisibility_suffix(
386481
divisibility_spec, launch_arg_names, tensor_ndims
387482
)
388483
contiguity_part = _contiguity_suffix(
389484
contiguity_spec, launch_arg_names, tensor_ndims
390485
)
391486

392-
return f"{divisibility_part}_{contiguity_part}"
487+
return (
488+
f"{divisibility_part}_{contiguity_part}_size_{size_type}_stride_{stride_type}"
489+
)
393490

394491

395492
def _divisibility_suffix(divisibility_spec, launch_arg_names, tensor_ndims):

tests/test_aot.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn.functional as F
66

77
import ninetoothed
8+
import ninetoothed.aot
89
import ninetoothed.generation
910
import tests.test_addmm as addmm
1011
import tests.test_attention as attention
@@ -342,6 +343,19 @@ def _application(input, scale, output):
342343
assert torch.allclose(output, expected)
343344

344345

346+
def test_overflow_terms():
347+
terms = ninetoothed.aot._overflow_terms(("input", "scale"), (2, 0))
348+
349+
assert terms == (
350+
"input.shape[0] > 2147483647ULL",
351+
"input.strides[0] > 2147483647LL",
352+
"input.strides[0] < -2147483648LL",
353+
"input.shape[1] > 2147483647ULL",
354+
"input.strides[1] > 2147483647LL",
355+
"input.strides[1] < -2147483648LL",
356+
)
357+
358+
345359
def _generate_kernel_name_suffix():
346360
count = _generate_kernel_name_suffix._kernel_count
347361
_generate_kernel_name_suffix._kernel_count += 1

0 commit comments

Comments
 (0)