Skip to content

Commit c23916a

Browse files
committed
Move MemorySpace to tile
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 6c3b5f8 commit c23916a

5 files changed

Lines changed: 25 additions & 27 deletions

File tree

experimental/cuda-lang/src/cuda/lang/_compile.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from cuda.tile._exception import TileCompilerExecutionError
2020
from cuda.lang._logging import get_log_flags
2121
from cuda.lang._ir import ir, hir
22-
from cuda.lang._ir.type import MemorySpace
2322
from cuda.lang._passes.ast2hir import get_function_hir
2423
from cuda.lang._passes.ir2mlir import ir2mlir
2524
from cuda.lang._passes.flatten_cfg import flatten_cfg
@@ -103,8 +102,7 @@ def get_function_ir(
103102
constant_mask,
104103
parameter_names,
105104
function.param_locs,
106-
ctx,
107-
array_memory_space=MemorySpace.GENERIC
105+
ctx
108106
)
109107
canonicalize_parameters(params, builder)
110108
with cuda_lang_impl_registry.as_current():

experimental/cuda-lang/src/cuda/lang/_datatype.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,6 @@ def satisfies_pointer_constraint(value, constraint: OpaquePointerSpec):
158158
if isinstance(pointer_ty, TilePointerTy):
159159
if constraint == opaque_ptr:
160160
return True
161-
if pointer_ty.memory_space is None:
162-
return False
163161
return pointer_ty.memory_space.value == constraint.value
164162

165163
return pointer_ty.memory_space.value == constraint.value

experimental/cuda-lang/src/cuda/lang/_ir/type.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
from dataclasses import dataclass
5-
from enum import Enum
65

76
from cuda.lang._ir.ir import LocalArrayContextManagerValue
87
from cuda.lang._enums import TensorMapSwizzle
@@ -22,24 +21,14 @@
2221
EnumTy,
2322
make_tile_ty,
2423
ContextManagerTy,
25-
ContextManagerState,
24+
ContextManagerState, MemorySpace,
2625
)
2726
import cuda.tile._datatype as datatype
2827
from cuda.tile._datatype import DType
2928
from cuda.tile._ir.ir import Var, AggregateValue
3029
from cuda.lang._exception import TileTypeError
3130

3231

33-
class MemorySpace(Enum):
34-
GENERIC = 0
35-
GLOBAL = 1
36-
SHARED = 3
37-
CONSTANT = 4
38-
LOCAL = 5
39-
TENSOR = 6
40-
SHARED_CLUSTER = 7
41-
42-
4332
def _is_power_of_2(value: int) -> bool:
4433
assert isinstance(value, int)
4534
return value > 0 and value & (value - 1) == 0

src/cuda/tile/_compile.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ def _create_kernel_parameters(parameter_constraints: Sequence[ParameterConstrain
124124
constant_parameter_mask: Sequence[bool],
125125
parameter_names: Sequence[str],
126126
parameter_locations: Sequence[Loc],
127-
ir_ctx: ir.IRContext,
128-
array_memory_space=None) -> _KernelParameters:
127+
ir_ctx: ir.IRContext) -> _KernelParameters:
129128
aggregate_vars = []
130129
nonconstant_flat_vars = []
131130
for pos, (constraint, is_const, name, loc) in enumerate(
@@ -140,10 +139,10 @@ def _create_kernel_parameters(parameter_constraints: Sequence[ParameterConstrain
140139
if isinstance(constraint, ScalarConstraint):
141140
ty = TileTy(constraint.dtype, ())
142141
elif isinstance(constraint, ArrayConstraint):
143-
ty = _get_array_ty(constraint, array_memory_space)
142+
ty = _get_array_ty(constraint)
144143
elif isinstance(constraint, ListConstraint):
145144
assert isinstance(constraint.element, ArrayConstraint)
146-
array_ty = _get_array_ty(constraint.element, array_memory_space)
145+
array_ty = _get_array_ty(constraint.element)
147146
ty = ListTy(array_ty)
148147
else:
149148
raise TypeError(f"Unexpected parameter descriptor type"
@@ -157,7 +156,7 @@ def _create_kernel_parameters(parameter_constraints: Sequence[ParameterConstrain
157156
return _KernelParameters(aggregate_vars, nonconstant_flat_vars)
158157

159158

160-
def _get_array_ty(param: ArrayConstraint, memory_space):
159+
def _get_array_ty(param: ArrayConstraint):
161160
for static_stride, bound in zip(param.stride_constant, param.stride_lower_bound_incl,
162161
strict=True):
163162
if static_stride is not None:
@@ -169,8 +168,7 @@ def _get_array_ty(param: ArrayConstraint, memory_space):
169168
return ArrayTy(make_tile_ty(param.dtype, ()),
170169
shape=(None,) * param.ndim,
171170
strides=param.stride_constant,
172-
index_dtype=param.index_dtype,
173-
memory_space=memory_space)
171+
index_dtype=param.index_dtype)
174172

175173

176174
def _log_mlir(bytecode_buf):

src/cuda/tile/_ir/type.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,24 @@ def size_to_bytecode(s: Optional[int]) -> int:
290290
# ============== Pointer Type ===============
291291

292292

293+
class MemorySpace(enum.Enum):
294+
GENERIC = 0
295+
GLOBAL = 1
296+
SHARED = 3
297+
CONSTANT = 4
298+
LOCAL = 5
299+
TENSOR = 6
300+
SHARED_CLUSTER = 7
301+
302+
293303
@dataclass(frozen=True)
294304
class PointerTy(Type):
295305
pointee_type: Type
296-
memory_space: Any = None
306+
memory_space: MemorySpace = MemorySpace.GENERIC
307+
308+
def __str__(self):
309+
memspc = "" if self.memory_space == MemorySpace.GENERIC else f", {self.memory_space}"
310+
return f"Pointer[{self.pointee_type}{memspc}]"
297311

298312

299313
# ============== Tile Type ===============
@@ -344,7 +358,7 @@ def __init__(self,
344358
shape: Tuple[Optional[int], ...],
345359
strides: Tuple[Optional[int], ...],
346360
index_dtype=None,
347-
memory_space=None):
361+
memory_space: MemorySpace = MemorySpace.GENERIC):
348362
from .._datatype import int32
349363
assert isinstance(element_type, Type)
350364
self.element_type = element_type
@@ -403,7 +417,8 @@ def __str__(self):
403417
strides_str = ('?' if x is None else str(x) for x in self.strides)
404418
strides_str = "(" + ','.join(strides_str) + ")"
405419
indexty_str = "" if self.index_dtype == int32 else f",index_dtype={self.index_dtype}]"
406-
return f"Array[{type_str},{shape_str}:{strides_str}{indexty_str}]"
420+
memspc = "" if self.memory_space == MemorySpace.GENERIC else f", {self.memory_space}"
421+
return f"Array[{type_str},{shape_str}:{strides_str}{indexty_str}{memspc}]"
407422

408423

409424
# ============== PartitionView Type ===============

0 commit comments

Comments
 (0)