Skip to content

Commit 4f9e5c9

Browse files
committed
Expose assume_divby
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent bba2494 commit 4f9e5c9

7 files changed

Lines changed: 195 additions & 2 deletions

File tree

changelog.d/assume-divby.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Added `ct.assume_divisible_by(x, divisor)`, a compiler hint that declares an integer scalar to be divisible by a constant. The divisibility fact is propagated through arithmetic, e.g., allowing the compiler to prove alignment for derived indices and pointer offsets to emit wider memory operations than it could with unknown divisibility.

docs/source/operations.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ Utility
214214
printf
215215
print
216216
assert_
217+
assume_divisible_by
217218

218219

219220
.. _operations-metaprogramming:

docs/source/performance.rst

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ Performance Tuning
1010
Several performance tuning techniques are available in cuTile:
1111

1212
* architecture-specific configuration values, using :class:`ByTarget`;
13-
* load/store hints such as ``latency`` and ``allow_tma``.
13+
* load/store hints such as ``latency`` and ``allow_tma``;
14+
* divisibility hints via :func:`assume_divisible_by`.
1415

1516

1617
Architecture-specific configuration
@@ -53,6 +54,43 @@ Example
5354
:end-before: example-end
5455

5556

57+
.. _divisibility-hints:
58+
59+
Divisibility hints
60+
------------------
61+
62+
:func:`assume_divisible_by` is a compiler hint that declares an integer
63+
scalar to be divisible by a compile-time constant. No check is performed at
64+
runtime:
65+
66+
.. code-block:: python
67+
68+
n = ct.assume_divisible_by(n, 16)
69+
70+
The compiler propagates the divisibility metadata through arithmetic operations — so
71+
that derived indices and pointer offsets inherit the same fact. This matters
72+
most when a runtime scalar is used to compute a dynamic array slice:
73+
74+
.. code-block:: python
75+
76+
@ct.kernel
77+
def kernel(x, dim_offset: int, dim_size: int):
78+
dim_offset = ct.assume_divisible_by(dim_offset, 16)
79+
dim_size = ct.assume_divisible_by(dim_size, 16)
80+
start = ct.bid(0) * dim_offset
81+
sub_x = x.slice(axis=0, start=start, stop=start + dim_size)
82+
tile = ct.load(sub_x, index=(0,), shape=(128,))
83+
ct.store(sub_x, index=(0,), tile=tile)
84+
85+
Without the hints, the compiler treats ``dim_offset`` and ``dim_size`` as
86+
fully unknown and cannot prove alignment for the derived view. With the
87+
hints, it can infer alignment all the way into the view's base address and
88+
shape, enabling wider memory operations.
89+
90+
The hint is a programmer declaration, not an enforcement. Behavior is undefined
91+
if ``x`` is not actually divisible by ``divisor`` at runtime.
92+
93+
5694
.. _autotuning:
5795

5896
Autotuning

src/cuda/tile/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
argmax,
7575
argmin,
7676
assert_,
77+
assume_divisible_by,
7778
astile,
7879
astype,
7980
atan2,
@@ -236,6 +237,7 @@
236237
"argmax",
237238
"argmin",
238239
"assert_",
240+
"assume_divisible_by",
239241
"astile",
240242
"astype",
241243
"atan2",

src/cuda/tile/_ir/ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,6 +2132,19 @@ def assume_div_by(x: Var, divisor: int | None) -> Var:
21322132
return add_operation(AssumeDivBy, x.get_type(), x=x, divisor=divisor)
21332133

21342134

2135+
@impl(ct.assume_divisible_by)
2136+
def assume_divisible_by_impl(x: Var, divisor: Var) -> Var:
2137+
ty = x.get_type()
2138+
if not is_0d_tile(ty, is_integral):
2139+
raise TileTypeError(
2140+
f"`assume_divisible_by` requires an integer scalar, got {ty}")
2141+
divisor_val = require_constant_int(divisor)
2142+
if divisor_val < 1:
2143+
raise TileTypeError(
2144+
f"`assume_divisible_by` requires a positive divisor, got {divisor_val}")
2145+
return assume_div_by(x, divisor_val)
2146+
2147+
21352148
@dataclass(eq=False)
21362149
class AssumeBounded(Operation, opcode="assume_bounded"):
21372150
lower_bound: int | None = attribute()

