44import itertools
55import math
66
7- from dataclasses import dataclass
7+ from dataclasses import dataclass , field
88from typing import Optional , Tuple , Dict , Any , Sequence
99from enum import Enum
1010
1111from cuda .tile import _datatype as datatype
1212
13+ from cuda .tile ._bytecode .version import BytecodeVersion
1314from cuda .tile ._numeric_semantics import RoundingMode , PaddingMode
14- from cuda .tile ._exception import Loc , TileTypeError , TileValueError
15+ from cuda .tile ._exception import Loc , TileTypeError , TileValueError , TileUnsupportedFeatureError
1516from cuda .tile ._memory_model import MemoryOrder , MemoryScope
1617import cuda .tile ._bytecode as bc
1718
18- from .ir import Operation
19+ from .ir import Operation , Builder
1920from .type import TileTy , PointerTy , LooselyTypedScalar , make_tile_ty
2021from .typing_support import typeof_pyval
2122from .._datatype import DType , _DTypePromotionImpl , NumericDTypeCategory , NumericDTypeCategories , \
@@ -34,30 +35,29 @@ class ComparisonPredicates(Enum):
3435@dataclass
3536class MathOpDef :
3637 impl : callable # Python scalar fallback
37- supported_rounding_modes : Tuple [RoundingMode , ...] = ()
38+ supported_rounding_modes : Dict [RoundingMode , Optional [BytecodeVersion ]] = field (
39+ default_factory = dict )
3840 support_flush_to_zero : bool = False
3941
4042
43+ _RD_BASIC = {RoundingMode .RN : None , RoundingMode .RZ : None ,
44+ RoundingMode .RM : None , RoundingMode .RP : None }
45+ _RD_TRUEDIV = {** _RD_BASIC , RoundingMode .FULL : None , RoundingMode .APPROX : None }
46+ _RD_SQRT = {** _RD_BASIC , RoundingMode .APPROX : None }
47+ _RD_TANH = {RoundingMode .FULL : None , RoundingMode .APPROX : BytecodeVersion .V_13_2 }
48+
4149BINOP_REGISTRY = {
42- "add" : MathOpDef (lambda x , y : x + y ,
43- (RoundingMode .RN , RoundingMode .RZ , RoundingMode .RM , RoundingMode .RP ),
44- support_flush_to_zero = True ),
45- "sub" : MathOpDef (lambda x , y : x - y ,
46- (RoundingMode .RN , RoundingMode .RZ , RoundingMode .RM , RoundingMode .RP ),
47- support_flush_to_zero = True ),
48- "mul" : MathOpDef (lambda x , y : x * y ,
49- (RoundingMode .RN , RoundingMode .RZ , RoundingMode .RM , RoundingMode .RP ),
50- support_flush_to_zero = True ),
50+ "add" : MathOpDef (lambda x , y : x + y , _RD_BASIC , support_flush_to_zero = True ),
51+ "sub" : MathOpDef (lambda x , y : x - y , _RD_BASIC , support_flush_to_zero = True ),
52+ "mul" : MathOpDef (lambda x , y : x * y , _RD_BASIC , support_flush_to_zero = True ),
5153 "floordiv" : MathOpDef (lambda x , y : x // y ),
5254 "cdiv" : MathOpDef (lambda x , y : (x + y - 1 ) // y ),
53- "truediv" : MathOpDef (lambda x , y : x / y ,
54- (RoundingMode .RN , RoundingMode .RZ , RoundingMode .RM , RoundingMode .RP ,
55- RoundingMode .FULL , RoundingMode .APPROX ),
56- support_flush_to_zero = True ),
55+ "truediv" : MathOpDef (lambda x , y : x / y , _RD_TRUEDIV , support_flush_to_zero = True ),
5756 "mod" : MathOpDef (lambda x , y : x % y ),
5857 "pow" : MathOpDef (lambda x , y : x ** y ),
59- "max" : MathOpDef (max , (), support_flush_to_zero = True ),
60- "min" : MathOpDef (min , (), support_flush_to_zero = True ),
58+ "atan2" : MathOpDef (math .atan2 ),
59+ "max" : MathOpDef (max , support_flush_to_zero = True ),
60+ "min" : MathOpDef (min , support_flush_to_zero = True ),
6161 "and_" : MathOpDef (lambda x , y : x & y ),
6262 "or_" : MathOpDef (lambda x , y : x | y ),
6363 "xor" : MathOpDef (lambda x , y : x ^ y ),
@@ -81,30 +81,26 @@ class MathOpDef:
8181 "abs" : MathOpDef (abs ),
8282 "neg" : MathOpDef (lambda x : - x ),
8383 "exp" : MathOpDef (math .exp ),
84- "exp2" : MathOpDef (lambda x : 2 ** x , (), support_flush_to_zero = True ),
84+ "exp2" : MathOpDef (lambda x : 2 ** x , support_flush_to_zero = True ),
8585 "sin" : MathOpDef (math .sin ),
8686 "sinh" : MathOpDef (math .sinh ),
8787 "cos" : MathOpDef (math .cos ),
8888 "cosh" : MathOpDef (math .cosh ),
8989 "tan" : MathOpDef (math .tan ),
90- # TODO: RoundingMode support dependent on bytecode version
91- "tanh" : MathOpDef (math .tanh ),
90+ "tanh" : MathOpDef (math .tanh , _RD_TANH ),
9291 "log" : MathOpDef (math .log ),
9392 "log2" : MathOpDef (math .log2 ),
94- "sqrt" : MathOpDef (math .sqrt ,
95- (RoundingMode .RN , RoundingMode .RZ , RoundingMode .RM , RoundingMode .RP ,
96- RoundingMode .APPROX ),
97- support_flush_to_zero = True ),
98- "rsqrt" : MathOpDef (lambda x : x ** - 0.5 , (), support_flush_to_zero = True ),
93+ "sqrt" : MathOpDef (math .sqrt , _RD_SQRT , support_flush_to_zero = True ),
94+ "rsqrt" : MathOpDef (lambda x : x ** - 0.5 , support_flush_to_zero = True ),
9995 "invert" : MathOpDef (lambda x : ~ x ),
10096 "not_" : MathOpDef (lambda x : not x ),
10197 "floor" : MathOpDef (math .floor ),
10298 "ceil" : MathOpDef (math .ceil ),
10399}
104100
105101
106- def get_default_rounding_mode ():
107- return RoundingMode .RN
102+ def get_default_rounding_mode (opname : Optional [ str ] = None ):
103+ return RoundingMode .FULL if opname == 'tanh' else RoundingMode . RN
108104
109105
110106rounding_mode_to_bytecode = {
@@ -117,8 +113,6 @@ def get_default_rounding_mode():
117113 RoundingMode .RZI : bc .RoundingMode .NEAREST_INT_TO_ZERO
118114}
119115
120- rounding_mode_to_bytecode [None ] = rounding_mode_to_bytecode [get_default_rounding_mode ()]
121-
122116
123117def get_rounding_mode (op : Operation , constants : Dict [str , Any ]) -> Optional [RoundingMode ]:
124118 return (
@@ -146,6 +140,14 @@ def check_rd_and_ftz(fn: str, rounding_mode: Optional[RoundingMode], flush_to_ze
146140 if rounding_mode not in math_op_def .supported_rounding_modes :
147141 raise TileTypeError (
148142 f'Rounding mode { rounding_mode .value } is not supported for { fn } ' )
143+ min_version = math_op_def .supported_rounding_modes [rounding_mode ]
144+ if min_version is not None :
145+ cur_version = Builder .get_current ().ir_ctx .tileiras_version
146+ if cur_version < min_version :
147+ raise TileUnsupportedFeatureError (
148+ f'{ fn } rounding_mode={ rounding_mode .value } requires tileiras '
149+ f'{ min_version .major ()} .{ min_version .minor ()} or later. '
150+ f'Current version is { cur_version .major ()} .{ cur_version .minor ()} .' )
149151 if not datatype .is_float (dtype ):
150152 raise TileTypeError (
151153 f'Rounding mode can only be used for float types, '
0 commit comments