Skip to content

Commit 8b6d006

Browse files
committed
array slice
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent 361048e commit 8b6d006

File tree

7 files changed

+450
-9
lines changed

7 files changed

+450
-9
lines changed

changelog.d/array-slice.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
Add `Array.slice(axis, start, stop)` to create a view of an array sliced along a single axis.
5+
The result shares memory with the original array (no data copy).

src/cuda/tile/_ir/op_impl.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from cuda.tile._ir.ops_utils import get_dtype
1717

1818
from .ir import Var
19-
from .typing_support import get_signature
19+
from .typing_support import datatype, get_signature
2020
from .type import TupleTy, TileTy, DTypeSpec, EnumTy, StringTy, ArrayTy, SliceType, \
2121
ListTy, PointerTy, LooselyTypedScalar, RangeIterType
2222
from .. import _datatype
@@ -267,6 +267,21 @@ def require_scalar_or_0d_tile_type(var: Var) -> TileTy | DType | PointerTy:
267267
return ty
268268

269269

270+
def require_signed_integer_scalar_or_0d_tile_type(var: Var) -> TileTy | DType:
271+
ty = require_scalar_or_0d_tile_type(var)
272+
if isinstance(ty, TileTy):
273+
dtype = ty.dtype
274+
elif isinstance(ty, DType):
275+
dtype = ty
276+
else:
277+
dtype = None
278+
279+
if dtype is None or not datatype.is_integral(dtype) or not datatype.is_signed(dtype):
280+
raise _make_type_error(f"Expected a signed integer scalar or a 0D signed integer tile,"
281+
f" but got {ty}", var)
282+
return ty
283+
284+
270285
def require_bool(var: Var) -> TileTy | DType:
271286
ty = var.get_type()
272287
if not (ty == _datatype.bool_

src/cuda/tile/_ir/ops.py

Lines changed: 122 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from . import hir
3131
from .op_impl import (
3232
impl, require_constant_int, require_constant_int_tuple,
33+
require_signed_integer_scalar_or_0d_tile_type,
3334
require_tile_type, normalize_axis, require_dtype_spec,
3435
require_tile_or_scalar_type, require_constant_bool, require_optional_constant_enum,
3536
require_constant_str, require_array_type, require_tuple_type, require_constant_slice,
@@ -46,11 +47,14 @@
4647
promote_types, promote_dtypes, check_implicit_cast
4748
)
4849
from .scope import Scope, JumpInfo, ControlFlowInfo
49-
from .typing_support import typeof_pyval, dtype_registry, loose_type_of_pyval, get_constant_value
50+
from .typing_support import (
51+
BYTE_BITWIDTH, typeof_pyval, dtype_registry, loose_type_of_pyval, get_constant_value
52+
)
5053
from .type import (
5154
TupleTy, TileTy, NoneType, BoundMethodTy, SizeTy, ArrayTy,
5255
ListTy, make_tile_ty, SliceType, DTypeConstructor, RangeIterType, Type,
53-
NONE, ModuleTy, TypeTy, LooselyTypedScalar, DTypeSpec, StringTy, InvalidType
56+
NONE, ModuleTy, TypeTy, LooselyTypedScalar, DTypeSpec, StringTy, InvalidType,
57+
array_size_type,
5458
)
5559
from cuda.tile._datatype import (
5660
DType, is_integral, is_float, is_signed, is_boolean, is_restricted_float,
@@ -1506,6 +1510,7 @@ def getattr_impl(object: Var, name: Var) -> Var:
15061510
case ArrayTy(), "ndim": return loosely_typed_const(ty.ndim)
15071511
case ArrayTy(), "shape": return build_tuple(object.get_aggregate().shape)
15081512
case ArrayTy(), "strides": return build_tuple(object.get_aggregate().strides)
1513+
case ArrayTy(), "slice": return bind_method(object, ct._m_array_slice)
15091514

15101515
case TileTy(), "dtype": return loosely_typed_const(ty.dtype)
15111516
case TileTy(), "shape": return loosely_typed_const(ty.shape_value)
@@ -1578,11 +1583,7 @@ def range_(args: Tuple[Var, ...]) -> Var:
15781583
if not 1 <= len(args) <= 3:
15791584
raise TileTypeError(f"Invalid number of arguments: {len(args)}")
15801585
for arg in args:
1581-
arg_ty = require_scalar_or_0d_tile_type(arg)
1582-
if isinstance(arg_ty, TileTy):
1583-
arg_ty = arg_ty.dtype
1584-
if not datatype.is_integral(arg_ty) or not datatype.is_signed(arg_ty):
1585-
raise TileTypeError(f"Expected a signed integer, got {arg_ty}")
1586+
require_signed_integer_scalar_or_0d_tile_type(arg)
15861587

15871588
if len(args) == 1:
15881589
start = strictly_typed_const(0, datatype.default_int_type)
@@ -1905,6 +1906,120 @@ def num_blocks(axis: Var) -> Var:
19051906
return add_operation(TileNumBlocks, datatype.default_int_type, axis=axis)
19061907

19071908

1909+
def _infer_sliced_shape(
1910+
array_ty: ArrayTy,
1911+
axis: int,
1912+
const_start: Optional[int],
1913+
const_stop: Optional[int],
1914+
) -> Tuple[TupleTy, Tuple[Optional[int], ...]]:
1915+
has_const_bounds = const_start is not None and const_stop is not None
1916+
new_axis_size = const_stop - const_start if has_const_bounds else None
1917+
1918+
# FIXME: Enable static shape in MakeTensorView for new_axis_size if static
1919+
new_shape = TupleTy(
1920+
SizeTy(None) if i == axis else dim
1921+
for i, dim in enumerate(array_ty.shape)
1922+
)
1923+
1924+
# Preserve shape divisibility if new size is compatible
1925+
old_div_by = array_ty.shape_div_by[axis]
1926+
new_div_by = (
1927+
old_div_by
1928+
if (old_div_by is not None
1929+
and new_axis_size is not None
1930+
and new_axis_size % old_div_by == 0)
1931+
else None
1932+
)
1933+
1934+
new_shape_div_by = tuple(
1935+
new_div_by if i == axis else d
1936+
for i, d in enumerate(array_ty.shape_div_by)
1937+
)
1938+
1939+
return new_shape, new_shape_div_by
1940+
1941+
1942+
def _infer_sliced_base_ptr_alignment(
1943+
array_ty: ArrayTy,
1944+
axis: int,
1945+
const_start: Optional[int],
1946+
) -> Optional[int]:
1947+
if array_ty.base_ptr_div_by is None:
1948+
return None
1949+
1950+
# Get stride divisibility in elements or use static stride if present
1951+
stride_div_by = array_ty.stride_div_by[axis] or array_ty.strides[axis].maybe_value
1952+
if stride_div_by is None:
1953+
return None
1954+
1955+
assert array_ty.dtype.bitwidth % BYTE_BITWIDTH == 0
1956+
dtype_bytewidth = array_ty.dtype.bitwidth // BYTE_BITWIDTH
1957+
stride_div_by_bytes = stride_div_by * dtype_bytewidth
1958+
offset_div_by = (
1959+
const_start * stride_div_by_bytes if const_start is not None
1960+
else stride_div_by_bytes
1961+
)
1962+
return math.gcd(offset_div_by, array_ty.base_ptr_div_by)
1963+
1964+
1965+
@impl(ct._m_array_slice)
1966+
def array_slice_impl(array: Var, axis: Var, start: Var, stop: Var) -> Var:
1967+
array_ty = require_array_type(array)
1968+
axis = normalize_axis(require_constant_int(axis), array_ty.ndim)
1969+
require_signed_integer_scalar_or_0d_tile_type(start)
1970+
require_signed_integer_scalar_or_0d_tile_type(stop)
1971+
1972+
def maybe_const_int(v: Var):
1973+
if isinstance(v.get_type(), DType) and v.is_constant():
1974+
v_int = v.get_constant()
1975+
assert isinstance(v_int, int)
1976+
return v_int
1977+
return None
1978+
1979+
const_start = maybe_const_int(start)
1980+
const_stop = maybe_const_int(stop)
1981+
if const_start is not None and const_start < 0:
1982+
raise TileTypeError("Slice start must be non-negative")
1983+
if const_stop is not None and const_stop < 0:
1984+
raise TileTypeError("Slice stop must be non-negative")
1985+
if const_start is not None and const_stop is not None and const_stop < const_start:
1986+
raise TileTypeError("Slice stop must be greater than or equal to start")
1987+
1988+
new_shape_ty, new_shape_div_by = _infer_sliced_shape(array_ty, axis, const_start, const_stop)
1989+
new_base_ptr_div_by = _infer_sliced_base_ptr_alignment(array_ty, axis, const_start)
1990+
new_array_ty = ArrayTy(
1991+
array_ty.dtype,
1992+
shape=new_shape_ty,
1993+
strides=array_ty.strides,
1994+
elements_disjoint=array_ty.elements_disjoint,
1995+
base_ptr_div_by=new_base_ptr_div_by,
1996+
stride_div_by=array_ty.stride_div_by,
1997+
shape_div_by=new_shape_div_by,
1998+
)
1999+
2000+
array_val = array.get_aggregate()
2001+
assert isinstance(array_val, ArrayValue)
2002+
static_stride = array_ty.strides[axis].maybe_value
2003+
if static_stride == 1:
2004+
offset = start # skip multiplication for unit stride
2005+
elif static_stride is not None:
2006+
offset = binary_arithmetic("mul", start, loosely_typed_const(static_stride))
2007+
else:
2008+
offset = binary_arithmetic("mul", start, array_val.strides[axis])
2009+
2010+
new_base_ptr = pointer_offset(array_val.base_ptr, astype(offset, datatype.uint64))
2011+
axis_new_shape = astype(binary_arithmetic("sub", stop, start), array_size_type().dtype)
2012+
new_shape = tuple(
2013+
axis_new_shape if i == axis else s for i, s in enumerate(array_val.shape)
2014+
)
2015+
2016+
[ret] = unflatten_aggregates(
2017+
(new_base_ptr,) + new_shape + array_val.strides,
2018+
(new_array_ty,), (new_array_ty,)
2019+
)
2020+
return ret
2021+
2022+
19082023
class TileLoad(TypedOperation):
19092024
def __init__(
19102025
self, array: Var, index: tuple[Var, ...], order: Sequence[int],

src/cuda/tile/_ir/type.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,10 @@ def unify(self, other: "ArrayTy") -> Optional["ArrayTy"]:
351351
for s1, s2 in zip(self.strides, other.strides, strict=True)))
352352

353353
elements_disjoint = self.elements_disjoint and other.elements_disjoint
354-
base_ptr_div_by = math.gcd(self.base_ptr_div_by, other.base_ptr_div_by)
354+
base_ptr_div_by = (
355+
None if (self.base_ptr_div_by is None or other.base_ptr_div_by is None)
356+
else math.gcd(self.base_ptr_div_by, other.base_ptr_div_by)
357+
)
355358
shape_div_by = tuple(
356359
None if (d1 is None or d2 is None) else math.gcd(d1, d2)
357360
for d1, d2 in zip(self.shape_div_by, other.shape_div_by, strict=True)

src/cuda/tile/_stub.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,38 @@ def ndim(self) -> int:
181181
int (constant):
182182
"""
183183

184+
@function
185+
def slice(self, axis, start, stop) -> "Array":
186+
"""Creates a view of the |array| sliced along a single `axis`.
187+
188+
The returned array references the same underlying memory as |array|,
189+
but with a restricted range from index `start` (inclusive) to `stop` (exclusive)
190+
along the specified axis. No data is copied.
191+
192+
`axis` must be a constant integer. Negative values are supported and count
193+
from the last dimension (e.g., ``axis=-1`` refers to the last axis).
194+
195+
`start` and `stop` must be integers (scalars or 0D tiles).
196+
They must satisfy ``0 <= start < N`` and ``start <= stop <= N``, where ``N``
197+
is the size of `array` along the sliced axis.
198+
199+
For example, consider a 2-dimensional array A of shape ``(M, N)``.
200+
Slicing along axis 0 from `start` to `stop`:
201+
202+
>>> sub = A.slice(axis=0, start=start, stop=stop)
203+
204+
The result `sub` will be an array of shape ``(stop - start, N)``.
205+
Using NumPy slice notation for illustration, this is equivalent to::
206+
207+
sub = A[start:stop, :] # NumPy notation for reference only
208+
209+
The slice bounds can be dynamic (runtime values):
210+
211+
>>> # Process variable-length segments
212+
>>> segment = A.slice(axis=1, start=offset, stop=offset + length)
213+
>>> tile = ct.load(segment, (0, 0), shape=(TILE_M, TILE_N))
214+
"""
215+
184216

185217
class Tile:
186218
"""A *tile array* (or *tile*) is an immutable multidimensional collection of values that is
@@ -1891,4 +1923,8 @@ def assert_(cond, /, message=None) -> None:
18911923

18921924
# ==== Private stubs ====
18931925

1926+
1927+
def _m_array_slice(array, axis, start, stop): ... # Array.slice(axis, start, stop)
1928+
1929+
18941930
def _m_tile_item(tile): ... # Tile.item()

0 commit comments

Comments
 (0)