src/cuda/tile/_stub.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,36 @@ def kernel(x):
11831183
"""
11841184

11851185

1186+
@stub
1187+
def assume_divisible_by(x: Scalar, divisor: int) -> Scalar:
1188+
"""Declares that ``x`` is divisible by ``divisor``.
1189+
1190+
The caller is responsible for the correctness of the claim;
1191+
behavior is undefined if ``x`` is not actually divisible by
1192+
``divisor`` at runtime.
1193+
1194+
Args:
1195+
x: An integer scalar.
1196+
divisor (const int): The assumed divisor. Must be a positive integer constant.
1197+
1198+
Returns:
1199+
An integer scalar. ``x`` value unchanged.
1200+
1201+
Examples:
1202+
1203+
.. testcode::
1204+
:template: kernel_wrapper.py
1205+
1206+
n = ct.bid(0) + 128
1207+
n = ct.assume_divisible_by(n, 128)
1208+
print(n)
1209+
1210+
.. testoutput::
1211+
1212+
128
1213+
"""
1214+
1215+
11861216
@stub
11871217
def load(array: Array, /,
11881218
index: Shape,

test/test_propagate_divby.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from cuda.tile._ir.ir import Block
99
from cuda.tile._compile import compile_tile
1010
from cuda.tile.compilation import (
11-
ParameterConstraint, ArrayConstraint, ListConstraint, ConstantConstraint, KernelSignature
11+
ParameterConstraint, ArrayConstraint, ListConstraint,
12+
ConstantConstraint, ScalarConstraint, KernelSignature
1213
)
1314
from cuda.tile._cext import CallingConvention
15+
from cuda.tile._exception import TileTypeError
1416
from typing import Sequence
1517

1618

@@ -104,6 +106,109 @@ def kernel(x):
104106
assert get_op_divby(body, MakeTensorView) == [{}]
105107

106108

109+
# --- ct.assume_divisible_by ---
110+
111+
def test_assume_divisible_by_emits_op():
112+
def kernel(x, n: int):
113+
n = ct.assume_divisible_by(n, 16)
114+
ct.store(x, (n,), 0)
115+
116+
body = get_ir(kernel, (
117+
array_arg(ndim=1, base_div=1, stride_const=(1,)),
118+
ScalarConstraint(ct.int32),
119+
))
120+
ops = [op.divisor for op in body.traverse() if isinstance(op, AssumeDivBy)]
121+
assert ops == [16]
122+
123+
124+
def test_assume_divisible_by_propagates_to_dynamic_slice():
125+
def kernel(x, start_factor: int, extent: int):
126+
start_factor = ct.assume_divisible_by(start_factor, 32)
127+
extent = ct.assume_divisible_by(extent, 32)
128+
start = ct.bid(0) * start_factor
129+
stop = start + extent
130+
sub_x = x.slice(axis=0, start=start, stop=stop)
131+
tile = ct.load(sub_x, index=(0,), shape=(1,))
132+
ct.store(sub_x, index=(0,), tile=tile)
133+
134+
body = get_ir(kernel, (
135+
array_arg(dtype=ct.bfloat16, base_div=16, stride_const=(1,)),
136+
ScalarConstraint(ct.int32),
137+
ScalarConstraint(ct.int32),
138+
))
139+
140+
divby = get_op_divby(body, MakeTensorView)[0]
141+
assert divby.get('base_ptr') == 16 and divby.get('shape[0]') == 32
142+
143+
144+
def test_assume_divisible_by_divisor_one_is_noop():
145+
def kernel(x, n: int):
146+
n = ct.assume_divisible_by(n, 1)
147+
ct.store(x, (n,), 0)
148+
149+
body = get_ir(kernel, (
150+
array_arg(ndim=1, base_div=1, stride_const=(1,)),
151+
ScalarConstraint(ct.int32),
152+
))
153+
ops = [op for op in body.traverse() if isinstance(op, AssumeDivBy)]
154+
assert ops == []
155+
156+
157+
def test_assume_divisible_by_non_power_of_two_divisor():
158+
# divisor=12 has largest power-of-2 factor 4.
159+
# The propagate_divby pass extracts that power-of-2 when inserting
160+
# AssumeDivBy before MakeTensorView, so shape[0] ends up with divisor=4.
161+
def kernel(x, extent: int):
162+
extent = ct.assume_divisible_by(extent, 12)
163+
sub_x = x.slice(axis=0, start=0, stop=extent)
164+
tile = ct.load(sub_x, index=(0,), shape=(1,))
165+
ct.store(sub_x, index=(0,), tile=tile)
166+
167+
body = get_ir(kernel, (
168+
array_arg(ndim=1, base_div=1, stride_const=(1,)),
169+
ScalarConstraint(ct.int32),
170+
))
171+
divby = get_op_divby(body, MakeTensorView)[0]
172+
assert divby.get('shape[0]') == 4
173+
174+
175+
def test_assume_divisible_by_type_error_on_float():
176+
def kernel(x, f: float):
177+
f = ct.assume_divisible_by(f, 16)
178+
ct.store(x, (0,), 0)
179+
180+
with pytest.raises(TileTypeError, match="integer scalar"):
181+
get_ir(kernel, (
182+
array_arg(ndim=1, stride_const=(1,)),
183+
ScalarConstraint(ct.float32),
184+
))
185+
186+
187+
def test_assume_divisible_by_error_on_nonconstant_divisor():
188+
def kernel(x, n: int, d: int):
189+
n = ct.assume_divisible_by(n, d)
190+
ct.store(x, (n,), 0)
191+
192+
with pytest.raises(TileTypeError, match="integer constant"):
193+
get_ir(kernel, (
194+
array_arg(ndim=1, stride_const=(1,)),
195+
ScalarConstraint(ct.int32),
196+
ScalarConstraint(ct.int32),
197+
))
198+
199+
200+
def test_assume_divisible_by_error_on_nonpositive_divisor():
201+
def kernel(x, n: int):
202+
n = ct.assume_divisible_by(n, 0)
203+
ct.store(x, (n,), 0)
204+
205+
with pytest.raises(TileTypeError, match="positive divisor"):
206+
get_ir(kernel, (
207+
array_arg(ndim=1, stride_const=(1,)),
208+
ScalarConstraint(ct.int32),
209+
))
210+
211+
107212
# --- Control flow propagation ---
108213

109214
def test_if_else():

0 commit comments

Comments
 (0)