Skip to content

Commit 8c3c4b6

Browse files
committed
Add nested block to the TileReduce op
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 48ab3e3 commit 8c3c4b6

File tree

8 files changed

+263
-263
lines changed

8 files changed

+263
-263
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Fixed reductions with multiple axes specified in non-increasing order.

src/cuda/tile/_ir/ir.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,9 @@ def __init__(
527527
self,
528528
op: str,
529529
operands: dict[str, Optional[Var | Tuple[Var, ...]]],
530-
result_vars: List[Var],
530+
result_vars: Sequence[Var],
531531
attributes: Optional[Dict[str, Any]] = None,
532-
nested_blocks: Optional[List[Block]] = None,
532+
nested_blocks: Optional[Sequence[Block]] = None,
533533
loc: Loc = Loc.unknown(),
534534
):
535535
self.op = op

src/cuda/tile/_ir/ops.py

Lines changed: 223 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import operator
77
from contextlib import contextmanager
88
from dataclasses import dataclass
9-
from typing import Sequence, Tuple, Optional, Union, Any, List, Callable, Iterator
9+
from typing import Sequence, Tuple, Optional, Union, Any, List, Callable, Iterator, Iterable
1010

1111
from typing_extensions import override
1212

@@ -61,9 +61,7 @@
6161
DType, is_integral, is_float, is_signed, is_boolean, is_restricted_float,
6262
)
6363
from cuda.tile._ir2bytecode import (
64-
lower_reduce,
65-
lower_reduce_argmax_argmin, lower_scan,
66-
BytecodeContext, typeid,
64+
lower_scan, BytecodeContext, typeid,
6765
generate_bytecode_for_block, convert_dtype, get_list_item_repr_size_in_words,
6866
get_list_partition_view_tile_size, tensor_view_typeid, tensor_view_typeid_for_list
6967
)
@@ -2894,40 +2892,153 @@ def matmul(x: Var, y: Var) -> Var:
28942892

28952893

28962894
class TileReduce(TypedOperation):
2897-
def __init__(self, fn: str, x: Var, axis: int,
2898-
rounding_mode: Optional[RoundingMode], flush_to_zero: bool,
2899-
result_var: Var, loc: Loc):
2895+
def __init__(self, xs: tuple[Var, ...], identities: tuple[bool | int | float, ...], axis: int,
2896+
body: Block, result_vars: tuple[Var, ...], loc: Loc):
29002897
super().__init__(
29012898
"tile_reduce",
2902-
operands={"x": x},
2903-
attributes={
2904-
"fn": fn, "axis": axis,
2905-
"rounding_mode": rounding_mode,
2906-
"flush_to_zero": flush_to_zero,
2907-
},
2908-
result_vars=[result_var],
2899+
operands={"xs": xs},
2900+
attributes={"identities": identities, "axis": axis},
2901+
nested_blocks=[body],
2902+
result_vars=result_vars,
29092903
loc=loc,
29102904
)
29112905

