Skip to content

Commit 4d4c491

Browse files
committed
Add RawArrayMemory() to support array direct memory access by element offset
Signed-off-by: Ziheng Deng <zihengd@nvidia.com>
1 parent ca63835 commit 4d4c491

File tree

7 files changed

+516
-6
lines changed

7 files changed

+516
-6
lines changed

changelog.d/raw-array-memory.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Added `Array.get_raw_memory()` returning a `RawArrayMemory` object that supports `load_offset(offset)` and `store_offset(offset, value)` for direct memory access by element offset (no shape/stride index calculation).

src/cuda/tile/_ir/ir.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,14 @@ def as_tuple(self) -> tuple["Var", ...]:
288288
return (self.array,)
289289

290290

291+
@dataclass
292+
class RawArrayMemoryValue(AggregateValue):
293+
base_ptr: Var
294+
295+
def as_tuple(self) -> tuple[Var, ...]:
296+
return (self.base_ptr,)
297+
298+
291299
@dataclass
292300
class ListValue(AggregateValue):
293301
base_ptr: Var

src/cuda/tile/_ir/op_impl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .ir import Var, TupleValue, Builder
2121
from .type import TiledViewTy, TupleTy, TileTy, DTypeSpec, EnumTy, StringTy, ArrayTy, SliceType, \
2222
ListTy, LooselyTypedScalar, RangeIterType, FunctionTy, ClosureTy, BoundMethodTy, \
23-
DTypeConstructor, Type
23+
DTypeConstructor, Type, RawArrayMemoryTy
2424

2525

2626
def _verify_params_match(stub_sig: inspect.Signature, func_sig: inspect.Signature):
@@ -365,6 +365,13 @@ def require_tiled_view_type(var: Var) -> TiledViewTy:
365365
return ty
366366

367367

368+
def require_raw_array_memory_type(var: Var) -> RawArrayMemoryTy:
369+
ty = var.get_type()
370+
if not isinstance(ty, RawArrayMemoryTy):
371+
raise _make_type_error(f"Expected a RawArrayMemory, but given value has type {ty}", var)
372+
return ty
373+
374+
368375
def require_list_type(var: Var) -> ListTy:
369376
ty = var.get_type()
370377
if not isinstance(ty, ListTy):

