Skip to content

Commit 64ad5db

Browse files
committed
[lang] Use Scalar/Pointer/Vector types instead of TileTy in cuda.lang
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent a69f4fe commit 64ad5db

33 files changed

Lines changed: 845 additions & 992 deletions

experimental/cuda-lang/src/cuda/lang/__init__.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from cuda.tile._memory_model import (
66
MemoryScope,
77
MemoryOrder,
8+
MemorySpace
89
)
910

1011
from ._execution import (
@@ -15,11 +16,19 @@
1516

1617
from ._compile import compile_simt
1718

18-
from ._stub import (
19+
from cuda.tile._datatype import (
1920
pointer_dtype,
2021
opaque_pointer_dtype,
2122
is_pointer_dtype,
2223
PointerInfo,
24+
)
25+
from ._stub.types import (
26+
Scalar,
27+
Vector,
28+
Pointer,
29+
)
30+
from ._stub.core_api import (
31+
dtype_of,
2332
warp_size,
2433
full_mask,
2534
block_idx,
@@ -46,11 +55,7 @@
4655
setmaxregister_increase,
4756
setmaxregister_decrease,
4857
elect_sync,
49-
Constant,
5058
Array,
51-
Pointer,
52-
Vector,
53-
printf,
5459
shared_array,
5560
local_array,
5661
address_space_cast,
@@ -65,17 +70,33 @@
6570
syncthreads,
6671
syncwarp,
6772
nvvm,
68-
libdevice,
73+
nanosleep,
74+
griddepcontrol_wait,
75+
griddepcontrol_launch_dependents,
76+
memory_barrier,
77+
)
78+
from cuda.tile._stub import (
79+
Constant
80+
)
81+
82+
from cuda.lang._stub import libdevice
83+
84+
from cuda.lang._stub.tensor_map import (
6985
TensorMapSwizzle,
7086
TensorMap,
7187
tensor_map_tiled,
88+
)
89+
90+
from cuda.lang._stub.tcgen05 import (
7291
CTAGroup,
7392
Tcgen05LdStShape,
7493
tcgen05_alloc,
7594
tcgen05_dealloc,
7695
tcgen05_commit,
7796
tcgen05_ld,
78-
nanosleep,
97+
)
98+
99+
from cuda.lang._stub.mbarrier import (
79100
MbarrierScope,
80101
mbarrier_init,
81102
mbarrier_invalidate,
@@ -87,12 +108,12 @@
87108
mbarrier_test_wait_parity,
88109
mbarrier_try_wait,
89110
mbarrier_try_wait_parity,
111+
)
112+
113+
from cuda.lang._stub.cluster_launch_control import (
90114
clusterlaunchcontrol_try_cancel,
91115
clusterlaunchcontrol_is_canceled,
92116
clusterlaunchcontrol_get_first_block_idx,
93-
griddepcontrol_wait,
94-
griddepcontrol_launch_dependents,
95-
memory_barrier,
96117
)
97118

98119
from ._datatype import (
@@ -114,10 +135,10 @@
114135
uint64,
115136
mbarrier,
116137
clusterlaunchcontrol_token,
117-
MemorySpace,
118138
)
119139

120140
__all__ = (
141+
"dtype_of",
121142
"pointer_dtype",
122143
"opaque_pointer_dtype",
123144
"is_pointer_dtype",
@@ -172,7 +193,6 @@
172193
"uint32",
173194
"uint64",
174195
"Constant",
175-
"printf",
176196
"shared_array",
177197
"local_array",
178198
"address_space_cast",
@@ -184,6 +204,7 @@
184204
"syncwarp",
185205
"Array",
186206
"Pointer",
207+
"Scalar",
187208
"Vector",
188209
"nvvm",
189210
"libdevice",

experimental/cuda-lang/src/cuda/lang/_compile.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from cuda.lang._passes.ir2mlir import ir2mlir
2525
from cuda.lang._passes.flatten_cfg import flatten_cfg
2626
from cuda.lang._passes.simt_semantics import simt_semantic_analysis
27-
from cuda.lang._passes.canonicalize_parameters import canonicalize_parameters
2827
from cuda.lang._passes.handle_dyn_shared_mem import handle_dynamic_shared_memory
2928
from cuda.lang._passes.hoist_tensor_map import hoist_tensor_maps, HoistedTensorMap
3029
from cuda.lang.compilation import (
@@ -137,7 +136,6 @@ def get_function_ir(
137136
function.param_locs,
138137
ctx
139138
)
140-
canonicalize_parameters(params, builder)
141139
hir2ir(function, params.aggregate_vars, ctx)
142140
func_body = ctx.make_block("entry", function.body.loc)
143141
func_body.params = sum((vars for vars, _ in params.nonconstant_flat_vars), ())

experimental/cuda-lang/src/cuda/lang/_datatype.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,8 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import TypeAlias, Union
5+
from typing import TypeAlias
66

7-
from cuda.lang._ir.type import (
8-
MemorySpace,
9-
TileTy,
10-
)
11-
from cuda.tile._stub import Tile
127
from cuda.tile._datatype import (
138
DType,
149
bfloat16,
@@ -52,11 +47,11 @@
5247
)
5348

5449

55-
def to_torch_dtype(dtype: DType | TileTy):
56-
import torch
50+
def to_torch_dtype(dtype: DType, /):
51+
if not isinstance(dtype, DType):
52+
raise TypeError("Expected a DType object")
5753

58-
if isinstance(dtype, TileTy):
59-
dtype = dtype.dtype
54+
import torch
6055

6156
dtype_map = {
6257
bool_: torch.bool,
@@ -87,11 +82,7 @@ def to_torch_dtype(dtype: DType | TileTy):
8782
raise NotImplementedError(f"No torch dtype mapping for {dtype}")
8883

8984

90-
TypeSpec: TypeAlias = Union[DType | TileTy]
91-
92-
93-
def is_any_pointer(value):
94-
return isinstance(value, Tile) and value.ndim == 0 and is_pointer_dtype(value.dtype)
85+
TypeSpec: TypeAlias = DType
9586

9687

9788
__all__ = [
@@ -102,7 +93,6 @@ def is_any_pointer(value):
10293
"is_boolean",
10394
"is_integral",
10495
"is_signed",
105-
"is_any_pointer",
10696
"is_pointer_dtype",
10797
"pointer_dtype",
10898
"opaque_pointer_dtype",
@@ -131,5 +121,4 @@ def is_any_pointer(value):
131121
"DType",
132122
"to_torch_dtype",
133123
"default_int_type",
134-
"MemorySpace",
135124
]

experimental/cuda-lang/src/cuda/lang/_execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class kernel(_cext.TileDispatcher):
3333
3434
@cl.kernel
3535
def kernel():
36-
cl.printf("Hello!\\n")
36+
print("Hello!")
3737
3838
cl.launch(stream, (1,), (3,), kernel, ())
3939

experimental/cuda-lang/src/cuda/lang/_ir/_host_program.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
binary_bitwise_tensorlike_raw
1515
)
1616
from cuda.tile._ir.core_ops import TypedConst, Assign, loosely_typed_const, strictly_typed_const
17+
from cuda.lang._ir.type import ScalarTy
1718
from cuda.tile._ir.ops import (
1819
AssumeBounded, AssumeDivBy,
1920
)
20-
from cuda.tile._ir.typing_support import I32_TY, I64_TY
21+
from cuda.tile._datatype import int32, int64
2122

2223
HostOpcode = Literal["Const", "KernelArgI32", "KernelArgI64", "Mul", "Add", "RoundUpToPow2"]
2324

@@ -50,9 +51,9 @@ def get_host_programs_by_var(kernel_body: ir.Block) -> dict[str, HostProgram]:
5051
ret = dict()
5152
for i, p in enumerate(kernel_body.params):
5253
ty = p.get_type()
53-
if ty == I32_TY:
54+
if ty == ScalarTy(int32):
5455
ret[p.name] = HostProgram(opcodes=["KernelArgI32"], op_attrs=[i])
55-
elif ty == I64_TY:
56+
elif ty == ScalarTy(int64):
5657
ret[p.name] = HostProgram(opcodes=["KernelArgI64"], op_attrs=[i])
5758

5859
for op in kernel_body:

experimental/cuda-lang/src/cuda/lang/_ir/ir.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@
66
from dataclasses import dataclass, field
77
import itertools
88
from collections import defaultdict
9-
from typing import Sequence
10-
11-
from typing_extensions import override
12-
13-
from cuda.tile._datatype import DType
149
from cuda.tile._ir.ir import (
1510
Block as TileBlock,
1611
Builder as TileBuilder,
@@ -22,9 +17,8 @@
2217
attribute,
2318
add_operation,
2419
format_var,
25-
AggregateValue, TypingHooks,
20+
AggregateValue,
2621
)
27-
from cuda.tile._ir.type import TensorLikeTy, TileTy
2822

2923

3024
class Builder:
@@ -91,18 +85,13 @@ def to_string(
9185
return f"{' ' * indent}^{self._name}({params}):\n{ops}"
9286

9387

94-
class _LangTypingHooks(TypingHooks):
95-
@override
96-
def get_tensor_like_type(self, dtype: DType, shape: Sequence[int]) -> TensorLikeTy:
97-
return TileTy(dtype, shape)
98-
99-
10088
class IRContext(TileIRContext):
10189
def __init__(self, log_ir_on_error: bool = True):
90+
from cuda.lang._ir.type import LangTypingHooks
10291
self._block_names: dict[int, str] = {}
10392
self._block_counter: dict[str, itertools.count] = defaultdict(itertools.count)
10493
super().__init__(log_ir_on_error, tileiras_version=None,
105-
typing_hooks=_LangTypingHooks())
94+
typing_hooks=LangTypingHooks())
10695

10796
def make_block(self, name: str, loc: Loc, params: tuple[Var, ...] = ()) -> Block:
10897
block = Block(self, loc)

0 commit comments

Comments
 (0)