Skip to content

Commit 53bf338

Browse files
committed
Add atan2 and tanh rounding_mode for tileiras 13.2
- Add ct.atan2(y, x) binary operation for arctangent - Add optional rounding_mode parameter to ct.tanh() supporting RoundingMode.FULL and RoundingMode.APPROX - Add version checks that raise clear errors when 13.2 features are used with older tileiras - Add requires_tileiras() test utility to skip tests based on tileiras version - Add backwards compatibility test suite with mocked versions Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent bb67dd8 commit 53bf338

File tree

13 files changed

+264
-53
lines changed

13 files changed

+264
-53
lines changed

changelog.d/tileiras-13-2-ops.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
Add support for tileiras 13.2 features:
5+
- New `ct.atan2(y, x)` operation for computing the arctangent of y/x
6+
- Optional `rounding_mode` parameter for `ct.tanh()` (supports `RoundingMode.FULL` and `RoundingMode.APPROX`)
7+
8+
Both features require tileiras 13.2 and will raise a clear error message when used with older versions.

src/cuda/tile/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
argmin,
6666
assert_,
6767
astype,
68+
atan2,
6869
atomic_add,
6970
atomic_and,
7071
atomic_cas,
@@ -198,6 +199,7 @@
198199
"argmin",
199200
"assert_",
200201
"astype",
202+
"atan2",
201203
"atomic_add",
202204
"atomic_and",
203205
"atomic_cas",

src/cuda/tile/_compile.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,11 @@ def wrapper(*args, **kwargs):
7979

8080
def _get_final_ir(pyfunc,
8181
args: Sequence[ir.KernelArgument],
82-
config: TileContextConfig) -> ir.Function:
82+
config: TileContextConfig,
83+
tileiras_version: BytecodeVersion = BytecodeVersion.V_13_1) -> ir.Function:
8384
func_hir: hir.Function = get_function_hir(pyfunc, entry_point=True)
8485

