Skip to content

Commit f50684b

Browse files
committed
Add a test case for AOT int32 overflow checks
1 parent 0e12d0b commit f50684b

1 file changed

Lines changed: 14 additions & 0 deletions

File tree

tests/test_aot.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn.functional as F
66

77
import ninetoothed
8+
import ninetoothed.aot
89
import ninetoothed.generation
910
import tests.test_addmm as addmm
1011
import tests.test_attention as attention
@@ -342,6 +343,19 @@ def _application(input, scale, output):
342343
assert torch.allclose(output, expected)
343344

344345

346+
def test_overflow_terms():
347+
terms = ninetoothed.aot._overflow_terms(("input", "scale"), (2, 0))
348+
349+
assert terms == (
350+
"input.shape[0] > 2147483647ULL",
351+
"input.strides[0] > 2147483647LL",
352+
"input.strides[0] < -2147483648LL",
353+
"input.shape[1] > 2147483647ULL",
354+
"input.strides[1] > 2147483647LL",
355+
"input.strides[1] < -2147483648LL",
356+
)
357+
358+
345359
def _generate_kernel_name_suffix():
346360
count = _generate_kernel_name_suffix._kernel_count
347361
_generate_kernel_name_suffix._kernel_count += 1

0 commit comments

Comments
 (0)