|
6 | 6 | import operator |
7 | 7 | from dataclasses import dataclass |
8 | 8 | from typing import ( |
9 | | - Literal, Sequence, Tuple, Optional, Any, List, Callable, Iterator, Iterable, |
| 9 | + Literal, Sequence, Tuple, Optional, Any, List, Callable, Iterable, |
10 | 10 | ) |
11 | 11 |
|
12 | 12 | from typing_extensions import override |
|
21 | 21 | PhiState, LoopVarState, make_aggregate, ConstantState, MemoryEffect, attribute, operand, |
22 | 22 | BlockRestriction, add_operation_variadic, |
23 | 23 | ) |
| 24 | +from .aggregate_support import flatten_block_parameters, expand_aggregate_var, \ |
| 25 | + flatten_aggregate_types, flatten_aggregates, unflatten_aggregates |
24 | 26 | from .arithmetic_ops import reshape, broadcast_to, astype, compare_tensorlike, \ |
25 | 27 | binary_bitwise_tensorlike, bitwise_shift_tensorlike, binary_arithmetic_tensorlike, \ |
26 | 28 | compare_tensorlike_raw, where, binary_bitwise_tensorlike_raw, where_raw, TileReshape, \ |
@@ -1218,113 +1220,6 @@ def generate_bytecode(self, ctx: "BytecodeContext"): |
1218 | 1220 | return bc.encode_MakePartitionViewOp(ctx.builder, pv_ty, tv) |
1219 | 1221 |
|
1220 | 1222 |
|
1221 | | -def flatten_aggregates(vars: Sequence[Var], types: Sequence[Type]) -> tuple[Var, ...]: |
1222 | | - ret = [] |
1223 | | - for x, ty in zip(vars, types, strict=True): |
1224 | | - item_types = tuple(ty.flatten_aggregate()) |
1225 | | - x_ty = x.get_type_allow_invalid() |
1226 | | - if isinstance(x_ty, InvalidType): |
1227 | | - for _ in item_types: |
1228 | | - t = x.ctx.make_temp(x.loc) |
1229 | | - t.set_type(x_ty) |
1230 | | - ret.append(t) |
1231 | | - else: |
1232 | | - items = tuple(x.flatten_aggregate()) |
1233 | | - assert len(items) == len(item_types) |
1234 | | - ret.extend(items) |
1235 | | - return tuple(ret) |
1236 | | - |
1237 | | - |
1238 | | -def flatten_aggregate_types(types: Sequence[Type]) -> tuple[Type, ...]: |
1239 | | - ret = [] |
1240 | | - for ty in types: |
1241 | | - ret.extend(ty.flatten_aggregate()) |
1242 | | - return tuple(ret) |
1243 | | - |
1244 | | - |
1245 | | -def unflatten_aggregates(flattened: Tuple[Var, ...], |
1246 | | - nominal: Sequence[Type], actual: Sequence[Type]) -> tuple[Var, ...]: |
1247 | | - it = iter(flattened) |
1248 | | - ret = tuple(_maybe_unflatten_aggregate(it, n, a) for n, a in zip(nominal, actual, strict=True)) |
1249 | | - assert next(it, None) is None |
1250 | | - return ret |
1251 | | - |
1252 | | - |
1253 | | -def _maybe_unflatten_aggregate(flattened_iter: Iterator[Var], nominal: Type, actual: Type) -> Var: |
1254 | | - if not nominal.is_aggregate(): |
1255 | | - return next(flattened_iter) |
1256 | | - return _unflatten_proper_aggregate(flattened_iter, nominal, actual, result_var=None) |
1257 | | - |
1258 | | - |
1259 | | -def expand_aggregate_var(var: Var) -> Tuple[Var, ...]: |
1260 | | - item_types = tuple(var.get_type().flatten_aggregate()) |
1261 | | - ret = tuple(var.ctx.make_var(f"{var.get_original_name()}_{i}", var.loc) |
1262 | | - for i in range(len(item_types))) |
1263 | | - for item, item_ty in zip(ret, item_types, strict=True): |
1264 | | - item.set_type(item_ty) |
1265 | | - return ret |
1266 | | - |
1267 | | - |
1268 | | -def flatten_block_parameters(vars: Sequence[Var]) -> list[tuple[Var, ...]]: |
1269 | | - ret = [] |
1270 | | - for v in vars: |
1271 | | - ty = v.get_type_allow_invalid() |
1272 | | - if ty.is_aggregate(): |
1273 | | - flattened_vars = expand_aggregate_var(v) |
1274 | | - ret.append(flattened_vars) |
1275 | | - it = iter(flattened_vars) |
1276 | | - _unflatten_proper_aggregate(it, ty, ty, v) |
1277 | | - assert next(it, None) is None |
1278 | | - else: |
1279 | | - ret.append((v,)) |
1280 | | - return ret |
1281 | | - |
1282 | | - |
1283 | | -def _unflatten_proper_aggregate(flattened_iter: Iterator[Var], nominal: Type, actual: Type, |
1284 | | - result_var: Var | None) -> Var: |
1285 | | - nominal_item_types = nominal.aggregate_item_types() |
1286 | | - if isinstance(actual, InvalidType): |
1287 | | - # Pop values from the iterator and throw them out |
1288 | | - for _ in nominal_item_types: |
1289 | | - next(flattened_iter) |
1290 | | - builder = Builder.get_current() |
1291 | | - t = builder.ir_ctx.make_temp(builder.loc) |
1292 | | - t.set_type(actual) |
1293 | | - return t |
1294 | | - |
1295 | | - items = tuple(_maybe_unflatten_aggregate(flattened_iter, item_nominal, item_actual) |
1296 | | - for item_nominal, item_actual |
1297 | | - in zip(nominal_item_types, actual.aggregate_item_types(), strict=True)) |
1298 | | - val = nominal.make_aggregate_value(items) |
1299 | | - |
1300 | | - builder = Builder.get_current() |
1301 | | - if isinstance(nominal, ArrayTy): |
1302 | | - assert isinstance(val, ArrayValue) |
1303 | | - base_ptr = val.base_ptr |
1304 | | - shape = tuple(assume_bounded(x, 0, None) for x in val.shape) |
1305 | | - |
1306 | | - all_strides = [] |
1307 | | - dynamic_strides = [] |
1308 | | - for x, s in zip(val.strides, nominal.strides, strict=True): |
1309 | | - if s is None: |
1310 | | - x = assume_bounded(x, 0, None) |
1311 | | - dynamic_strides.append(x) |
1312 | | - all_strides.append(x) |
1313 | | - |
1314 | | - operands = dict(base_ptr=base_ptr, shape=shape, dynamic_strides=tuple(dynamic_strides)) |
1315 | | - ret = builder.add_operation(MakeTensorView, nominal, operands, result_var) |
1316 | | - ret.set_aggregate(ArrayValue(base_ptr, shape, tuple(all_strides))) |
1317 | | - return ret |
1318 | | - elif isinstance(nominal, ListTy): |
1319 | | - assert isinstance(val, ListValue) |
1320 | | - operands = dict(base_ptr=val.base_ptr, length=val.length) |
1321 | | - ret = builder.add_operation(MakeListView, nominal, operands, result_var) |
1322 | | - ret.set_aggregate(val) |
1323 | | - return ret |
1324 | | - else: |
1325 | | - return builder.make_aggregate(val, nominal, result_var=result_var) |
1326 | | - |
1327 | | - |
1328 | 1223 | @dataclass(eq=False) |
1329 | 1224 | class TileNumBlocks(Operation, opcode="tile_num_blocks"): |
1330 | 1225 | axis: int = attribute() |
@@ -3921,6 +3816,35 @@ def store_advanced_impl(array: Var, indices: Var, tile: Var, |
3921 | 3816 | latency=latency_val, allow_tma=allow_tma_val) |
3922 | 3817 |
|
3923 | 3818 |
|
| 3819 | +@tile_impl_registry.unflatten_aggregate_impl(ArrayTy) |
| 3820 | +def _unflatten_aggregate_array_impl(val: ArrayValue, ty: ArrayTy, result_var: Var): |
| 3821 | + assert isinstance(val, ArrayValue) |
| 3822 | + base_ptr = val.base_ptr |
| 3823 | + shape = tuple(assume_bounded(x, 0, None) for x in val.shape) |
| 3824 | + |
| 3825 | + all_strides = [] |
| 3826 | + dynamic_strides = [] |
| 3827 | + for x, s in zip(val.strides, ty.strides, strict=True): |
| 3828 | + if s is None: |
| 3829 | + x = assume_bounded(x, 0, None) |
| 3830 | + dynamic_strides.append(x) |
| 3831 | + all_strides.append(x) |
| 3832 | + |
| 3833 | + operands = dict(base_ptr=base_ptr, shape=shape, dynamic_strides=tuple(dynamic_strides)) |
| 3834 | + ret = Builder.get_current().add_operation(MakeTensorView, ty, operands, result_var) |
| 3835 | + ret.set_aggregate(ArrayValue(base_ptr, shape, tuple(all_strides))) |
| 3836 | + return ret |
| 3837 | + |
| 3838 | + |
| 3839 | +@tile_impl_registry.unflatten_aggregate_impl(ListTy) |
| 3840 | +def _unflatten_aggregate_list_impl(val: ListValue, ty: ListTy, result_var: Var): |
| 3841 | + assert isinstance(val, ListValue) |
| 3842 | + operands = dict(base_ptr=val.base_ptr, length=val.length) |
| 3843 | + ret = Builder.get_current().add_operation(MakeListView, ty, operands, result_var) |
| 3844 | + ret.set_aggregate(val) |
| 3845 | + return ret |
| 3846 | + |
| 3847 | + |
3924 | 3848 | def _add_dummy_op_to_invalid_vars(vars: Sequence[Var], |
3925 | 3849 | actual_types: Sequence[Type]) -> tuple[Var, ...]: |
3926 | 3850 | return tuple(add_operation(MakeDummy, actual) |
|
0 commit comments