85-
ir_ctx = ir.IRContext(config)
86+
ir_ctx = ir.IRContext(config, tileiras_version)
8687
func_body = hir2ir(func_hir, args, ir_ctx)
8788
eliminate_assign_ops(func_body)
8889
dead_code_elimination_pass(func_body)
@@ -188,7 +189,7 @@ def compile_tile(pyfunc,
188189

189190
param_names = tuple(inspect.signature(pyfunc).parameters.keys())
190191
ir_args = _bind_kernel_arguments(param_names, args, get_constant_annotations(pyfunc))
191-
func_ir = _get_final_ir(pyfunc, ir_args, context.config)
192+
func_ir = _get_final_ir(pyfunc, ir_args, context.config, bytecode_version)
192193

193194
if 'CUTILEIR' in context.config.log_keys:
194195
code = (f"==== CuTile IR for {func_ir.name}==== \n\n"

src/cuda/tile/_ir/ir.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,24 @@
2222
TileTypeError, Loc, TileInternalError
2323
)
2424
from .. import TileSyntaxError
25-
from .._cext import TileContext
2625
from .._context import TileContextConfig
26+
from cuda.tile._bytecode.version import BytecodeVersion
2727

2828
if TYPE_CHECKING:
2929
from cuda.tile._ir2bytecode import BytecodeContext
3030

3131

3232
class IRContext:
33-
def __init__(self, config: TileContextConfig):
33+
def __init__(self, config: TileContextConfig, tileiras_version: BytecodeVersion):
3434
self._all_vars: Dict[str, str] = {}
3535
self._counter_by_name: Dict[str, Iterator[int]] = defaultdict(itertools.count)
3636
self._temp_counter = itertools.count()
3737
self.typemap: Dict[str, Type] = dict()
3838
self.constants: Dict[str, Any] = dict()
3939
self._loose_typemap: Dict[str, Type] = dict()
40-
self.config: TileContext = config
40+
self.config: TileContextConfig = config
4141
self._aggregate_values: Dict[str, Any] = dict()
42+
self.tileiras_version: BytecodeVersion = tileiras_version
4243

4344
# Make a Var with a unique name based on `name`.
4445
def make_var(self, name: str, loc: Loc, undefined: bool = False) -> Var:

src/cuda/tile/_ir/op_impl.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from cuda.tile._datatype import (
1313
is_integral, is_float, is_restricted_float,
1414
is_boolean, is_signed, DType)
15-
from cuda.tile._exception import TileTypeError
15+
from cuda.tile._bytecode.version import BytecodeVersion
16+
from cuda.tile._exception import TileTypeError, TileUnsupportedFeatureError
1617
from cuda.tile._ir.ops_utils import get_dtype
1718

1819
from .typing_support import datatype, get_signature
19-
from .ir import Var, TupleValue
20+
from .ir import Var, TupleValue, Builder
2021
from .type import TupleTy, TileTy, DTypeSpec, EnumTy, StringTy, ArrayTy, SliceType, \
2122
ListTy, LooselyTypedScalar, RangeIterType, FunctionTy, ClosureTy, BoundMethodTy, \
2223
DTypeConstructor, Type
@@ -36,9 +37,19 @@ def _verify_params_match(stub_sig: inspect.Signature, func_sig: inspect.Signatur
3637
op_implementations = dict()
3738

3839

39-
def impl(stub, *, fixed_args: Sequence[Any] = ()):
40+
def impl(stub, *, fixed_args: Sequence[Any] = (),
41+
min_version: Optional[BytecodeVersion] = None):
4042
stub_sig = get_signature(stub)
4143

44+
def _check_version():
45+
cur_version = Builder.get_current().ir_ctx.tileiras_version
46+
if min_version is not None and cur_version < min_version:
47+
raise TileUnsupportedFeatureError(
48+
f"{stub.__name__} requires tileiras "
49+
f"{min_version.major()}.{min_version.minor()} or later. "
50+
f"Current version is {cur_version.major()}.{cur_version.minor()}."
51+
)
52+
4253
def decorate(func):
4354
orig_func = func
4455
if len(fixed_args) > 0:
@@ -50,6 +61,7 @@ def decorate(func):
5061
if is_coroutine:
5162
@functools.wraps(func)
5263
async def wrapper(*args, **kwargs):
64+
_check_version()
5365
# Memorize the stub and the args so that we can automatically
5466
# provide context for error messages.
5567
old = _current_stub.stub_and_args
@@ -61,6 +73,7 @@ async def wrapper(*args, **kwargs):
6173
else:
6274
@functools.wraps(func)
6375
def wrapper(*args, **kwargs):
76+
_check_version()
6477
# Memorize the stub and the args so that we can automatically
6578
# provide context for error messages.
6679
old = _current_stub.stub_and_args

src/cuda/tile/_ir/ops.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545
from .ops_utils import (
4646
BINOP_REGISTRY, UNARYOP_REGISTRY,
4747
check_rd_and_ftz, PaddingMode,
48-
rounding_mode_to_bytecode, get_dtype, change_dtype, memory_order_to_bytecode,
48+
rounding_mode_to_bytecode, get_default_rounding_mode, get_dtype,
49+
change_dtype, memory_order_to_bytecode,
4950
memory_scope_to_bytecode, broadcast_shapes2, is_shape_broadcastable_to, BroadcastError,
5051
promote_types, promote_dtypes, check_implicit_cast
5152
)
@@ -68,6 +69,7 @@
6869
get_list_partition_view_tile_size, tensor_view_typeid, tensor_view_typeid_for_list, dtype_typeid
6970
)
7071
import cuda.tile._bytecode as bc
72+
from cuda.tile._bytecode.version import BytecodeVersion
7173
from .._debug import CUDA_TILE_TESTING_DISABLE_DIV
7274

7375

@@ -689,10 +691,11 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
689691
lhs = ctx.cast(ctx.get_value(self.lhs), ctx.typeof(self.lhs), result_type)
690692
rhs = ctx.cast(ctx.get_value(self.rhs), ctx.typeof(self.rhs), result_type)
691693
acc = ctx.cast(ctx.get_value(self.acc), ctx.typeof(self.acc), result_type)
694+
rm = self.rounding_mode if self.rounding_mode is not None else get_default_rounding_mode()
692695
return bc.encode_FmaOp(ctx.builder,
693696
ctx.typeid_of(self.result_var),
694697
lhs, rhs, acc,
695-
rounding_mode_to_bytecode[self.rounding_mode],
698+
rounding_mode_to_bytecode[rm],
696699
self.flush_to_zero)
697700

698701

@@ -966,7 +969,8 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
966969
dtype = get_dtype(result_ty)
967970
kind = "float" if datatype.is_float(dtype) else "int"
968971
res_typeid = typeid(ctx.type_table, result_ty)
969-
rounding_mode = rounding_mode_to_bytecode[self.rounding_mode]
972+
rm = self.rounding_mode if self.rounding_mode is not None else get_default_rounding_mode()
973+
rounding_mode = rounding_mode_to_bytecode[rm]
970974
lhs = ctx.get_value(self.lhs)
971975
rhs = ctx.get_value(self.rhs)
972976

@@ -1006,6 +1010,8 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
10061010
flush_to_zero=self.flush_to_zero)
10071011
case "pow", "float":
10081012
return bc.encode_PowOp(ctx.builder, res_typeid, lhs, rhs)
1013+
case "atan2", "float":
1014+
return bc.encode_Atan2Op(ctx.builder, res_typeid, lhs, rhs)
10091015
case "min", "int":
10101016
return bc.encode_MinIOp(ctx.builder, res_typeid, lhs, rhs,
10111017
signedness=datatype.get_signedness(dtype))
@@ -1087,6 +1093,11 @@ def binary_arithmetic_impl_with_ftz(fn: str, x: Var, y: Var, flush_to_zero: Var)
10871093
return binary_arithmetic(fn, x, y, flush_to_zero=flush_to_zero)
10881094