2906+
@property
2907+
def body(self):
2908+
return self.nested_blocks[0]
2909+
29122910
@override
2913-
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
2914-
x_type = ctx.typeof(self.x)
2915-
x_value = ctx.get_value(self.x)
2916-
res_type = ctx.typeof(self.result_var)
2917-
return lower_reduce(
2918-
ctx, x_value, x_type, self.axis, res_type, self.fn,
2919-
self.rounding_mode, self.flush_to_zero
2911+
def _to_string_block_prefixes(self) -> List[str]:
2912+
return ["do"]
2913+
2914+
@override
2915+
def generate_bytecode(self, ctx: BytecodeContext) -> tuple[bc.Value, ...]:
2916+
xs = tuple(ctx.get_value(x) for x in self.xs)
2917+
res_typeids = tuple(ctx.typeid_of(v) for v in self.result_vars)
2918+
2919+
identities = []
2920+
param_type_ids = []
2921+
for id_val, x in zip(self.identities, self.xs, strict=True):
2922+
x_dtype = get_dtype(x.get_type())
2923+
x_dtype_id = typeid(ctx.type_table, x_dtype, wrap_scalars=False)
2924+
if datatype.is_float(x_dtype):
2925+
x_dtype_bc = x_dtype._bytecode_type
2926+
attr = bc.Float(float(id_val), x_dtype_bc, ctx.type_table)
2927+
elif datatype.is_boolean(x_dtype):
2928+
attr = bc.Bool(bool(id_val))
2929+
else:
2930+
assert datatype.is_integral(x_dtype)
2931+
attr = bc.Integer(x_dtype_id, x_dtype.bitwidth, int(id_val))
2932+
identities.append(attr)
2933+
2934+
x_tile_typeid = ctx.type_table.tile(x_dtype_id, ())
2935+
param_type_ids.append(x_tile_typeid)
2936+
param_type_ids.append(x_tile_typeid)
2937+
2938+
nested_builder = bc.encode_ReduceOp(
2939+
ctx.builder,
2940+
result_types=res_typeids,
2941+
operands=xs,
2942+
dim=self.axis,
2943+
identities=identities
29202944
)
29212945

2946+
with nested_builder.new_block(param_type_ids) as block_args:
2947+
for var, value in zip(self.body.params, block_args, strict=True):
2948+
ctx.set_value(var, value)
2949+
generate_bytecode_for_block(ctx, self.body)
2950+
2951+
return nested_builder.done()
2952+
2953+
2954+
def raw_reduce(xs: tuple[Var, ...], identities: tuple[bool | int | float], axis: int,
2955+
body: Callable[[tuple[Var, ...], tuple[Var, ...]], tuple[Var, ...]]
2956+
) -> tuple[Var, ...]:
2957+
builder = Builder.get_current()
2958+
2959+
block_params = []
2960+
lhs_vars = []
2961+
rhs_vars = []
2962+
input_shape = ()
2963+
for i, x in enumerate(xs):
2964+
x_ty = x.get_type()
2965+
assert isinstance(x_ty, TileTy)
2966+
if i == 0:
2967+
input_shape = x_ty.shape_value
2968+
else:
2969+
assert input_shape == x_ty.shape_value
2970+
tile_0d_ty = make_tile_ty(x_ty.dtype, ())
2971+
for _ in range(2):
2972+
var = builder.ir_ctx.make_temp(builder.loc)
2973+
var.set_type(tile_0d_ty)
2974+
block_params.append(var)
2975+
lhs_vars.append(block_params[-2])
2976+
rhs_vars.append(block_params[-1])
2977+
2978+
assert 0 <= axis < len(input_shape)
2979+
result_shape = input_shape[:axis] + input_shape[axis+1:]
2980+
result_types = tuple(make_tile_ty(x.get_type().dtype, result_shape) for x in xs)
2981+
2982+
assert len(xs) == len(identities)
2983+
2984+
with nested_block(builder.loc) as body_block:
2985+
body_block.params = tuple(block_params)
2986+
body_results = body(tuple(lhs_vars), tuple(rhs_vars))
2987+
for body_res, x in zip(body_results, xs, strict=True):
2988+
body_res_ty = body_res.get_type()
2989+
assert isinstance(body_res_ty, TileTy)
2990+
assert body_res_ty.shape_value == ()
2991+
assert body_res_ty.dtype == x.get_type().dtype
2992+
2993+
add_operation(EndBranch, (), outputs=body_results)
2994+
2995+
return add_operation(TileReduce, result_types, xs=xs, identities=identities, axis=axis,
2996+
body=body_block)
2997+
2998+
2999+
def reduce(xs: tuple[Var, ...], identities: tuple[bool | int | float, ...],
3000+
axis: int | None | Iterable[int], keepdims: bool,
3001+
body: Callable[[tuple[Var, ...], tuple[Var, ...]], tuple[Var, ...]]
3002+
) -> tuple[Var, ...]:
3003+
if len(xs) == 0:
3004+
raise TileTypeError("Need at least one input value to reduce")
3005+
3006+
if len(xs) != len(identities):
3007+
raise TileTypeError(f"Number of input values ({len(xs)}) doesn't match the"
3008+
f" number of identities ({len(identities)})")
3009+
3010+
common_input_shape = ()
3011+
3012+
x_types = tuple(require_tile_type(x) for x in xs)
3013+
for x_ty in x_types:
3014+
try:
3015+
common_input_shape = broadcast_shapes2(common_input_shape, x_ty.shape_value)
3016+
except BroadcastError:
3017+
all_shapes = ", ".join(str(ty.shape_value) for ty in x_types)
3018+
raise TileTypeError(f"Input shapes {all_shapes}"
3019+
f" are not broadcastable to a common shape")
3020+
3021+
if axis is None:
3022+
axis = tuple(range(len(common_input_shape)))
3023+
else:
3024+
if isinstance(axis, int):
3025+
axis = (axis,)
3026+
axis = sorted(normalize_axis(a, len(common_input_shape)) for a in axis)
3027+
for a1, a2 in zip(axis, axis[1:]):
3028+
if a1 == a2:
3029+
raise TileTypeError(f"Repeated reduction axis {a1}")
3030+
3031+
xs = tuple(broadcast_to(x, common_input_shape) for x in xs)
3032+
for i, a in enumerate(axis):
3033+
xs = raw_reduce(xs, identities, a - i, body)
3034+
3035+
result_shape = _get_reduction_shape(common_input_shape, axis, keepdims)
3036+
return tuple(reshape(x, result_shape) for x in xs)
3037+
29223038

29233039
def _get_reduction_shape(shape: Tuple[int, ...],
2924-
normalized_axis: int | Tuple[int, ...] | None,
3040+
normalized_axis: Tuple[int, ...],
29253041
keepdims: bool) -> Tuple[int, ...]:
2926-
if normalized_axis is None:
2927-
normalized_axis = tuple(range(len(shape)))
2928-
if isinstance(normalized_axis, int):
2929-
normalized_axis = (normalized_axis,)
2930-
normalized_axis = set(normalized_axis)
29313042
ret = []
29323043
for i, size in enumerate(shape):
29333044
if i in normalized_axis:
@@ -2938,29 +3049,46 @@ def _get_reduction_shape(shape: Tuple[int, ...],
29383049
return tuple(ret)
29393050

29403051

2941-
def reduce(fn: str, x: Var, axis: Optional[tuple[int, ...]], keepdims: bool,
2942-
rounding_mode: Optional[RoundingMode] = None,
2943-
flush_to_zero: bool = False) -> Var:
3052+
def reduce_simple(fn: str, x: Var, axis: int | None | tuple[int, ...], keepdims: bool,
3053+
rounding_mode: Optional[RoundingMode] = None,
3054+
flush_to_zero: bool = False) -> Var:
29443055
x_type = require_tile_type(x)
29453056
check_rd_and_ftz(fn, rounding_mode, flush_to_zero, x_type.dtype)
2946-
x_shape = x_type.shape
2947-
rank = len(x_shape)
2948-
if axis is None:
2949-
axis = tuple(range(rank))
2950-
else:
2951-
axis = tuple([normalize_axis(axis_value, rank) for axis_value in axis])
29523057

2953-
x_dtype = datatype.default_int_type if datatype.is_boolean(x_type.dtype) else x_type.dtype
2954-
x = _promote_and_broadcast_to(x, TileTy(x_dtype, x_shape))
2955-
for i, axis_value in enumerate(axis):
2956-
axis_value -= i
2957-
x_shape = x_shape[:axis_value] + x_shape[axis_value + 1:]
2958-
x = add_operation(
2959-
TileReduce, TileTy(x_dtype, TupleTy(x_shape)),
2960-
fn=fn, x=x, axis=axis_value,
2961-
rounding_mode=rounding_mode, flush_to_zero=flush_to_zero
2962-
)
2963-
return reshape(x, _get_reduction_shape(x_type.shape_value, axis, keepdims))
3058+
if datatype.is_boolean(x_type.dtype):
3059+
x = astype(x, datatype.default_int_type)
3060+
3061+
match fn:
3062+
case "add": id_val = 0
3063+
case "mul": id_val = 1
3064+
case "min": id_val = _get_min_max(x_type.dtype)[1]
3065+
case "max": id_val = _get_min_max(x_type.dtype)[0]
3066+
case _: assert False
3067+
3068+
def body(lhs: tuple[Var], rhs: tuple[Var]) -> tuple[Var]:
3069+
[lhs], [rhs] = lhs, rhs
3070+
ret = raw_binary_arithmetic(fn, lhs, rhs,
3071+
rounding_mode=rounding_mode, flush_to_zero=flush_to_zero)
3072+
return (ret,)
3073+
3074+
[ret] = reduce((x,), (id_val,), axis, keepdims, body)
3075+
return ret
3076+
3077+
3078+
Limits = Tuple[float, float] | Tuple[int, int]
3079+
3080+
3081+
def _get_min_max(dtype: datatype.DType) -> Limits:
3082+
use_float = datatype.is_float(dtype)
3083+
if use_float:
3084+
if dtype in [datatype.float16, datatype.bfloat16, datatype.float32, datatype.float64]:
3085+
return -float("inf"), float("inf")
3086+
else:
3087+
raise NotImplementedError(f"Unsupported float dtype: {dtype}")
3088+
elif datatype.is_signed(dtype):
3089+
return -(1 << (dtype.bitwidth-1)), (1 << (dtype.bitwidth-1)) - 1
3090+
else:
3091+
return 0, (1 << dtype.bitwidth) - 1
29643092

29653093

29663094
def _parse_reduce_axis(axis: Var) -> Optional[tuple[int, ...]]:
@@ -2981,7 +3109,8 @@ def reduce_impl_with_rd_and_ftz(fn: str, x: Var, axis: Var, keepdims: Var, round
29813109
keepdims = require_constant_bool(keepdims)
29823110
rounding_mode = require_optional_constant_enum(rounding_mode, RoundingMode)
29833111
flush_to_zero = require_constant_bool(flush_to_zero)
2984-
return reduce(fn, x, axis, keepdims, rounding_mode=rounding_mode, flush_to_zero=flush_to_zero)
3112+
return reduce_simple(fn, x, axis, keepdims,
3113+
rounding_mode=rounding_mode, flush_to_zero=flush_to_zero)
29853114

29863115

29873116
@impl(ct.max, fixed_args=["max"])
@@ -2990,53 +3119,63 @@ def reduce_impl_with_ftz(fn: str, x: Var, axis: Var, keepdims: Var, flush_to_zer
29903119
axis = _parse_reduce_axis(axis)
29913120
keepdims = require_constant_bool(keepdims)
29923121
flush_to_zero = require_constant_bool(flush_to_zero)
2993-
return reduce(fn, x, axis, keepdims, flush_to_zero=flush_to_zero)
3122+
return reduce_simple(fn, x, axis, keepdims, flush_to_zero=flush_to_zero)
29943123

29953124

2996-
class TileArgReduce(TypedOperation):
2997-
def __init__(self, fn: str, x: Var, axis: Optional[int],
2998-
result_var: Var, loc: Loc):
2999-
super().__init__(
3000-
"tile_arg_reduce",
3001-
operands={"x": x},
3002-
attributes={"fn": fn, "axis": axis},
3003-
result_vars=[result_var],
3004-
loc=loc,
3005-
)
3006-
3007-
@override
3008-
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
3009-
x_type = ctx.typeof(self.x)
3010-
x_value = ctx.get_value(self.x)
3011-
res_type = ctx.typeof(self.result_var)
3012-
return lower_reduce_argmax_argmin(
3013-
ctx, x_value, x_type, self.axis, res_type, self.fn
3014-
)
3125+
def argmax_argmin(fn: str, x: Var, axis: Optional[int], keepdims: bool) -> Var:
3126+
require_tile_type(x)
3127+
final_shape = None
3128+
if axis is None:
3129+
if keepdims:
3130+
final_shape = (1,) * x.get_type().ndim
3131+
keepdims = False
3132+
x = reshape(x, (-1,))
3133+
axis = 0
3134+
else:
3135+
axis = normalize_axis(axis, x.get_type().ndim)
30153136

3137+
if datatype.is_boolean(x.get_type().dtype):
3138+
x = astype(x, datatype.default_int_type)
30163139

3017-
def argreduce(fn: str, x: Var, axis: Optional[int], keepdims: bool) -> Var:
3018-
x_type = require_tile_type(x)
3019-
x_shape = x_type.shape
3020-
if axis is not None:
3021-
axis = normalize_axis(axis, len(x_shape))
3140+
x_type = x.get_type()
3141+
indices = arange(x_type.shape_value[axis], datatype.default_int_type)
3142+
indices = reshape(indices, tuple(-1 if i == axis else 1 for i in range(x_type.ndim)))
3143+
3144+
match fn:
3145+
case "argmin":
3146+
id_val = _get_min_max(x_type.dtype)[1]
3147+
cmp = "lt"
3148+
case "argmax":
3149+
id_val = _get_min_max(x_type.dtype)[0]
3150+
cmp = "gt"
3151+
case _: assert False
3152+
3153+
def body(lhs: tuple[Var, Var], rhs: tuple[Var, Var]) -> tuple[Var, Var]:
3154+
lhs_val, lhs_idx = lhs
3155+
rhs_val, rhs_idx = rhs
3156+
val_strict = raw_comparison(cmp, lhs_val, rhs_val)
3157+
val_equal = raw_comparison("eq", lhs_val, rhs_val)
3158+
index_lt = raw_comparison("lt", lhs_idx, rhs_idx)
3159+
val_equal_and_index_lt = raw_binary_bitwise("and_", val_equal, index_lt)
3160+
cond = raw_binary_bitwise("or_", val_strict, val_equal_and_index_lt)
3161+
res = raw_where(cond, lhs_val, rhs_val)
3162+
idx = raw_where(cond, lhs_idx, rhs_idx)
3163+
return res, idx
3164+
3165+
[_, ret] = reduce((x, indices), (id_val, 0), axis, keepdims, body)
3166+
3167+
if final_shape is not None:
3168+
ret = reshape(ret, final_shape)
30223169

3023-
x_dtype = datatype.default_int_type if datatype.is_boolean(x_type.dtype) else x_type.dtype
3024-
x = _promote_and_broadcast_to(x, TileTy(x_dtype, x_shape))
3025-
output_dtype = datatype.default_int_type
3026-
output_shape = TupleTy([]) if axis is None else TupleTy(x_shape[:axis] + x_shape[axis + 1:])
3027-
x = add_operation(
3028-
TileArgReduce, TileTy(output_dtype, output_shape),
3029-
fn=fn, x=x, axis=axis
3030-
)
3031-
return reshape(x, _get_reduction_shape(x_type.shape_value, axis, keepdims))
3170+
return ret
30323171

30333172

30343173
@impl(ct.argmax, fixed_args=["argmax"])
30353174
@impl(ct.argmin, fixed_args=["argmin"])
3036-
def argreduce_impl(fn: str, x: Var, axis: Var, keepdims: Var) -> Var:
3175+
def argmax_argmin_impl(fn: str, x: Var, axis: Var, keepdims: Var) -> Var:
30373176
axis = require_optional_constant_int(axis)
30383177
keepdims = require_constant_bool(keepdims)
3039-
return argreduce(fn, x, axis, keepdims)
3178+
return argmax_argmin(fn, x, axis, keepdims)
30403179

30413180

30423181
class TileScan(TypedOperation):

src/cuda/tile/_ir/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def __str__(self):
283283
return f"Tile[{self.dtype},{shape_str}]"
284284

285285

286-
def make_tile_ty(dtype, shape: Sequence[int]):
286+
def make_tile_ty(dtype, shape: Sequence[int]) -> TileTy:
287287
shape = TupleTy(tuple(SizeTy(x) for x in shape))
288288
return TileTy(dtype, shape)
289289

0 commit comments

Comments
 (0)