|
30 | 30 | from . import hir |
31 | 31 | from .op_impl import ( |
32 | 32 | impl, require_constant_int, require_constant_int_tuple, |
| 33 | + require_signed_integer_scalar_or_0d_tile_type, |
33 | 34 | require_tile_type, normalize_axis, require_dtype_spec, |
34 | 35 | require_tile_or_scalar_type, require_constant_bool, require_optional_constant_enum, |
35 | 36 | require_constant_str, require_array_type, require_tuple_type, require_constant_slice, |
|
46 | 47 | promote_types, promote_dtypes, check_implicit_cast |
47 | 48 | ) |
48 | 49 | from .scope import Scope, JumpInfo, ControlFlowInfo |
49 | | -from .typing_support import typeof_pyval, dtype_registry, loose_type_of_pyval, get_constant_value |
| 50 | +from .typing_support import ( |
| 51 | + BYTE_BITWIDTH, typeof_pyval, dtype_registry, loose_type_of_pyval, get_constant_value |
| 52 | +) |
50 | 53 | from .type import ( |
51 | 54 | TupleTy, TileTy, NoneType, BoundMethodTy, SizeTy, ArrayTy, |
52 | 55 | ListTy, make_tile_ty, SliceType, DTypeConstructor, RangeIterType, Type, |
53 | | - NONE, ModuleTy, TypeTy, LooselyTypedScalar, DTypeSpec, StringTy, InvalidType |
| 56 | + NONE, ModuleTy, TypeTy, LooselyTypedScalar, DTypeSpec, StringTy, InvalidType, |
| 57 | + array_size_type, |
54 | 58 | ) |
55 | 59 | from cuda.tile._datatype import ( |
56 | 60 | DType, is_integral, is_float, is_signed, is_boolean, is_restricted_float, |
@@ -1506,6 +1510,7 @@ def getattr_impl(object: Var, name: Var) -> Var: |
1506 | 1510 | case ArrayTy(), "ndim": return loosely_typed_const(ty.ndim) |
1507 | 1511 | case ArrayTy(), "shape": return build_tuple(object.get_aggregate().shape) |
1508 | 1512 | case ArrayTy(), "strides": return build_tuple(object.get_aggregate().strides) |
| 1513 | + case ArrayTy(), "slice": return bind_method(object, ct._m_array_slice) |
1509 | 1514 |
|
1510 | 1515 | case TileTy(), "dtype": return loosely_typed_const(ty.dtype) |
1511 | 1516 | case TileTy(), "shape": return loosely_typed_const(ty.shape_value) |
@@ -1578,11 +1583,7 @@ def range_(args: Tuple[Var, ...]) -> Var: |
1578 | 1583 | if not 1 <= len(args) <= 3: |
1579 | 1584 | raise TileTypeError(f"Invalid number of arguments: {len(args)}") |
1580 | 1585 | for arg in args: |
1581 | | - arg_ty = require_scalar_or_0d_tile_type(arg) |
1582 | | - if isinstance(arg_ty, TileTy): |
1583 | | - arg_ty = arg_ty.dtype |
1584 | | - if not datatype.is_integral(arg_ty) or not datatype.is_signed(arg_ty): |
1585 | | - raise TileTypeError(f"Expected a signed integer, got {arg_ty}") |
| 1586 | + require_signed_integer_scalar_or_0d_tile_type(arg) |
1586 | 1587 |
|
1587 | 1588 | if len(args) == 1: |
1588 | 1589 | start = strictly_typed_const(0, datatype.default_int_type) |
@@ -1905,6 +1906,120 @@ def num_blocks(axis: Var) -> Var: |
1905 | 1906 | return add_operation(TileNumBlocks, datatype.default_int_type, axis=axis) |
1906 | 1907 |
|
1907 | 1908 |
|
| 1909 | +def _infer_sliced_shape( |
| 1910 | + array_ty: ArrayTy, |
| 1911 | + axis: int, |
| 1912 | + const_start: Optional[int], |
| 1913 | + const_stop: Optional[int], |
| 1914 | +) -> Tuple[TupleTy, Tuple[Optional[int], ...]]: |
| 1915 | + has_const_bounds = const_start is not None and const_stop is not None |
| 1916 | + new_axis_size = const_stop - const_start if has_const_bounds else None |
| 1917 | + |
| 1918 | + # FIXME: Enable static shape in MakeTensorView for new_axis_size if static |
| 1919 | + new_shape = TupleTy( |
| 1920 | + SizeTy(None) if i == axis else dim |
| 1921 | + for i, dim in enumerate(array_ty.shape) |
| 1922 | + ) |
| 1923 | + |
| 1924 | + # Preserve shape divisibility if new size is compatible |
| 1925 | + old_div_by = array_ty.shape_div_by[axis] |
| 1926 | + new_div_by = ( |
| 1927 | + old_div_by |
| 1928 | + if (old_div_by is not None |
| 1929 | + and new_axis_size is not None |
| 1930 | + and new_axis_size % old_div_by == 0) |
| 1931 | + else None |
| 1932 | + ) |
| 1933 | + |
| 1934 | + new_shape_div_by = tuple( |
| 1935 | + new_div_by if i == axis else d |
| 1936 | + for i, d in enumerate(array_ty.shape_div_by) |
| 1937 | + ) |
| 1938 | + |
| 1939 | + return new_shape, new_shape_div_by |
| 1940 | + |
| 1941 | + |
| 1942 | +def _infer_sliced_base_ptr_alignment( |
| 1943 | + array_ty: ArrayTy, |
| 1944 | + axis: int, |
| 1945 | + const_start: Optional[int], |
| 1946 | +) -> Optional[int]: |
| 1947 | + if array_ty.base_ptr_div_by is None: |
| 1948 | + return None |
| 1949 | + |
| 1950 | + # Get stride divisibility in elements or use static stride if present |
| 1951 | + stride_div_by = array_ty.stride_div_by[axis] or array_ty.strides[axis].maybe_value |
| 1952 | + if stride_div_by is None: |
| 1953 | + return None |
| 1954 | + |
| 1955 | + assert array_ty.dtype.bitwidth % BYTE_BITWIDTH == 0 |
| 1956 | + dtype_bytewidth = array_ty.dtype.bitwidth // BYTE_BITWIDTH |
| 1957 | + stride_div_by_bytes = stride_div_by * dtype_bytewidth |
| 1958 | + offset_div_by = ( |
| 1959 | + const_start * stride_div_by_bytes if const_start is not None |
| 1960 | + else stride_div_by_bytes |
| 1961 | + ) |
| 1962 | + return math.gcd(offset_div_by, array_ty.base_ptr_div_by) |
| 1963 | + |
| 1964 | + |
| 1965 | +@impl(ct._m_array_slice) |
| 1966 | +def array_slice_impl(array: Var, axis: Var, start: Var, stop: Var) -> Var: |
| 1967 | + array_ty = require_array_type(array) |
| 1968 | + axis = normalize_axis(require_constant_int(axis), array_ty.ndim) |
| 1969 | + require_signed_integer_scalar_or_0d_tile_type(start) |
| 1970 | + require_signed_integer_scalar_or_0d_tile_type(stop) |
| 1971 | + |
| 1972 | + def maybe_const_int(v: Var): |
| 1973 | + if isinstance(v.get_type(), DType) and v.is_constant(): |
| 1974 | + v_int = v.get_constant() |
| 1975 | + assert isinstance(v_int, int) |
| 1976 | + return v_int |
| 1977 | + return None |
| 1978 | + |
| 1979 | + const_start = maybe_const_int(start) |
| 1980 | + const_stop = maybe_const_int(stop) |
| 1981 | + if const_start is not None and const_start < 0: |
| 1982 | + raise TileTypeError("Slice start must be non-negative") |
| 1983 | + if const_stop is not None and const_stop < 0: |
| 1984 | + raise TileTypeError("Slice stop must be non-negative") |
| 1985 | + if const_start is not None and const_stop is not None and const_stop < const_start: |
| 1986 | + raise TileTypeError("Slice stop must be greater than or equal to start") |
| 1987 | + |
| 1988 | + new_shape_ty, new_shape_div_by = _infer_sliced_shape(array_ty, axis, const_start, const_stop) |
| 1989 | + new_base_ptr_div_by = _infer_sliced_base_ptr_alignment(array_ty, axis, const_start) |
| 1990 | + new_array_ty = ArrayTy( |
| 1991 | + array_ty.dtype, |
| 1992 | + shape=new_shape_ty, |
| 1993 | + strides=array_ty.strides, |
| 1994 | + elements_disjoint=array_ty.elements_disjoint, |
| 1995 | + base_ptr_div_by=new_base_ptr_div_by, |
| 1996 | + stride_div_by=array_ty.stride_div_by, |
| 1997 | + shape_div_by=new_shape_div_by, |
| 1998 | + ) |
| 1999 | + |
| 2000 | + array_val = array.get_aggregate() |
| 2001 | + assert isinstance(array_val, ArrayValue) |
| 2002 | + static_stride = array_ty.strides[axis].maybe_value |
| 2003 | + if static_stride == 1: |
| 2004 | + offset = start # skip multiplication for unit stride |
| 2005 | + elif static_stride is not None: |
| 2006 | + offset = binary_arithmetic("mul", start, loosely_typed_const(static_stride)) |
| 2007 | + else: |
| 2008 | + offset = binary_arithmetic("mul", start, array_val.strides[axis]) |
| 2009 | + |
| 2010 | + new_base_ptr = pointer_offset(array_val.base_ptr, astype(offset, datatype.uint64)) |
| 2011 | + axis_new_shape = astype(binary_arithmetic("sub", stop, start), array_size_type().dtype) |
| 2012 | + new_shape = tuple( |
| 2013 | + axis_new_shape if i == axis else s for i, s in enumerate(array_val.shape) |
| 2014 | + ) |
| 2015 | + |
| 2016 | + [ret] = unflatten_aggregates( |
| 2017 | + (new_base_ptr,) + new_shape + array_val.strides, |
| 2018 | + (new_array_ty,), (new_array_ty,) |
| 2019 | + ) |
| 2020 | + return ret |
| 2021 | + |
| 2022 | + |
1908 | 2023 | class TileLoad(TypedOperation): |
1909 | 2024 | def __init__( |
1910 | 2025 | self, array: Var, index: tuple[Var, ...], order: Sequence[int], |
|
0 commit comments