src/cuda/tile/_ir/ops.py

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
enter_nested_block, nested_block, PhiState, LoopVarState,
2626
TupleValue, make_aggregate, RangeValue, BoundMethodValue, ArrayValue, ConstantState,
2727
ListValue, TiledViewValue, ClosureValue, MemoryEffect, attribute, operand,
28-
BlockRestriction, FormattedStringValue,
28+
BlockRestriction, FormattedStringValue, RawArrayMemoryValue
2929
)
3030
from .type import PointerTy
3131
from . import hir
@@ -42,7 +42,7 @@
4242
require_optional_constant_str, PrintfValidator, require_tile_maybe_loose_type,
4343
require_0d_tile_maybe_loose_type, require_bool, require_optional_range_type,
4444
require_tile_or_tile_tuple_type, require_constant_scalar_tuple, require_constant_scalar,
45-
require_callable_type)
45+
require_callable_type, require_raw_array_memory_type)
4646
from .ops_utils import (
4747
BINOP_REGISTRY, UNARYOP_REGISTRY,
4848
check_rd_and_ftz, PaddingMode, get_default_order,
@@ -60,7 +60,7 @@
6060
ListTy, make_tile_ty, SliceType, DTypeConstructor, RangeIterType, Type,
6161
NONE, ModuleTy, TypeTy, LooselyTypedScalar, DTypeSpec, StringTy, InvalidType,
6262
array_size_type, ClosureTy, LiveCapturedScope, TokenTy, TiledViewTy, FormattedStringTy,
63-
StringFormat, FormattedPiece,
63+
StringFormat, FormattedPiece, RawArrayMemoryTy
6464
)
6565
from cuda.tile._datatype import (
6666
DType, is_integral, is_float, is_signed, is_boolean,
@@ -1593,6 +1593,7 @@ def getattr_impl(object: Var, name: Var) -> Var:
15931593
case ArrayTy(), "strides": return build_tuple(object.get_aggregate().strides)
15941594
case ArrayTy(), "slice": return bind_method(object, ct._m_array_slice)
15951595
case ArrayTy(), "tiled_view": return bind_method(object, ct._m_array_tiled_view)
1596+
case ArrayTy(), "get_raw_memory": return bind_method(object, ct._m_array_get_raw_memory)
15961597

15971598
case TileTy(), "dtype": return loosely_typed_const(ty.dtype)
15981599
case TileTy(), "shape": return loosely_typed_const(ty.shape)
@@ -1614,6 +1615,12 @@ def getattr_impl(object: Var, name: Var) -> Var:
16141615
case TiledViewTy(), "load": return bind_method(object, ct._m_tiled_view_load)
16151616
case TiledViewTy(), "store": return bind_method(object, ct._m_tiled_view_store)
16161617

1618+
case RawArrayMemoryTy(), "dtype": return loosely_typed_const(ty.dtype)
1619+
case RawArrayMemoryTy(), "load_offset": return bind_method(
1620+
object, ct._m_raw_array_memory_load_offset)
1621+
case RawArrayMemoryTy(), "store_offset": return bind_method(
1622+
object, ct._m_raw_array_memory_store_offset)
1623+
16171624
case ModuleTy(), _:
16181625
try:
16191626
return loosely_typed_const(getattr(ty.py_mod, attr_name))
@@ -2154,6 +2161,83 @@ def _tile_load_impl_inner(array: Var, index_items: tuple[Var, ...], shape: Seque
21542161
return reshape(result, shape)
21552162

21562163

2164+
@impl(ct._m_array_get_raw_memory)
2165+
def get_raw_memory_impl(array: Var) -> Var:
2166+
array_ty = require_array_type(array)
2167+
array_val = array.get_aggregate()
2168+
assert isinstance(array_val, ArrayValue)
2169+
base_ptr = array_val.base_ptr
2170+
raw_mem_ty = RawArrayMemoryTy(array_ty.dtype)
2171+
[ret] = unflatten_aggregates((base_ptr,), (raw_mem_ty,), (raw_mem_ty,))
2172+
return ret
2173+
2174+
2175+
@impl(ct._m_raw_array_memory_load_offset)
2176+
def raw_array_memory_load_offset_impl(raw_array_memory: Var, offset: Var, mask: Var,
2177+
padding_value: Var, latency: Var) -> Var:
2178+
raw_mem_ty = require_raw_array_memory_type(raw_array_memory)
2179+
raw_mem_val = raw_array_memory.get_aggregate()
2180+
assert isinstance(raw_mem_val, RawArrayMemoryValue)
2181+
base_ptr = raw_mem_val.base_ptr
2182+
2183+
offset = astype(offset, datatype.uint64)
2184+
pointer = pointer_offset(base_ptr, offset)
2185+
pointer_ty = pointer.get_type()
2186+
pointer_shape = pointer_ty.shape
2187+
array_dtype = raw_mem_ty.dtype
2188+
2189+
final_mask = _process_custom_mask(mask, None, pointer_shape)
2190+
2191+
if padding_value.is_constant() and padding_value.get_constant() is None:
2192+
padding_var: Optional[Var] = None
2193+
else:
2194+
padding_ty = require_tile_type(padding_value)
2195+
padding_shape = padding_ty.shape
2196+
if not is_shape_broadcastable_to(padding_shape, pointer_shape):
2197+
raise TileTypeError(f"Padding shape {padding_shape} is not broadcastable to the"
2198+
f" offset shape {pointer_shape}")
2199+
padding_var = _implicit_cast(padding_value, array_dtype, "Invalid padding value")
2200+
padding_var = broadcast_to(padding_var, pointer_shape)
2201+
2202+
latency_val = require_optional_constant_int(latency)
2203+
_check_load_store_hints(latency_val)
2204+
result, _token = load_pointer(pointer, final_mask, padding_var, latency_val)
2205+
return result
2206+
2207+
2208+
@impl(ct._m_raw_array_memory_store_offset)
2209+
def raw_array_memory_store_offset_impl(raw_array_memory: Var, offset: Var, value: Var,
2210+
mask: Var, latency: Var) -> None:
2211+
raw_mem_ty = require_raw_array_memory_type(raw_array_memory)
2212+
raw_mem_val = raw_array_memory.get_aggregate()
2213+
assert isinstance(raw_mem_val, RawArrayMemoryValue)
2214+
base_ptr = raw_mem_val.base_ptr
2215+
2216+
offset = astype(offset, datatype.uint64)
2217+
pointer = pointer_offset(base_ptr, offset)
2218+
pointer_ty = pointer.get_type()
2219+
pointer_shape = pointer_ty.shape
2220+
array_dtype = raw_mem_ty.dtype
2221+
2222+
final_mask = _process_custom_mask(mask, None, pointer_shape)
2223+
value = _get_scatter_value(value, pointer_shape, array_dtype, "Value",
2224+
array_name="RawArrayMemory")
2225+
2226+
latency_val = require_optional_constant_int(latency)
2227+
_check_load_store_hints(latency_val)
2228+
[_token] = store_pointer(pointer, value, final_mask, latency_val)
2229+
2230+
2231+
def tile_load(array: Var, index: tuple[Var, ...], shape: Sequence[int], order: Sequence[int],
2232+
padding_mode: PaddingMode, latency: Optional[int],
2233+
allow_tma: Optional[bool]) -> tuple[Var, Var]:
2234+
res_ty = make_tile_ty(array.get_type().dtype, shape)
2235+
return add_operation(TileLoad, (res_ty, TokenTy()),
2236+
array=array, index=index, order=tuple(order),
2237+
padding_mode=padding_mode, latency=latency,
2238+
allow_tma=allow_tma)
2239+
2240+
21572241
@impl(ct.load)
21582242
def tile_load_impl(array: Var, index: Var, shape: Var, order: Var,
21592243
padding_mode: Var, latency: Var, allow_tma: Var) -> Var:
@@ -2369,7 +2453,8 @@ def scatter_impl(array: Var, indices: Var, value: Var, mask: Var,
23692453

23702454

23712455
def _get_scatter_value(value: Var, pointer_shape: Tuple[int, ...], array_dtype: DType,
2372-
value_name: str, cast_dtype: bool = True) -> Var:
2456+
value_name: str, cast_dtype: bool = True,
2457+
array_name: str = "array") -> Var:
23732458
value_ty = require_tile_type(value)
23742459
value_shape = value_ty.shape
23752460

@@ -2379,7 +2464,7 @@ def _get_scatter_value(value: Var, pointer_shape: Tuple[int, ...], array_dtype:
23792464

23802465
if cast_dtype:
23812466
value = _implicit_cast(value, array_dtype,
2382-
"Stored value is incompatible with array's dtype")
2467+
f"Stored value is incompatible with {array_name}'s dtype")
23832468
return broadcast_to(value, pointer_shape)
23842469

23852470

src/cuda/tile/_ir/type.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,31 @@ def __str__(self):
465465
f"padding_mode={self.padding_mode}]")
466466

467467

468+
# ============== Raw Array Memory Type ===============
469+
470+
471+
@dataclass(frozen=True)
472+
class RawArrayMemoryTy(Type):
473+
"""Type for a RawArrayMemory object that allows load/store by element offset (no index math)."""
474+
dtype: "DType"
475+
476+
def is_aggregate(self) -> bool:
477+
return True
478+
479+
def aggregate_item_types(self) -> tuple["Type", ...]:
480+
base_ptr_ty = PointerTy(self.dtype)
481+
base_ptr_tile_ty = TileTy(base_ptr_ty, TupleTy(()))
482+
return (base_ptr_tile_ty,)
483+
484+
def make_aggregate_value(self, items: tuple["Var", ...]) -> "AggregateValue":
485+
from .ir import RawArrayMemoryValue
486+
assert len(items) == 1
487+
return RawArrayMemoryValue(items[0])
488+
489+
def __str__(self):
490+
return f"RawArrayMemory[{self.dtype}]"
491+
492+
468493
# ============== List Type ===============
469494

470495

src/cuda/tile/_stub.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,64 @@ def tiled_view(self, tile_shape: Constant[Shape], *,
232232
"""
233233
return _m_array_tiled_view(self, tile_shape, padding_mode=padding_mode)
234234

235+
def get_raw_memory(self) -> "RawArrayMemory":
236+
"""Returns an object that allows loading and storing by element offset.
237+
238+
The returned object holds the array's base pointer. Use
239+
:py:meth:`RawArrayMemory.load_offset`
240+
and :py:meth:`RawArrayMemory.store_offset` with an offset in **elements** (no shape/stride
241+
index calculation). Useful when you already have memory offsets.
242+
243+
Returns:
244+
RawArrayMemory:
245+
"""
246+
return _m_array_get_raw_memory(self)
247+
248+
249+
class RawArrayMemory:
250+
"""Type stub for RawArrayMemory objects returned by :py:meth:`Array.get_raw_memory`."""
251+
252+
@property
253+
@function
254+
def dtype(self) -> "DType":
255+
"""The data type of the elements in the |RawArrayMemory|.
256+
257+
Returns:
258+
DType (constant):
259+
"""
260+
261+
def load_offset(self, offset: "TileOrScalar", /, *,
262+
mask: Optional["Tile"] = None,
263+
padding_value: "TileOrScalar" = 0,
264+
latency: Optional[int] = None) -> "Tile":
265+
"""Loads from memory at base_ptr + offset (offset in elements).
266+
267+
Args:
268+
offset: Element offset(s); scalar or tile of integer type.
269+
mask: Optional boolean mask; where False, padding_value is used instead of load.
270+
padding_value: Value used when mask is False; default 0.
271+
latency: Optional latency hint (1--10).
272+
273+
Returns:
274+
Tile: Loaded tile; shape matches broadcast(offset).
275+
"""
276+
return _m_raw_array_memory_load_offset(
277+
self, offset, mask=mask, padding_value=padding_value, latency=latency)
278+
279+
def store_offset(self, offset: "TileOrScalar", value: "TileOrScalar", /, *,
280+
mask: Optional["Tile"] = None,
281+
latency: Optional[int] = None) -> None:
282+
"""Stores to memory at base_ptr + offset (offset in elements).
283+
284+
Args:
285+
offset: Element offset(s); scalar or tile of integer type.
286+
value: Value(s) to store; broadcast to offset shape.
287+
mask: Optional boolean mask; where False, no store occurs.
288+
latency: Optional latency hint (1--10).
289+
"""
290+
return _m_raw_array_memory_store_offset(
291+
self, offset, value, mask=mask, latency=latency)
292+
235293

236294
class Tile:
237295
"""Type stub for a |tile|."""
@@ -2430,3 +2488,22 @@ def _m_tiled_view_load(tiled_view, index, *, latency, allow_tma): ...
24302488
@function
24312489
def _m_tiled_view_store(tiled_view, index, tile, *, latency, allow_tma): ...
24322490
# TiledView.store(index, tile, latency=latency, allow_tma=allow_tma)
2491+
2492+
2493+
@function
2494+
def _m_array_get_raw_memory(array: Array) -> RawArrayMemory: ... # Array.get_raw_memory()
2495+
2496+
2497+
@function
2498+
def _m_raw_array_memory_load_offset(
2499+
raw_array_memory: RawArrayMemory, offset: TileOrScalar, /, *,
2500+
mask: Optional[Tile] = None,
2501+
padding_value: TileOrScalar = 0,
2502+
latency: Optional[int] = None) -> Tile: ... # RawArrayMemory.load_offset()
2503+
2504+
2505+
@function
2506+
def _m_raw_array_memory_store_offset(
2507+
raw_array_memory: RawArrayMemory, offset: TileOrScalar, value: TileOrScalar, /, *,
2508+
mask: Optional[Tile] = None,
2509+
latency: Optional[int] = None) -> None: ... # RawArrayMemory.store_offset()

0 commit comments

Comments
 (0)