10891095

1096+
@impl(ct.atan2, min_version=BytecodeVersion.V_13_2)
1097+
def atan2_impl(x1: Var, x2: Var) -> Var:
1098+
return binary_arithmetic("atan2", x1, x2)
1099+
1100+
10901101
@impl(ct.add, fixed_args=["add"])
10911102
@impl(ct.sub, fixed_args=["sub"])
10921103
@impl(ct.mul, fixed_args=["mul"])
@@ -1347,7 +1358,9 @@ def __init__(self, fn: str, operand: Var,
13471358
@override
13481359
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
13491360
x = ctx.get_value(self.operand)
1350-
rounding_mode = rounding_mode_to_bytecode[self.rounding_mode]
1361+
rm = (self.rounding_mode if self.rounding_mode is not None
1362+
else get_default_rounding_mode(self.fn))
1363+
rounding_mode = rounding_mode_to_bytecode[rm]
13511364
flush_to_zero = self.flush_to_zero
13521365
input_type = ctx.typeof(self.operand)
13531366
input_dtype = get_dtype(input_type)
@@ -1368,9 +1381,8 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
13681381
case "sinh", True: return bc.encode_SinHOp(ctx.builder, res_type_id, x)
13691382
case "cosh", True: return bc.encode_CosHOp(ctx.builder, res_type_id, x)
13701383
case "tan", True: return bc.encode_TanOp(ctx.builder, res_type_id, x)
1371-
# TODO: rounding mode support depending on bytecode version
13721384
case "tanh", True: return bc.encode_TanHOp(ctx.builder, res_type_id, x,
1373-
rounding_mode=bc.RoundingMode.FULL)
1385+
rounding_mode=rounding_mode)
13741386
case "log", True: return bc.encode_LogOp(ctx.builder, res_type_id, x)
13751387
case "log2", True: return bc.encode_Log2Op(ctx.builder, res_type_id, x)
13761388
case "sqrt", True: return bc.encode_SqrtOp(ctx.builder, res_type_id, x,
@@ -1486,7 +1498,6 @@ def pos_impl(x: Var):
14861498
@impl(ct.log, fixed_args=["log", _UNARY_FLOAT])
14871499
@impl(ct.log2, fixed_args=["log2", _UNARY_FLOAT])
14881500
@impl(ct.tan, fixed_args=["tan", _UNARY_FLOAT])
1489-
@impl(ct.tanh, fixed_args=["tanh", _UNARY_FLOAT])
14901501
@impl(ct.sin, fixed_args=["sin", _UNARY_FLOAT])
14911502
@impl(ct.sinh, fixed_args=["sinh", _UNARY_FLOAT])
14921503
@impl(ct.cos, fixed_args=["cos", _UNARY_FLOAT])
@@ -1519,6 +1530,12 @@ def unary_impl_with_rd_and_ftz(fn: str, behavior: _UnaryBehavior,
15191530
return unary(fn, behavior, x, rounding_mode=rounding_mode, flush_to_zero=flush_to_zero)
15201531

15211532

1533+
@impl(ct.tanh, fixed_args=["tanh", _UNARY_FLOAT])
1534+
def unary_impl_with_rd(fn: str, behavior: _UnaryBehavior, x: Var, rounding_mode: Var) -> Var:
1535+
rounding_mode = require_optional_constant_enum(rounding_mode, RoundingMode)
1536+
return unary(fn, behavior, x, rounding_mode=rounding_mode)
1537+
1538+
15221539
@impl(getattr)
15231540
def getattr_impl(object: Var, name: Var) -> Var:
15241541
ty = object.get_type()

src/cuda/tile/_ir/ops_utils.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
import itertools
55
import math
66

7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
88
from typing import Optional, Tuple, Dict, Any, Sequence
99
from enum import Enum
1010

1111
from cuda.tile import _datatype as datatype
1212

13+
from cuda.tile._bytecode.version import BytecodeVersion
1314
from cuda.tile._numeric_semantics import RoundingMode, PaddingMode
14-
from cuda.tile._exception import Loc, TileTypeError, TileValueError
15+
from cuda.tile._exception import Loc, TileTypeError, TileValueError, TileUnsupportedFeatureError
1516
from cuda.tile._memory_model import MemoryOrder, MemoryScope
1617
import cuda.tile._bytecode as bc
1718

18-
from .ir import Operation
19+
from .ir import Operation, Builder
1920
from .type import TileTy, PointerTy, LooselyTypedScalar, make_tile_ty
2021
from .typing_support import typeof_pyval
2122
from .._datatype import DType, _DTypePromotionImpl, NumericDTypeCategory, NumericDTypeCategories, \
@@ -34,30 +35,29 @@ class ComparisonPredicates(Enum):
3435
@dataclass
3536
class MathOpDef:
3637
impl: callable # Python scalar fallback
37-
supported_rounding_modes: Tuple[RoundingMode, ...] = ()
38+
supported_rounding_modes: Dict[RoundingMode, Optional[BytecodeVersion]] = field(
39+
default_factory=dict)
3840
support_flush_to_zero: bool = False
3941

4042

43+
_RD_BASIC = {RoundingMode.RN: None, RoundingMode.RZ: None,
44+
RoundingMode.RM: None, RoundingMode.RP: None}
45+
_RD_TRUEDIV = {**_RD_BASIC, RoundingMode.FULL: None, RoundingMode.APPROX: None}
46+
_RD_SQRT = {**_RD_BASIC, RoundingMode.APPROX: None}
47+
_RD_TANH = {RoundingMode.FULL: None, RoundingMode.APPROX: BytecodeVersion.V_13_2}
48+
4149
BINOP_REGISTRY = {
42-
"add": MathOpDef(lambda x, y: x + y,
43-
(RoundingMode.RN, RoundingMode.RZ, RoundingMode.RM, RoundingMode.RP),
44-
support_flush_to_zero=True),
45-
"sub": MathOpDef(lambda x, y: x - y,
46-
(RoundingMode.RN, RoundingMode.RZ, RoundingMode.RM, RoundingMode.RP),
47-
support_flush_to_zero=True),
48-
"mul": MathOpDef(lambda x, y: x * y,
49-
(RoundingMode.RN, RoundingMode.RZ, RoundingMode.RM, RoundingMode.RP),
50-
support_flush_to_zero=True),
50+
"add": MathOpDef(lambda x, y: x + y, _RD_BASIC, support_flush_to_zero=True),
51+
"sub": MathOpDef(lambda x, y: x - y, _RD_BASIC, support_flush_to_zero=True),
52+
"mul": MathOpDef(lambda x, y: x * y, _RD_BASIC, support_flush_to_zero=True),
5153
"floordiv": MathOpDef(lambda x, y: x // y),
5254
"cdiv": MathOpDef(lambda x, y: (x + y - 1) // y),
53-
"truediv": MathOpDef(lambda x, y: x / y,
54-
(RoundingMode.RN, RoundingMode.RZ, RoundingMode.RM, RoundingMode.RP,
55-
RoundingMode.FULL, RoundingMode.APPROX),
56-
support_flush_to_zero=True),
55+
"truediv": MathOpDef(lambda x, y: x / y, _RD_TRUEDIV, support_flush_to_zero=True),
5756
"mod": MathOpDef(lambda x, y: x % y),
5857
"pow": MathOpDef(lambda x, y: x ** y),
59-
"max": MathOpDef(max, (), support_flush_to_zero=True),
60-
"min": MathOpDef(min, (), support_flush_to_zero=True),
58+
"atan2": MathOpDef(math.atan2),
59+
"max": MathOpDef(max, support_flush_to_zero=True),
60+
"min": MathOpDef(min, support_flush_to_zero=True),
6161
"and_": MathOpDef(lambda x, y: x & y),
6262
"or_": MathOpDef(lambda x, y: x | y),
6363
"xor": MathOpDef(lambda x, y: x ^ y),
@@ -81,30 +81,26 @@ class MathOpDef:
8181
"abs": MathOpDef(abs),
8282
"neg": MathOpDef(lambda x: -x),
8383
"exp": MathOpDef(math.exp),
84-
"exp2": MathOpDef(lambda x: 2 ** x, (), support_flush_to_zero=True),
84+
"exp2": MathOpDef(lambda x: 2 ** x, support_flush_to_zero=True),
8585
"sin": MathOpDef(math.sin),
8686
"sinh": MathOpDef(math.sinh),
8787
"cos": MathOpDef(math.cos),
8888
"cosh": MathOpDef(math.cosh),
8989
"tan": MathOpDef(math.tan),
90-
# TODO: RoundingMode support dependent on bytecode version
91-
"tanh": MathOpDef(math.tanh),
90+
"tanh": MathOpDef(math.tanh, _RD_TANH),
9291
"log": MathOpDef(math.log),
9392
"log2": MathOpDef(math.log2),
94-
"sqrt": MathOpDef(math.sqrt,
95-
(RoundingMode.RN, RoundingMode.RZ, RoundingMode.RM, RoundingMode.RP,
96-
RoundingMode.APPROX),
97-
support_flush_to_zero=True),
98-
"rsqrt": MathOpDef(lambda x: x ** -0.5, (), support_flush_to_zero=True),
93+
"sqrt": MathOpDef(math.sqrt, _RD_SQRT, support_flush_to_zero=True),
94+
"rsqrt": MathOpDef(lambda x: x ** -0.5, support_flush_to_zero=True),
9995
"invert": MathOpDef(lambda x: ~x),
10096
"not_": MathOpDef(lambda x: not x),
10197
"floor": MathOpDef(math.floor),
10298
"ceil": MathOpDef(math.ceil),
10399
}
104100

105101

106-
def get_default_rounding_mode():
107-
return RoundingMode.RN
102+
def get_default_rounding_mode(opname: Optional[str] = None):
103+
return RoundingMode.FULL if opname == 'tanh' else RoundingMode.RN
108104

109105

110106
rounding_mode_to_bytecode = {
@@ -117,8 +113,6 @@ def get_default_rounding_mode():
117113
RoundingMode.RZI: bc.RoundingMode.NEAREST_INT_TO_ZERO
118114
}
119115

120-
rounding_mode_to_bytecode[None] = rounding_mode_to_bytecode[get_default_rounding_mode()]
121-
122116

123117
def get_rounding_mode(op: Operation, constants: Dict[str, Any]) -> Optional[RoundingMode]:
124118
return (
@@ -146,6 +140,14 @@ def check_rd_and_ftz(fn: str, rounding_mode: Optional[RoundingMode], flush_to_ze
146140
if rounding_mode not in math_op_def.supported_rounding_modes:
147141
raise TileTypeError(
148142
f'Rounding mode {rounding_mode.value} is not supported for {fn}')
143+
min_version = math_op_def.supported_rounding_modes[rounding_mode]
144+
if min_version is not None:
145+
cur_version = Builder.get_current().ir_ctx.tileiras_version
146+
if cur_version < min_version:
147+
raise TileUnsupportedFeatureError(
148+
f'{fn} rounding_mode={rounding_mode.value} requires tileiras '
149+
f'{min_version.major()}.{min_version.minor()} or later. '
150+
f'Current version is {cur_version.major()}.{cur_version.minor()}.')
149151
if not datatype.is_float(dtype):
150152
raise TileTypeError(
151153
f'Rounding mode can only be used for float types, '

src/cuda/tile/_ir2bytecode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def lower_scan(ctx: "BytecodeContext", x: bc.Value, input_ty: Type,
240240

241241
element_tile_typeid = tt.tile(element_type_id, ())
242242
with nested_builder.new_block((element_tile_typeid, element_tile_typeid)) as (a, b):
243-
rounding_mode_bc = rounding_mode_to_bytecode[rounding_mode]
243+
rm = rounding_mode if rounding_mode is not None else get_default_rounding_mode()
244+
rounding_mode_bc = rounding_mode_to_bytecode[rm]
244245
match scan_fn, use_float:
245246
case "add", True:
246247
res = bc.encode_AddFOp(ctx.builder, element_tile_typeid, a, b,

0 commit comments

Comments
 (0)