Skip to content

Commit e5d47f0

Browse files
committed
Virtualize unflatten_aggregate() via ImplRegistry
This gets rid of MakeTensorView ops in cuda.lang. Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 4e04dcb commit e5d47f0

7 files changed

Lines changed: 142 additions & 119 deletions

File tree

experimental/cuda-lang/src/cuda/lang/_compile.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def get_function_ir(
129129
if constant_mask is None:
130130
constant_mask = [False] * len(signature.parameters)
131131
parameter_names = function.signature.parameters.keys()
132-
with ir.TileBuilder(ctx, function.body.loc) as builder:
132+
with ir.TileBuilder(ctx, function.body.loc) as builder, cuda_lang_impl_registry.as_current():
133133
params = _create_kernel_parameters(
134134
signature.parameters,
135135
constant_mask,
@@ -138,8 +138,7 @@ def get_function_ir(
138138
ctx
139139
)
140140
canonicalize_parameters(params, builder)
141-
with cuda_lang_impl_registry.as_current():
142-
hir2ir(function, params.aggregate_vars, ctx)
141+
hir2ir(function, params.aggregate_vars, ctx)
143142
func_body = ctx.make_block("entry", function.body.loc)
144143
func_body.params = sum((vars for vars, _ in params.nonconstant_flat_vars), ())
145144
func_body.extend(builder.ops)

experimental/cuda-lang/src/cuda/lang/_passes/ir2mlir/pass_definition.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ def _get_mlir_comparison_op(
282282
# These operations have aggregate results. The RHS's elements are stored to
283283
# the LHS's when lowering Assign operations and are no-ops at the MLIR level.
284284
_NOOP_LOWERINGS = frozenset([
285-
ops.MakeTensorView,
286285
ops.ReinterpretPointerAsArray,
287286
])
288287

experimental/cuda-lang/test/passes/test_flatten_cfg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def test_kernel(A):
2222
else:
2323
A[0] = 0
2424

25-
# BEFORE: A{{.+}}: Array[int32,(?):(?)] = make_tensor_view
2625
# BEFORE: $[[ITEM:[0-9]+]]: Tile[int32,()] = load_pointer
2726
# BEFORE: $[[ITEM_CASTED:[0-9]+]]: Tile[bool_,()] = tile_astype(x=$[[ITEM]])
2827
# BEFORE: if(cond=$[[ITEM_CASTED]])
@@ -37,7 +36,6 @@ def test_kernel(A):
3736
filecheck(str(body), get_source(), ("BEFORE",))
3837

3938
# AFTER: ^entry({{.+}}):
40-
# AFTER: A{{.+}}: Array[int32,(?):(?)] = make_tensor_view
4139
# AFTER: $[[ITEM:[0-9]+]]: Tile[int32,()] = load_pointer
4240
# AFTER: $[[ITEM_CASTED:[0-9]+]]: Tile[bool_,()] = tile_astype(x=$[[ITEM]])
4341
# AFTER: cond_br $[[ITEM_CASTED]]: Tile[bool_,()] ^then() ^else()

src/cuda/tile/_compile.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,12 @@ def get_final_ir(self, signature_index: int) -> ir.Block:
270270
tileiras_version=self.bytecode_version,
271271
typing_hooks=_TileTypingHooks())
272272
with ir.Builder(ir_ctx, self._func_hir.body.loc) as ir_builder:
273-
params = _create_kernel_parameters(sig.parameters,
274-
self.ann_func.constant_parameter_mask,
275-
param_names,
276-
self._func_hir.param_locs,
277-
ir_ctx)
278273
with tile_impl_registry.as_current():
274+
params = _create_kernel_parameters(sig.parameters,
275+
self.ann_func.constant_parameter_mask,
276+
param_names,
277+
self._func_hir.param_locs,
278+
ir_ctx)
279279
hir2ir(self._func_hir, params.aggregate_vars, ir_ctx)
280280

281281
func_body = ir.Block(ir_ctx, self._func_hir.body.loc)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import Sequence, Iterator
6+
7+
from cuda.tile._ir.ir import Var, Builder
8+
from cuda.tile._ir.op_impl import ImplRegistry
9+
from cuda.tile._ir.type import Type, InvalidType
10+
11+
12+
def flatten_aggregates(vars: Sequence[Var], types: Sequence[Type]) -> tuple[Var, ...]:
13+
ret = []
14+
for x, ty in zip(vars, types, strict=True):
15+
item_types = tuple(ty.flatten_aggregate())
16+
x_ty = x.get_type_allow_invalid()
17+
if isinstance(x_ty, InvalidType):
18+
for _ in item_types:
19+
t = x.ctx.make_temp(x.loc)
20+
t.set_type(x_ty)
21+
ret.append(t)
22+
else:
23+
items = tuple(x.flatten_aggregate())
24+
assert len(items) == len(item_types)
25+
ret.extend(items)
26+
return tuple(ret)
27+
28+
29+
def flatten_aggregate_types(types: Sequence[Type]) -> tuple[Type, ...]:
30+
ret = []
31+
for ty in types:
32+
ret.extend(ty.flatten_aggregate())
33+
return tuple(ret)
34+
35+
36+
def unflatten_aggregates(flattened: tuple[Var, ...],
37+
nominal: Sequence[Type], actual: Sequence[Type]) -> tuple[Var, ...]:
38+
it = iter(flattened)
39+
ret = tuple(_maybe_unflatten_aggregate(it, n, a) for n, a in zip(nominal, actual, strict=True))
40+
assert next(it, None) is None
41+
return ret
42+
43+
44+
def _maybe_unflatten_aggregate(flattened_iter: Iterator[Var], nominal: Type, actual: Type) -> Var:
45+
if not nominal.is_aggregate():
46+
return next(flattened_iter)
47+
return _unflatten_proper_aggregate(flattened_iter, nominal, actual, result_var=None)
48+
49+
50+
def expand_aggregate_var(var: Var) -> tuple[Var, ...]:
51+
item_types = tuple(var.get_type().flatten_aggregate())
52+
ret = tuple(var.ctx.make_var(f"{var.get_original_name()}_{i}", var.loc)
53+
for i in range(len(item_types)))
54+
for item, item_ty in zip(ret, item_types, strict=True):
55+
item.set_type(item_ty)
56+
return ret
57+
58+
59+
def flatten_block_parameters(vars: Sequence[Var]) -> list[tuple[Var, ...]]:
60+
ret = []
61+
for v in vars:
62+
ty = v.get_type_allow_invalid()
63+
if ty.is_aggregate():
64+
flattened_vars = expand_aggregate_var(v)
65+
ret.append(flattened_vars)
66+
it = iter(flattened_vars)
67+
_unflatten_proper_aggregate(it, ty, ty, v)
68+
assert next(it, None) is None
69+
else:
70+
ret.append((v,))
71+
return ret
72+
73+
74+
def _unflatten_proper_aggregate(flattened_iter: Iterator[Var], nominal: Type, actual: Type,
75+
result_var: Var | None) -> Var:
76+
nominal_item_types = nominal.aggregate_item_types()
77+
if isinstance(actual, InvalidType):
78+
# Pop values from the iterator and throw them out
79+
for _ in nominal_item_types:
80+
next(flattened_iter)
81+
builder = Builder.get_current()
82+
t = builder.ir_ctx.make_temp(builder.loc)
83+
t.set_type(actual)
84+
return t
85+
86+
items = tuple(_maybe_unflatten_aggregate(flattened_iter, item_nominal, item_actual)
87+
for item_nominal, item_actual
88+
in zip(nominal_item_types, actual.aggregate_item_types(), strict=True))
89+
val = nominal.make_aggregate_value(items)
90+
91+
impl = ImplRegistry.get_current().unflatten_aggregate_implementations.get(type(nominal))
92+
if impl is None:
93+
return Builder.get_current().make_aggregate(val, nominal, result_var=result_var)
94+
else:
95+
return impl(val, nominal, result_var)

src/cuda/tile/_ir/op_impl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class ImplRegistry:
5959
def __init__(self):
6060
self.op_implementations = dict()
6161
self._overloaded_implementations = defaultdict(dict)
62+
self.unflatten_aggregate_implementations = dict()
6263

6364
@staticmethod
6465
def get_current() -> "ImplRegistry":
@@ -79,6 +80,7 @@ def update(self, source: "ImplRegistry"):
7980
self.op_implementations.update(source.op_implementations)
8081
for stub, overloads in source._overloaded_implementations.items():
8182
self._overloaded_implementations[stub].update(overloads)
83+
self.unflatten_aggregate_implementations.update(source.unflatten_aggregate_implementations)
8284

8385
def overload_dispatcher(self, stub, *, fixed_args: Sequence[Any] = ()):
8486
"""
@@ -215,6 +217,12 @@ def _have_overload_matching_first_param(self, stub: Callable, first_param: Any)
215217
return any(predicates[0](first_param)
216218
for _priority, predicates, _impl in candidates.values())
217219

220+
def unflatten_aggregate_impl(self, aggregate_type_class: type[Type]):
221+
def decorate(func):
222+
self.unflatten_aggregate_implementations[aggregate_type_class] = func
223+
return func
224+
return decorate
225+
218226

219227
def _predicate_from_overload_pattern(pattern):
220228
if pattern == WILDCARD:

src/cuda/tile/_ir/ops.py

Lines changed: 32 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import operator
77
from dataclasses import dataclass
88
from typing import (
9-
Literal, Sequence, Tuple, Optional, Any, List, Callable, Iterator, Iterable,
9+
Literal, Sequence, Tuple, Optional, Any, List, Callable, Iterable,
1010
)
1111

1212
from typing_extensions import override
@@ -21,6 +21,8 @@
2121
PhiState, LoopVarState, make_aggregate, ConstantState, MemoryEffect, attribute, operand,
2222
BlockRestriction, add_operation_variadic,
2323
)
24+
from .aggregate_support import flatten_block_parameters, expand_aggregate_var, \
25+
flatten_aggregate_types, flatten_aggregates, unflatten_aggregates
2426
from .arithmetic_ops import reshape, broadcast_to, astype, compare_tensorlike, \
2527
binary_bitwise_tensorlike, bitwise_shift_tensorlike, binary_arithmetic_tensorlike, \
2628
compare_tensorlike_raw, where, binary_bitwise_tensorlike_raw, where_raw, TileReshape, \
@@ -1218,113 +1220,6 @@ def generate_bytecode(self, ctx: "BytecodeContext"):
12181220
return bc.encode_MakePartitionViewOp(ctx.builder, pv_ty, tv)
12191221

12201222

1221-
def flatten_aggregates(vars: Sequence[Var], types: Sequence[Type]) -> tuple[Var, ...]:
1222-
ret = []
1223-
for x, ty in zip(vars, types, strict=True):
1224-
item_types = tuple(ty.flatten_aggregate())
1225-
x_ty = x.get_type_allow_invalid()
1226-
if isinstance(x_ty, InvalidType):
1227-
for _ in item_types:
1228-
t = x.ctx.make_temp(x.loc)
1229-
t.set_type(x_ty)
1230-
ret.append(t)
1231-
else:
1232-
items = tuple(x.flatten_aggregate())
1233-
assert len(items) == len(item_types)
1234-
ret.extend(items)
1235-
return tuple(ret)
1236-
1237-
1238-
def flatten_aggregate_types(types: Sequence[Type]) -> tuple[Type, ...]:
1239-
ret = []
1240-
for ty in types:
1241-
ret.extend(ty.flatten_aggregate())
1242-
return tuple(ret)
1243-
1244-
1245-
def unflatten_aggregates(flattened: Tuple[Var, ...],
1246-
nominal: Sequence[Type], actual: Sequence[Type]) -> tuple[Var, ...]:
1247-
it = iter(flattened)
1248-
ret = tuple(_maybe_unflatten_aggregate(it, n, a) for n, a in zip(nominal, actual, strict=True))
1249-
assert next(it, None) is None
1250-
return ret
1251-
1252-
1253-
def _maybe_unflatten_aggregate(flattened_iter: Iterator[Var], nominal: Type, actual: Type) -> Var:
1254-
if not nominal.is_aggregate():
1255-
return next(flattened_iter)
1256-
return _unflatten_proper_aggregate(flattened_iter, nominal, actual, result_var=None)
1257-
1258-
1259-
def expand_aggregate_var(var: Var) -> Tuple[Var, ...]:
1260-
item_types = tuple(var.get_type().flatten_aggregate())
1261-
ret = tuple(var.ctx.make_var(f"{var.get_original_name()}_{i}", var.loc)
1262-
for i in range(len(item_types)))
1263-
for item, item_ty in zip(ret, item_types, strict=True):
1264-
item.set_type(item_ty)
1265-
return ret
1266-
1267-
1268-
def flatten_block_parameters(vars: Sequence[Var]) -> list[tuple[Var, ...]]:
1269-
ret = []
1270-
for v in vars:
1271-
ty = v.get_type_allow_invalid()
1272-
if ty.is_aggregate():
1273-
flattened_vars = expand_aggregate_var(v)
1274-
ret.append(flattened_vars)
1275-
it = iter(flattened_vars)
1276-
_unflatten_proper_aggregate(it, ty, ty, v)
1277-
assert next(it, None) is None
1278-
else:
1279-
ret.append((v,))
1280-
return ret
1281-
1282-
1283-
def _unflatten_proper_aggregate(flattened_iter: Iterator[Var], nominal: Type, actual: Type,
1284-
result_var: Var | None) -> Var:
1285-
nominal_item_types = nominal.aggregate_item_types()
1286-
if isinstance(actual, InvalidType):
1287-
# Pop values from the iterator and throw them out
1288-
for _ in nominal_item_types:
1289-
next(flattened_iter)
1290-
builder = Builder.get_current()
1291-
t = builder.ir_ctx.make_temp(builder.loc)
1292-
t.set_type(actual)
1293-
return t
1294-
1295-
items = tuple(_maybe_unflatten_aggregate(flattened_iter, item_nominal, item_actual)
1296-
for item_nominal, item_actual
1297-
in zip(nominal_item_types, actual.aggregate_item_types(), strict=True))
1298-
val = nominal.make_aggregate_value(items)
1299-
1300-
builder = Builder.get_current()
1301-
if isinstance(nominal, ArrayTy):
1302-
assert isinstance(val, ArrayValue)
1303-
base_ptr = val.base_ptr
1304-
shape = tuple(assume_bounded(x, 0, None) for x in val.shape)
1305-
1306-
all_strides = []
1307-
dynamic_strides = []
1308-
for x, s in zip(val.strides, nominal.strides, strict=True):
1309-
if s is None:
1310-
x = assume_bounded(x, 0, None)
1311-
dynamic_strides.append(x)
1312-
all_strides.append(x)
1313-
1314-
operands = dict(base_ptr=base_ptr, shape=shape, dynamic_strides=tuple(dynamic_strides))
1315-
ret = builder.add_operation(MakeTensorView, nominal, operands, result_var)
1316-
ret.set_aggregate(ArrayValue(base_ptr, shape, tuple(all_strides)))
1317-
return ret
1318-
elif isinstance(nominal, ListTy):
1319-
assert isinstance(val, ListValue)
1320-
operands = dict(base_ptr=val.base_ptr, length=val.length)
1321-
ret = builder.add_operation(MakeListView, nominal, operands, result_var)
1322-
ret.set_aggregate(val)
1323-
return ret
1324-
else:
1325-
return builder.make_aggregate(val, nominal, result_var=result_var)
1326-
1327-
13281223
@dataclass(eq=False)
13291224
class TileNumBlocks(Operation, opcode="tile_num_blocks"):
13301225
axis: int = attribute()
@@ -3921,6 +3816,35 @@ def store_advanced_impl(array: Var, indices: Var, tile: Var,
39213816
latency=latency_val, allow_tma=allow_tma_val)
39223817

39233818

3819+
@tile_impl_registry.unflatten_aggregate_impl(ArrayTy)
3820+
def _unflatten_aggregate_array_impl(val: ArrayValue, ty: ArrayTy, result_var: Var):
3821+
assert isinstance(val, ArrayValue)
3822+
base_ptr = val.base_ptr
3823+
shape = tuple(assume_bounded(x, 0, None) for x in val.shape)
3824+
3825+
all_strides = []
3826+
dynamic_strides = []
3827+
for x, s in zip(val.strides, ty.strides, strict=True):
3828+
if s is None:
3829+
x = assume_bounded(x, 0, None)
3830+
dynamic_strides.append(x)
3831+
all_strides.append(x)
3832+
3833+
operands = dict(base_ptr=base_ptr, shape=shape, dynamic_strides=tuple(dynamic_strides))
3834+
ret = Builder.get_current().add_operation(MakeTensorView, ty, operands, result_var)
3835+
ret.set_aggregate(ArrayValue(base_ptr, shape, tuple(all_strides)))
3836+
return ret
3837+
3838+
3839+
@tile_impl_registry.unflatten_aggregate_impl(ListTy)
3840+
def _unflatten_aggregate_list_impl(val: ListValue, ty: ListTy, result_var: Var):
3841+
assert isinstance(val, ListValue)
3842+
operands = dict(base_ptr=val.base_ptr, length=val.length)
3843+
ret = Builder.get_current().add_operation(MakeListView, ty, operands, result_var)
3844+
ret.set_aggregate(val)
3845+
return ret
3846+
3847+
39243848
def _add_dummy_op_to_invalid_vars(vars: Sequence[Var],
39253849
actual_types: Sequence[Type]) -> tuple[Var, ...]:
39263850
return tuple(add_operation(MakeDummy, actual)

0 commit comments

Comments
 (0)