1+ package sk.ainet.core.tensor
2+
3+ /* *
4+ * Dense tensor data implementation for contiguous memory layout.
5+ * This is the most common tensor data representation where all elements
6+ * are stored in a contiguous array with standard stride patterns.
7+ */
8+ public class DenseTensorData <T : DType , V >(
9+ override val shape : Shape ,
10+ private val data : Array <V >,
11+ override val strides : IntArray = shape.computeStrides(),
12+ override val offset : Int = 0
13+ ) : TensorData<T, V> {
14+
15+ override val isContiguous: Boolean = true
16+
17+ override operator fun get (vararg indices : Int ): V {
18+ require(indices.size == shape.dimensions.size) {
19+ " Number of indices (${indices.size} ) must match tensor dimensions (${shape.dimensions.size} )"
20+ }
21+
22+ var flatIndex = offset
23+ for (i in indices.indices) {
24+ require(indices[i] >= 0 && indices[i] < shape.dimensions[i]) {
25+ " Index ${indices[i]} out of bounds for dimension $i with size ${shape.dimensions[i]} "
26+ }
27+ flatIndex + = indices[i] * strides[i]
28+ }
29+
30+ return data[flatIndex]
31+ }
32+
33+ override fun copyTo (dest : Array <V >, destOffset : Int ) {
34+ if (isContiguous && offset == 0 ) {
35+ // Fast path for contiguous data
36+ data.copyInto(dest, destOffset, 0 , shape.volume)
37+ } else {
38+ // Stride-based copy for non-contiguous or offset data
39+ var destIndex = destOffset
40+ iterateAll { flatIndex ->
41+ dest[destIndex++ ] = data[flatIndex]
42+ }
43+ }
44+ }
45+
46+ override fun slice (ranges : IntArray ): TensorData <T , V > {
47+ require(ranges.size == shape.dimensions.size * 2 ) {
48+ " Ranges array must contain start,end pairs for each dimension. Expected ${shape.dimensions.size * 2 } , got ${ranges.size} "
49+ }
50+
51+ // Calculate new shape and strides for the slice
52+ val newDimensions = mutableListOf<Int >()
53+ val newStrides = mutableListOf<Int >()
54+ var newOffset = offset
55+
56+ for (i in shape.dimensions.indices) {
57+ val start = ranges[i * 2 ]
58+ val end = ranges[i * 2 + 1 ]
59+
60+ require(start >= 0 && start < shape.dimensions[i] && end > start && end <= shape.dimensions[i]) {
61+ " Invalid range [$start , $end ) for dimension $i with size ${shape.dimensions[i]} "
62+ }
63+
64+ newDimensions.add(end - start)
65+ newStrides.add(strides[i])
66+ newOffset + = start * strides[i]
67+ }
68+
69+ val newShape = Shape (newDimensions.toIntArray())
70+ return ViewTensorData (data, newShape, newStrides.toIntArray(), newOffset, shape)
71+ }
72+
73+ override fun materialize (): TensorData <T , V > {
74+ // Already materialized (contiguous)
75+ return this
76+ }
77+
78+ /* *
79+ * Iterates over all elements in the tensor using stride-based indexing.
80+ */
81+ private fun iterateAll (action : (flatIndex: Int ) -> Unit ) {
82+ val indices = IntArray (shape.dimensions.size)
83+ iterateRecursive(0 , offset, indices, action)
84+ }
85+
86+ private fun iterateRecursive (dim : Int , currentOffset : Int , indices : IntArray , action : (flatIndex: Int ) -> Unit ) {
87+ if (dim == shape.dimensions.size) {
88+ action(currentOffset)
89+ return
90+ }
91+
92+ for (i in 0 until shape.dimensions[dim]) {
93+ indices[dim] = i
94+ iterateRecursive(dim + 1 , currentOffset + i * strides[dim], indices, action)
95+ }
96+ }
97+ }
98+
99+ /* *
100+ * Computes standard row-major strides for the given shape.
101+ */
102+ private fun Shape.computeStrides (): IntArray {
103+ if (dimensions.isEmpty()) return intArrayOf()
104+
105+ val strides = IntArray (dimensions.size)
106+ strides[dimensions.size - 1 ] = 1
107+
108+ for (i in dimensions.size - 2 downTo 0 ) {
109+ strides[i] = strides[i + 1 ] * dimensions[i + 1 ]
110+ }
111+
112+ return strides
113+ }
0 commit comments