|
8 | 8 | from cuda.tile._ir.ir import Block |
9 | 9 | from cuda.tile._compile import compile_tile |
10 | 10 | from cuda.tile.compilation import ( |
11 | | - ParameterConstraint, ArrayConstraint, ListConstraint, ConstantConstraint, KernelSignature |
| 11 | + ParameterConstraint, ArrayConstraint, ListConstraint, |
| 12 | + ConstantConstraint, ScalarConstraint, KernelSignature |
12 | 13 | ) |
13 | 14 | from cuda.tile._cext import CallingConvention |
| 15 | +from cuda.tile._exception import TileTypeError |
14 | 16 | from typing import Sequence |
15 | 17 |
|
16 | 18 |
|
@@ -104,6 +106,109 @@ def kernel(x): |
104 | 106 | assert get_op_divby(body, MakeTensorView) == [{}] |
105 | 107 |
|
106 | 108 |
|
| 109 | +# --- ct.assume_divisible_by --- |
| 110 | + |
| 111 | +def test_assume_divisible_by_emits_op(): |
| 112 | + def kernel(x, n: int): |
| 113 | + n = ct.assume_divisible_by(n, 16) |
| 114 | + ct.store(x, (n,), 0) |
| 115 | + |
| 116 | + body = get_ir(kernel, ( |
| 117 | + array_arg(ndim=1, base_div=1, stride_const=(1,)), |
| 118 | + ScalarConstraint(ct.int32), |
| 119 | + )) |
| 120 | + ops = [op.divisor for op in body.traverse() if isinstance(op, AssumeDivBy)] |
| 121 | + assert ops == [16] |
| 122 | + |
| 123 | + |
| 124 | +def test_assume_divisible_by_propagates_to_dynamic_slice(): |
| 125 | + def kernel(x, start_factor: int, extent: int): |
| 126 | + start_factor = ct.assume_divisible_by(start_factor, 32) |
| 127 | + extent = ct.assume_divisible_by(extent, 32) |
| 128 | + start = ct.bid(0) * start_factor |
| 129 | + stop = start + extent |
| 130 | + sub_x = x.slice(axis=0, start=start, stop=stop) |
| 131 | + tile = ct.load(sub_x, index=(0,), shape=(1,)) |
| 132 | + ct.store(sub_x, index=(0,), tile=tile) |
| 133 | + |
| 134 | + body = get_ir(kernel, ( |
| 135 | + array_arg(dtype=ct.bfloat16, base_div=16, stride_const=(1,)), |
| 136 | + ScalarConstraint(ct.int32), |
| 137 | + ScalarConstraint(ct.int32), |
| 138 | + )) |
| 139 | + |
| 140 | + divby = get_op_divby(body, MakeTensorView)[0] |
| 141 | + assert divby.get('base_ptr') == 16 and divby.get('shape[0]') == 32 |
| 142 | + |
| 143 | + |
| 144 | +def test_assume_divisible_by_divisor_one_is_noop(): |
| 145 | + def kernel(x, n: int): |
| 146 | + n = ct.assume_divisible_by(n, 1) |
| 147 | + ct.store(x, (n,), 0) |
| 148 | + |
| 149 | + body = get_ir(kernel, ( |
| 150 | + array_arg(ndim=1, base_div=1, stride_const=(1,)), |
| 151 | + ScalarConstraint(ct.int32), |
| 152 | + )) |
| 153 | + ops = [op for op in body.traverse() if isinstance(op, AssumeDivBy)] |
| 154 | + assert ops == [] |
| 155 | + |
| 156 | + |
| 157 | +def test_assume_divisible_by_non_power_of_two_divisor(): |
| 158 | + # divisor=12 has largest power-of-2 factor 4. |
| 159 | + # The propagate_divby pass extracts that power-of-2 when inserting |
| 160 | + # AssumeDivBy before MakeTensorView, so shape[0] ends up with divisor=4. |
| 161 | + def kernel(x, extent: int): |
| 162 | + extent = ct.assume_divisible_by(extent, 12) |
| 163 | + sub_x = x.slice(axis=0, start=0, stop=extent) |
| 164 | + tile = ct.load(sub_x, index=(0,), shape=(1,)) |
| 165 | + ct.store(sub_x, index=(0,), tile=tile) |
| 166 | + |
| 167 | + body = get_ir(kernel, ( |
| 168 | + array_arg(ndim=1, base_div=1, stride_const=(1,)), |
| 169 | + ScalarConstraint(ct.int32), |
| 170 | + )) |
| 171 | + divby = get_op_divby(body, MakeTensorView)[0] |
| 172 | + assert divby.get('shape[0]') == 4 |
| 173 | + |
| 174 | + |
| 175 | +def test_assume_divisible_by_type_error_on_float(): |
| 176 | + def kernel(x, f: float): |
| 177 | + f = ct.assume_divisible_by(f, 16) |
| 178 | + ct.store(x, (0,), 0) |
| 179 | + |
| 180 | + with pytest.raises(TileTypeError, match="integer scalar"): |
| 181 | + get_ir(kernel, ( |
| 182 | + array_arg(ndim=1, stride_const=(1,)), |
| 183 | + ScalarConstraint(ct.float32), |
| 184 | + )) |
| 185 | + |
| 186 | + |
| 187 | +def test_assume_divisible_by_error_on_nonconstant_divisor(): |
| 188 | + def kernel(x, n: int, d: int): |
| 189 | + n = ct.assume_divisible_by(n, d) |
| 190 | + ct.store(x, (n,), 0) |
| 191 | + |
| 192 | + with pytest.raises(TileTypeError, match="integer constant"): |
| 193 | + get_ir(kernel, ( |
| 194 | + array_arg(ndim=1, stride_const=(1,)), |
| 195 | + ScalarConstraint(ct.int32), |
| 196 | + ScalarConstraint(ct.int32), |
| 197 | + )) |
| 198 | + |
| 199 | + |
| 200 | +def test_assume_divisible_by_error_on_nonpositive_divisor(): |
| 201 | + def kernel(x, n: int): |
| 202 | + n = ct.assume_divisible_by(n, 0) |
| 203 | + ct.store(x, (n,), 0) |
| 204 | + |
| 205 | + with pytest.raises(TileTypeError, match="positive divisor"): |
| 206 | + get_ir(kernel, ( |
| 207 | + array_arg(ndim=1, stride_const=(1,)), |
| 208 | + ScalarConstraint(ct.int32), |
| 209 | + )) |
| 210 | + |
| 211 | + |
107 | 212 | # --- Control flow propagation --- |
108 | 213 |
|
109 | 214 | def test_if_else(): |
|
0 commit comments