Skip to content

Commit 8959f49

Browse files
committed
refactor partition view
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent 72af28b commit 8959f49

File tree

8 files changed

+170
-75
lines changed

8 files changed

+170
-75
lines changed

src/cuda/tile/_compile.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from cuda.tile._ir.typing_support import typeof_pyval, get_constant_value
3939
from cuda.tile._passes.ast2hir import get_function_hir
4040
from cuda.tile._passes.code_motion import hoist_loop_invariants
41+
from cuda.tile._passes.unhoist_partition_views import unhoist_partition_views
4142
from cuda.tile._passes.eliminate_assign_ops import eliminate_assign_ops
4243
from cuda.tile._passes.hir2ir import hir2ir
4344
from cuda.tile._passes.loop_split import split_loops
@@ -103,6 +104,11 @@ def _get_final_ir(pyfunc,
103104
# Otherwise, it may incorrectly hoist load operations out of the loop.
104105
hoist_loop_invariants(func_body)
105106

107+
# For version < V_13_3, MakePartitionView must be emitted inline before its consumer.
108+
# Code motion may hoist it to an outer block; copy it back where needed.
109+
if tileiras_version < BytecodeVersion.V_13_3:
110+
unhoist_partition_views(func_body)
111+
106112
split_loops(func_body)
107113
dead_code_elimination_pass(func_body)
108114
return ir.Function(func_body, func_hir.desc.name, func_hir.body.loc)

src/cuda/tile/_ir/ops.py

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
BYTE_BITWIDTH, typeof_pyval, dtype_registry, loose_type_of_pyval, get_constant_value
5454
)
5555
from .type import (
56-
TupleTy, TileTy, NoneType, BoundMethodTy, ArrayTy,
56+
PartitionViewTy, TupleTy, TileTy, NoneType, BoundMethodTy, ArrayTy,
5757
ListTy, make_tile_ty, SliceType, DTypeConstructor, RangeIterType, Type,
5858
NONE, ModuleTy, TypeTy, LooselyTypedScalar, DTypeSpec, StringTy, InvalidType,
5959
array_size_type, ClosureTy, LiveCapturedScope, TokenTy,
@@ -2012,26 +2012,43 @@ def _check_load_store_hints(latency_value: int | None, allow_tma_value: bool | N
20122012
raise TileTypeError(f"Allow TMA must be a boolean, got {allow_tma_value}")
20132013

20142014

2015+
@dataclass(eq=False)
2016+
class MakePartitionView(Operation, opcode="make_partition_view"):
2017+
array: Var = operand()
2018+
2019+
@override
2020+
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
2021+
partition_view_ty = self.result_var.get_type()
2022+
return bc.encode_MakePartitionViewOp(ctx.builder,
2023+
typeid(ctx.type_table, partition_view_ty),
2024+
ctx.get_value(self.array))
2025+
2026+
2027+
def make_partition_view(array: Var, tile_shape: Sequence[int],
2028+
order: Sequence[int],
2029+
padding_mode: PaddingMode) -> Var:
2030+
array_ty = array.get_type()
2031+
assert isinstance(array_ty, ArrayTy)
2032+
view_ty = PartitionViewTy(array_ty, tuple(tile_shape), tuple(order), padding_mode)
2033+
return add_operation(MakePartitionView, view_ty, array=array)
2034+
2035+
20152036
@dataclass(eq=False)
20162037
class TileLoad(Operation, opcode="tile_load", memory_effect=MemoryEffect.LOAD):
2017-
order: tuple[int, ...] = attribute()
2018-
padding_mode: PaddingMode = attribute()
20192038
latency: Optional[int] = attribute()
20202039
allow_tma: Optional[bool] = attribute()
2021-
array: Var = operand()
2040+
view: Var = operand()
20222041
index: tuple[Var, ...] = operand()
20232042
token: Optional[Var] = operand(default=None)
20242043

20252044
@override
20262045
def generate_bytecode(self, ctx: BytecodeContext) -> tuple[bc.Value, bc.Value]:
20272046
tile_type: TileTy = self.result_vars[0].get_type()
2028-
shape = tile_type.shape
2029-
partition = ctx.make_partition_view(self.array, self.order, shape, self.padding_mode)
20302047
res, res_token = bc.encode_LoadViewTkoOp(
20312048
ctx.builder,
20322049
tile_type=typeid(ctx.type_table, tile_type),
20332050
result_token_type=ctx.type_table.Token,
2034-
view=partition,
2051+
view=ctx.get_value(self.view),
20352052
index=ctx.index_tuple(self.index),
20362053
token=None if self.token is None else ctx.get_value(self.token),
20372054
memory_ordering_semantics=bc.MemoryOrderingSemantics.WEAK,
@@ -2041,16 +2058,6 @@ def generate_bytecode(self, ctx: BytecodeContext) -> tuple[bc.Value, bc.Value]:
20412058
return res, res_token
20422059

20432060

2044-
def tile_load(array: Var, index: tuple[Var, ...], shape: Sequence[int], order: Sequence[int],
2045-
padding_mode: PaddingMode, latency: Optional[int],
2046-
allow_tma: Optional[bool]) -> tuple[Var, Var]:
2047-
res_ty = make_tile_ty(array.get_type().dtype, shape)
2048-
return add_operation(TileLoad, (res_ty, TokenTy()),
2049-
array=array, index=index, order=tuple(order),
2050-
padding_mode=padding_mode, latency=latency,
2051-
allow_tma=allow_tma)
2052-
2053-
20542061
@impl(ct.load)
20552062
def tile_load_impl(array: Var, index: Var, shape: Var, order: Var,
20562063
padding_mode: Var, latency: Var, allow_tma: Var) -> Var:
@@ -2070,32 +2077,31 @@ def tile_load_impl(array: Var, index: Var, shape: Var, order: Var,
20702077
latency = require_optional_constant_int(latency)
20712078
allow_tma = require_optional_constant_bool(allow_tma)
20722079
_check_load_store_hints(latency, allow_tma)
2073-
result, _token = tile_load(array, index_items, broadcasted_shape, order, padding_mode, latency,
2074-
allow_tma)
2080+
2081+
view = make_partition_view(array, broadcasted_shape, order, padding_mode)
2082+
res_ty = make_tile_ty(array_ty.dtype, broadcasted_shape)
2083+
result, _token = add_operation(TileLoad, (res_ty, TokenTy()),
2084+
view=view, index=index_items, latency=latency,
2085+
allow_tma=allow_tma)
20752086
return reshape(result, shape)
20762087

20772088

20782089
@dataclass(eq=False)
20792090
class TileStore(Operation, opcode="tile_store", memory_effect=MemoryEffect.STORE):
2080-
order: tuple[int, ...] = attribute()
20812091
latency: Optional[int] = attribute()
20822092
allow_tma: Optional[bool] = attribute()
2083-
array: Var = operand()
2093+
view: Var = operand()
20842094
index: tuple[Var, ...] = operand()
20852095
tile: Var = operand()
20862096
token: Optional[Var] = operand(default=None)
20872097

20882098
@override
20892099
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
2090-
tile_ty = self.tile.get_type()
2091-
tile_shape = tile_ty.shape
2092-
partition = ctx.make_partition_view(self.array, self.order, tile_shape,
2093-
padding_mode=PaddingMode.UNDETERMINED)
20942100
return bc.encode_StoreViewTkoOp(
20952101
ctx.builder,
20962102
result_token_type=ctx.type_table.Token,
20972103
tile=ctx.get_value(self.tile),
2098-
view=partition,
2104+
view=ctx.get_value(self.view),
20992105
index=ctx.index_tuple(self.index),
21002106
token=None if self.token is None else ctx.get_value(self.token),
21012107
memory_ordering_semantics=bc.MemoryOrderingSemantics.WEAK,
@@ -2104,17 +2110,6 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
21042110
)
21052111

21062112

2107-
def tile_store(array: Var, index: tuple[Var, ...], tile: Var, order: Sequence[int],
2108-
latency: Optional[int], allow_tma: Optional[bool]) -> Var:
2109-
array_ty = array.get_type()
2110-
ty = require_tile_type(tile)
2111-
tile = astype(tile, array_ty.dtype)
2112-
if ty.ndim == 0:
2113-
tile = reshape(tile, (1,) * array_ty.ndim)
2114-
return add_operation(TileStore, (TokenTy(),), array=array, index=index, tile=tile,
2115-
order=tuple(order), latency=latency, allow_tma=allow_tma)
2116-
2117-
21182113
def _implicit_cast(src: Var, target_dtype: DType, error_context: str) -> Var:
21192114
ty = require_tile_maybe_loose_type(src)
21202115
try:
@@ -2130,20 +2125,25 @@ def _implicit_cast(src: Var, target_dtype: DType, error_context: str) -> Var:
21302125
def tile_store_impl(array: Var, index: Var, tile: Var, order: Var,
21312126
latency: Var, allow_tma: Var):
21322127
array_ty = require_array_type(array)
2128+
tile_ty = require_tile_type(tile)
21332129
index_ty = require_index_or_index_tuple_type(index)
21342130
index_items = index.get_aggregate().items if isinstance(index_ty, TupleTy) else (index,)
21352131
if array_ty.ndim != len(index_items):
21362132
raise TileTypeError(f"Index size {len(index_items)}"
21372133
f" does not match the array rank {array_ty.ndim}")
21382134

2139-
tile = _implicit_cast(tile, array_ty.dtype, "Stored tile is incompatible with array's dtype")
2140-
2135+
shape = tile_ty.shape
2136+
broadcasted_shape = (1,) * array_ty.ndim if len(shape) == 0 else shape
21412137
order = require_constant_axis_order(order, array_ty.ndim)
21422138
latency = require_optional_constant_int(latency)
21432139
allow_tma = require_optional_constant_bool(allow_tma)
21442140
_check_load_store_hints(latency, allow_tma)
21452141

2146-
[_token] = tile_store(array, index_items, tile, order, latency, allow_tma)
2142+
tile = _implicit_cast(tile, array_ty.dtype, "Stored tile is incompatible with array's dtype")
2143+
tile = reshape(tile, broadcasted_shape)
2144+
view = make_partition_view(array, broadcasted_shape, order, PaddingMode.UNDETERMINED)
2145+
[_token] = add_operation(TileStore, (TokenTy(),), view=view, index=index_items, tile=tile,
2146+
latency=latency, allow_tma=allow_tma)
21472147

21482148

21492149
@dataclass(eq=False)
@@ -2618,35 +2618,30 @@ def join_tokens(tokens: Tuple[Var, ...], *, block: Block, res: Var, loc: Loc) ->
26182618
@dataclass(eq=False)
26192619
class NumTiles(Operation, opcode="num_tiles"):
26202620
axis: int = attribute()
2621-
shape: tuple[int, ...] = attribute()
2622-
order: tuple[int, ...] = attribute()
2623-
array: Var = operand()
2621+
view: Var = operand()
26242622

26252623
@override
26262624
def generate_bytecode(self, ctx: BytecodeContext):
2627-
pv = ctx.make_partition_view(self.array, self.order, self.shape,
2628-
padding_mode=PaddingMode.UNDETERMINED)
2629-
result_types = [ctx.type_table.tile(ctx.type_table.I32, ())] * len(self.shape)
2630-
values = bc.encode_GetIndexSpaceShapeOp(ctx.builder, result_types, pv)
2625+
view_ty: PartitionViewTy = self.view.get_type()
2626+
result_types = [ctx.type_table.tile(ctx.type_table.I32, ())] * len(view_ty.tile_shape)
2627+
values = bc.encode_GetIndexSpaceShapeOp(ctx.builder, result_types,
2628+
src=ctx.get_value(self.view))
26312629
return values[self.axis]
26322630

26332631

2634-
def num_tiles(array: Var, axis: int, shape: Sequence[int], order: Sequence[int]) -> Var:
2635-
return add_operation(NumTiles, make_tile_ty(datatype.default_int_type, ()), array=array,
2636-
axis=axis, shape=tuple(shape), order=tuple(order))
2637-
2638-
26392632
@impl(ct.num_tiles)
26402633
def num_tiles_impl(array: Var, axis: Var, shape: Var, order: Var) -> Var:
26412634
array_ty = require_array_type(array)
26422635
axis = require_constant_int(axis)
26432636
axis = normalize_axis(axis, array_ty.ndim)
26442637
shape = require_constant_shape(shape, allow_single_int=True, expected_rank=array_ty.ndim,
26452638
allow_0d_shape=True)
2646-
if len(shape) == 0:
2647-
shape = (1,) * array_ty.ndim
2639+
broadcasted_shape = (1,) * array_ty.ndim if len(shape) == 0 else shape
26482640
order = require_constant_axis_order(order, array_ty.ndim)
2649-
return num_tiles(array, axis, shape, order)
2641+
2642+
view = make_partition_view(array, broadcasted_shape, order, PaddingMode.UNDETERMINED)
2643+
return add_operation(NumTiles, make_tile_ty(datatype.default_int_type, ()), view=view,
2644+
axis=axis)
26502645

26512646

26522647
def full_const(shape: Sequence[int], fill_value: int | float, dtype: DType) -> Var:

src/cuda/tile/_ir/type.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import TYPE_CHECKING
1414

1515
from cuda.tile._exception import Loc
16+
from cuda.tile._numeric_semantics import PaddingMode
1617

1718
if TYPE_CHECKING:
1819
from cuda.tile._datatype import DType
@@ -358,6 +359,25 @@ def __str__(self):
358359
return f"Array[{self.dtype},{shape_str}:{strides_str}]"
359360

360361

362+
# ============== PartitionView Type ===============
363+
364+
365+
@dataclass(frozen=True)
366+
class PartitionViewTy(Type):
367+
array_ty: ArrayTy
368+
tile_shape: tuple[int, ...]
369+
order: tuple[int, ...]
370+
padding_mode: PaddingMode
371+
372+
@property
373+
def dtype(self):
374+
return self.array_ty.dtype
375+
376+
def __str__(self):
377+
return (f"PartitionView[{self.array_ty},tile_shape={self.tile_shape},order={self.order},"
378+
f"padding_mode={self.padding_mode}]")
379+
380+
361381
# ============== List Type ===============
362382

363383

src/cuda/tile/_ir2bytecode.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import functools
66
import os
77
from contextlib import contextmanager
8-
from typing import Dict, Tuple, Sequence, Any, Optional
8+
from typing import Dict, Tuple, Any, Optional
99

1010
from cuda.tile import _datatype as datatype
1111
from cuda.tile._bytecode.attribute import make_load_store_hints
1212
from cuda.tile._datatype import get_signedness
13-
from cuda.tile import DType, PaddingMode
13+
from cuda.tile import DType
1414
import cuda.tile._bytecode as bc
1515
from cuda.tile._compiler_options import CompilerOptions
1616
from cuda.tile._exception import TileInternalError, TileError, FunctionDesc
@@ -19,7 +19,10 @@
1919
padding_mode_to_bytecode, rounding_mode_to_bytecode,
2020
get_default_rounding_mode,
2121
)
22-
from cuda.tile._ir.type import Type, TileTy, PointerTy, TokenTy, TupleTy, ArrayTy, size_to_bytecode
22+
from cuda.tile._ir.type import (
23+
PartitionViewTy, Type, TileTy, PointerTy, TokenTy, TupleTy, ArrayTy,
24+
size_to_bytecode,
25+
)
2326

2427

2528
def dtype_typeid(tt: bc.TypeTable, dtype: datatype.DType | PointerTy) -> bc.TypeId:
@@ -49,6 +52,11 @@ def typeid(tt: bc.TypeTable, ty: Type) -> bc.TypeId:
4952
return tt.tile(dtype, shape)
5053
elif isinstance(ty, TokenTy):
5154
return tt.Token
55+
elif isinstance(ty, PartitionViewTy):
56+
padding_value = padding_mode_to_bytecode[ty.padding_mode]
57+
assert isinstance(ty.array_ty, ArrayTy)
58+
tv_id = tensor_view_typeid(tt, ty.array_ty)
59+
return tt.partition_view(ty.tile_shape, tv_id, ty.order, padding_value)
5260
else:
5361
raise NotImplementedError(f"Lowering type '{ty}' is not supported")
5462

@@ -387,20 +395,6 @@ def load_store_hints(self,
387395
load_store_hints = bc.LoadStoreHints(latency=latency, allow_tma=allow_tma)
388396
return make_load_store_hints({self.sm_arch: load_store_hints})
389397

390-
def make_partition_view(self,
391-
array: Var,
392-
order: Sequence[int],
393-
tile_shape: Sequence[int],
394-
padding_mode: PaddingMode) -> bc.Value:
395-
padding_value = padding_mode_to_bytecode[padding_mode]
396-
array_ty = self.typeof(array)
397-
assert isinstance(array_ty, ArrayTy)
398-
view_ty_id = tensor_view_typeid(self.type_table, array_ty)
399-
partition_ty_id = self.type_table.partition_view(
400-
tile_shape, view_ty_id, order, padding_value)
401-
view = self.get_value(array)
402-
return bc.encode_MakePartitionViewOp(self.builder, partition_ty_id, view)
403-
404398

405399
def generate_bytecode_for_block(ctx: BytecodeContext, block: Block):
406400
for op in block.operations:

src/cuda/tile/_passes/alias_analysis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from cuda.tile._ir.ir import Var, Block
99
from cuda.tile._ir.ops import Assign, GetArrayListItem, \
10-
Loop, IfElse, Continue, Break, EndBranch, PointerOffset, \
10+
Loop, IfElse, Continue, Break, EndBranch, MakePartitionView, PointerOffset, \
1111
TileBroadcast, TileReshape, MakeTensorView, MakeListView, AssumeDivBy, TileReduce, TileScan
1212

1313

@@ -108,6 +108,8 @@ def _analyze_aliases_in_block(block: Block,
108108
alias_tracker[v.name] = ALIAS_UNIVERSE
109109
elif isinstance(op, MakeTensorView):
110110
_propagate(alias_tracker, op.base_ptr, op.result_var)
111+
elif isinstance(op, MakePartitionView):
112+
_propagate(alias_tracker, op.array, op.result_var)
111113
elif isinstance(op, MakeListView):
112114
_propagate(alias_tracker, op.base_ptr, op.result_var)
113115
elif isinstance(op, PointerOffset):

src/cuda/tile/_passes/token_order.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,13 @@ def token_order_pass(root_block: Block, alias_result: AliasResult):
112112

113113

114114
def _get_input_var(op: Operation):
115-
return op.array if "array" in op.operands else op.pointer
115+
if "view" in op.operands:
116+
return op.view
117+
elif "pointer" in op.operands:
118+
return op.pointer
119+
else:
120+
raise TileInternalError(f"Cannot determine input var for op {op}: "
121+
f"expected 'view' or 'pointer' operand")
116122

117123

118124
def _get_block_memory_effects(block: Block,
@@ -484,7 +490,7 @@ def is_idx_injective(idx_var: Var) -> bool:
484490
return loop_op.is_for_loop and idx_var.name == loop_op.induction_var.name
485491

486492
return set(store_op for store_op in tile_store_candidates
487-
if _get_input_var(store_op).get_type().elements_disjoint
493+
if _get_input_var(store_op).get_type().array_ty.elements_disjoint
488494
and any(is_idx_injective(idx_var) for idx_var in store_op.index))
489495

490496

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dataclasses
6+
from cuda.tile._ir.ir import Block, Mapper, Operation
7+
from cuda.tile._ir.ops import MakePartitionView, TileLoad, TileStore, NumTiles
8+
9+
10+
def unhoist_partition_views(root_block: Block):
11+
def_info: dict[str, tuple[Operation, Block]] = {}
12+
_unhoist(root_block, def_info)
13+
14+
15+
def _unhoist(block: Block, def_info: dict[str, tuple[Operation, Block]]):
16+
new_block = block.empty_like_self()
17+
for op in block:
18+
if isinstance(op, (TileLoad, TileStore, NumTiles)):
19+
view_def, def_block = def_info[op.view.name]
20+
if isinstance(view_def, MakePartitionView) and def_block is not block:
21+
mapper = Mapper(block.ctx)
22+
new_block.append(view_def.clone(mapper))
23+
op = dataclasses.replace(op, view=mapper.get_var(op.view))
24+
25+
for nested in op.nested_blocks:
26+
_unhoist(nested, def_info)
27+
28+
new_block.append(op)
29+
for v in op.result_vars:
30+
def_info[v.name] = (op, block)
31+
32+
block[:] = new_block.detach_all()

0 commit comments

Comments
 (0)