Skip to content

Commit c3d3405

Browse files
committed
Use TensorLikeTy instead of TileTy in binary elementwise ops
This should be a non-functional change. This is the first step toward separating the type systems of cuda.tile and cuda.lang, while still sharing some common operation implementations. Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent abcdc7c commit c3d3405

12 files changed

Lines changed: 326 additions & 201 deletions

File tree

experimental/cuda-lang/src/cuda/lang/_ir/_host_program.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
from typing import Literal
88

99
from cuda.lang._ir import ir
10-
from cuda.tile._ir.ops import loosely_typed_const, raw_binary_arithmetic, strictly_typed_const, \
11-
unary, _UNARY_BOOL_INT, raw_binary_bitwise, TypedConst, AssumeBounded, AssumeDivBy, Assign
10+
from cuda.tile._ir.ops import (
11+
loosely_typed_const, binary_arithmetic_tensorlike_raw,
12+
strictly_typed_const, unary, _UNARY_BOOL_INT, TypedConst, AssumeBounded, AssumeDivBy, Assign,
13+
binary_bitwise_tensorlike_raw
14+
)
1215
from cuda.tile._ir.typing_support import I32_TY, I64_TY
1316

1417
HostOpcode = Literal["Const", "KernelArgI32", "KernelArgI64", "Mul", "Add", "RoundUpToPow2"]
@@ -70,10 +73,10 @@ def host_program_to_ir(program: HostProgram, kernel_params: tuple[ir.Var, ...])
7073
case "KernelArgI64": stack.append(kernel_params[next(attrs)])
7174
case "Mul":
7275
b = stack.pop()
73-
stack[-1] = raw_binary_arithmetic("mul", stack[-1], b)
76+
stack[-1] = binary_arithmetic_tensorlike_raw("mul", stack[-1], b)
7477
case "Add":
7578
b = stack.pop()
76-
stack[-1] = raw_binary_arithmetic("add", stack[-1], b)
79+
stack[-1] = binary_arithmetic_tensorlike_raw("add", stack[-1], b)
7780
case "RoundUpToPow2":
7881
stack[-1] = _round_up_ir(stack[-1], next(attrs))
7982
case _:
@@ -86,7 +89,7 @@ def host_program_to_ir(program: HostProgram, kernel_params: tuple[ir.Var, ...])
8689
def _round_up_ir(value: ir.Var, alignment: int) -> ir.Var:
8790
value_ty = value.get_type()
8891
mask = strictly_typed_const(alignment - 1, value_ty)
89-
value_plus_mask = raw_binary_arithmetic("add", value, mask)
92+
value_plus_mask = binary_arithmetic_tensorlike_raw("add", value, mask)
9093
neg_mask = unary('neg', _UNARY_BOOL_INT, mask)
91-
rounded = raw_binary_bitwise('and_', value_plus_mask, neg_mask)
94+
rounded = binary_bitwise_tensorlike_raw('and_', value_plus_mask, neg_mask)
9295
return rounded

experimental/cuda-lang/src/cuda/lang/_ir/ir.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
from dataclasses import dataclass, field
77
import itertools
88
from collections import defaultdict
9+
from typing import Sequence
910

11+
from typing_extensions import override
12+
13+
from cuda.tile._datatype import DType
1014
from cuda.tile._ir.ir import (
1115
Block as TileBlock,
1216
Builder as TileBuilder,
@@ -18,8 +22,9 @@
1822
attribute,
1923
add_operation,
2024
format_var,
21-
AggregateValue,
25+
AggregateValue, TypingHooks,
2226
)
27+
from cuda.tile._ir.type import TensorLikeTy, TileTy
2328

2429

