diff --git a/src/language/types.jl b/src/language/types.jl index f246e54d..d552c465 100644 --- a/src/language/types.jl +++ b/src/language/types.jl @@ -151,7 +151,7 @@ specializations (e.g., aligned vs unaligned) produce different cubins. - `sizes::NTuple{N, Int32}`: Size in each dimension - `strides::NTuple{N, Int32}`: Stride in each dimension (in elements) """ -struct TileArray{T, N, S} +struct TileArray{T, N, S} <: AbstractArray{T, N} ptr::Ptr{T} sizes::NTuple{N, Int32} strides::NTuple{N, Int32} @@ -168,6 +168,24 @@ function Base.size(arr::TileArray{<:Any, N}, d::Integer) where N return d > N ? Int32(1) : arr.sizes[d] end Base.length(arr::TileArray) = prod(size(arr)) + +function Base.getindex(arr::TileArray, I...) + error( + "TileArray does not support host-side indexing. " * + "Use ct.load inside a cuTile kernel to access tile data." + ) +end + +function Base.show(io::IO, arr::TileArray{T,N}) where {T,N} + print(io, "TileArray{$T,$N}($(join(size(arr), "×")))") +end + +function Base.show(io::IO, ::MIME"text/plain", arr::TileArray{T,N}) where {T,N} + print(io, join(size(arr), "×"), " TileArray{$T,$N} on GPU") +end + +Base.pointer(arr::TileArray{T}) where {T} = arr.ptr + # Return the ArraySpec value (third type parameter) if present function array_spec(@nospecialize(T::Type{<:TileArray})) T isa DataType || return nothing # UnionAll types don't have full parameters