2530
class Builder:
@@ -86,11 +91,18 @@ def to_string(
8691
return f"{' ' * indent}^{self._name}({params}):\n{ops}"
8792

8893

94+
class _LangTypingHooks(TypingHooks):
95+
@override
96+
def get_tensor_like_type(self, dtype: DType, shape: Sequence[int]) -> TensorLikeTy:
97+
return TileTy(dtype, shape)
98+
99+
89100
class IRContext(TileIRContext):
90101
def __init__(self, log_ir_on_error: bool = True):
91102
self._block_names: dict[int, str] = {}
92103
self._block_counter: dict[str, itertools.count] = defaultdict(itertools.count)
93-
super().__init__(log_ir_on_error, tileiras_version=None)
104+
super().__init__(log_ir_on_error, tileiras_version=None,
105+
typing_hooks=_LangTypingHooks())
94106

95107
def make_block(self, name: str, loc: Loc, params: tuple[Var, ...] = ()) -> Block:
96108
block = Block(self, loc)

experimental/cuda-lang/src/cuda/lang/_ir/ops.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,16 @@
2424
require_tile_type, require_constant_bool, require_constant_pointer_info,
2525
require_scalar_pointer_type,
2626
)
27-
from cuda.lang._ir.type import (
28-
LocalArrayContextManagerTy, ContextManagerState, TensorMapTy,
29-
dtype_to_tensor_map_type, ArrayValue, PointerInfoTy
30-
)
27+
from cuda.tile._ir.type import TensorLikeTy
3128
from cuda.tile._ir.ops import (
32-
binary_arithmetic,
29+
binary_arithmetic_tensorlike,
30+
binary_arithmetic_tensorlike_raw,
3331
loosely_typed_const,
3432
tile_impl_registry,
3533
bind_method,
3634
build_tuple,
3735
strictly_typed_const,
3836
astype,
39-
raw_binary_arithmetic,
4037
Return,
4138
return_,
4239
Assign,
@@ -76,6 +73,8 @@
7673
from .. import _stub as stub
7774

7875
from .type import (
76+
LocalArrayContextManagerTy, ContextManagerState, TensorMapTy,
77+
dtype_to_tensor_map_type, ArrayValue, PointerInfoTy,
7978
MemorySpace,
8079
Type,
8180
make_vector_ty,
@@ -285,8 +284,8 @@ def _array_linear_offset(array: Var, indices: tuple[Var, ...]) -> Var:
285284
for index, stride in zip(indices, array_val.strides, strict=True):
286285
index = astype(index, datatype.uint64)
287286
stride = astype(stride, datatype.uint64)
288-
scaled = raw_binary_arithmetic("mul", index, stride)
289-
offset = raw_binary_arithmetic("add", offset, scaled)
287+
scaled = binary_arithmetic_tensorlike_raw("mul", index, stride)
288+
offset = binary_arithmetic_tensorlike_raw("add", offset, scaled)
290289
return offset
291290

292291

@@ -538,7 +537,7 @@ def _is_pointer_type(ty):
538537
return isinstance(ty, TileTy) and is_pointer_dtype(ty.dtype)
539538

540539

541-
@impl(operator.add, overload=(TileTy, TileTy))
540+
@impl(operator.add, overload=(TensorLikeTy, TensorLikeTy))
542541
async def add_impl(x: Var, y: Var) -> Var:
543542
xty, yty = x.get_type(), y.get_type()
544543
if _is_pointer_type(yty):
@@ -554,7 +553,7 @@ async def add_impl(x: Var, y: Var) -> Var:
554553
return await call_function(operator.add, x, y)
555554

556555

557-
@impl(operator.sub, overload=(TileTy, TileTy))
556+
@impl(operator.sub, overload=(TensorLikeTy, TensorLikeTy))
558557
async def sub_impl(x: Var, y: Var) -> Var:
559558
xty, yty = x.get_type(), y.get_type()
560559
if _is_pointer_type(xty):
@@ -563,7 +562,7 @@ async def sub_impl(x: Var, y: Var) -> Var:
563562
raise TileTypeError(f"Expected integer pointer offset, got {offset_dtype}")
564563
y = astype(y, datatype.int64)
565564
c0 = loosely_typed_const(0)
566-
offset = binary_arithmetic('sub', c0, y)
565+
offset = binary_arithmetic_tensorlike('sub', c0, y)
567566
return _pointer_with_offset(x, offset)
568567
if _is_pointer_type(yty):
569568
raise TileTypeError('It is invalid to subtract a pointer from an integer')
@@ -868,7 +867,7 @@ def shared_array_impl(shape: Var, dtype: Var, dynamic: Var, alignment: Var) -> O
868867

869868
if size is None or total_size is None:
870869
total_size = None
871-
total_size_var = raw_binary_arithmetic("mul", total_size_var, size_var)
870+
total_size_var = binary_arithmetic_tensorlike_raw("mul", total_size_var, size_var)
872871
else:
873872
total_size *= size
874873
total_size_var = strictly_typed_const(total_size, size_ty)

src/cuda/tile/_compile.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@
2727
from cuda.tile._bytecode.version import BytecodeVersion
2828
from cuda.tile._cext import get_compute_capability, TileContext, default_tile_context
2929
from cuda.tile._compiler_options import CompilerOptions
30+
from cuda.tile._datatype import DType
3031
from cuda.tile._exception import (
3132
TileCompilerError,
3233
TileCompilerExecutionError,
3334
TileCompilerTimeoutError, FunctionDesc, Loc
3435
)
3536
from cuda.tile._ir import ir, hir
37+
from cuda.tile._ir.ir import TypingHooks
3638
from cuda.tile._ir.ops import loosely_typed_const, flatten_block_parameters, tile_impl_registry
3739
from cuda.tile._ir.type import TileTy, ArrayTy, ListTy
3840
from cuda.tile._passes.ast2hir import get_function_hir
@@ -234,6 +236,11 @@ def unique_path_from_func_desc(base_dir: str, desc: FunctionDesc, suffix: str, m
234236
yield f
235237

236238

239+
class _TileTypingHooks(TypingHooks):
240+
def get_tensor_like_type(self, dtype: DType, shape: Sequence[int]) -> TileTy:
241+
return TileTy(dtype, shape)
242+
243+
237244
class _IrKeeper:
238245
def __init__(self,
239246
ann_func: AnnotatedFunction,
@@ -260,7 +267,8 @@ def get_final_ir(self, signature_index: int) -> ir.Block:
260267
sig = self.signatures[signature_index]
261268
param_names = tuple(self.ann_func.pysig.parameters.keys())
262269
ir_ctx = ir.IRContext(log_ir_on_error=self._log_cutile_ir,
263-
tileiras_version=self.bytecode_version)
270+
tileiras_version=self.bytecode_version,
271+
typing_hooks=_TileTypingHooks())
264272
with ir.Builder(ir_ctx, self._func_hir.body.loc) as ir_builder:
265273
params = _create_kernel_parameters(sig.parameters,
266274
self.ann_func.constant_parameter_mask,

src/cuda/tile/_ir/ir.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,26 @@
1818
)
1919

2020
from cuda.tile._ir.aggregate_value import AggregateValue
21-
from cuda.tile._ir.type import Type, InvalidType, LooselyTypedScalar
21+
from cuda.tile._ir.type import Type, InvalidType, LooselyTypedScalar, TensorLikeTy
2222
from cuda.tile._exception import (
2323
TileTypeError, Loc, TileInternalError
2424
)
25+
from cuda.tile._datatype import DType
2526
from cuda.tile._bytecode.version import BytecodeVersion
2627

2728

2829
if TYPE_CHECKING:
2930
from cuda.tile._ir2bytecode import BytecodeContext
3031

3132

33+
class TypingHooks:
34+
def get_tensor_like_type(self, dtype: DType, shape: Sequence[int]) -> TensorLikeTy:
35+
raise NotImplementedError()
36+
37+
3238
class IRContext:
33-
def __init__(self, log_ir_on_error: bool, tileiras_version: BytecodeVersion):
39+
def __init__(self, log_ir_on_error: bool, tileiras_version: BytecodeVersion,
40+
typing_hooks: TypingHooks):
3441
self._all_vars: Dict[str, str] = {}
3542
self._counter_by_name: Dict[str, Iterator[int]] = defaultdict(itertools.count)
3643
self._temp_counter = itertools.count()
@@ -41,6 +48,7 @@ def __init__(self, log_ir_on_error: bool, tileiras_version: BytecodeVersion):
4148
self._aggregate_values: Dict[str, Any] = dict()
4249
self.tileiras_version: BytecodeVersion = tileiras_version
4350
self._function_specialization_id_counter = itertools.count()
51+
self.typing_hooks = typing_hooks
4452

4553
def next_function_specialization_id(self) -> str:
4654
# Monotonic counter used as a unique id when creating concrete FunctionDescs

src/cuda/tile/_ir/op_impl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,15 @@ def require_constant_axis_order(var: Var, rank: int) -> Tuple[int, ...]:
452452
return tuple(normalize_axis(x, rank, var) for x in value)
453453

454454

455-
def require_tile_type(var: Var) -> TileTy:
455+
def ensure_tile(var: Var) -> Var[TileTy]:
456456
ty = var.get_type()
457457
if not isinstance(ty, TileTy):
458458
raise _make_type_error(f"Expected a tile, but given value has type {ty}", var)
459-
return ty
459+
return var
460+
461+
462+
def require_tile_type(var: Var) -> TileTy:
463+
return ensure_tile(var).get_type()
460464

461465

462466
def require_tile_or_tile_tuple_type(var: Var) -> TileTy | TupleTy:

0 commit comments

Comments
 (0)