Skip to content

Commit 92a9b24

Browse files
committed
Add missing file to repo
Relate-To: #94
1 parent 08f89ab commit 92a9b24

1 file changed

Lines changed: 113 additions & 0 deletions

File tree

  • skainet-core/skainet-tensors-api/src/commonMain/kotlin/sk/ainet/core/tensor
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package sk.ainet.core.tensor
2+
3+
/**
4+
* Dense tensor data implementation for contiguous memory layout.
5+
* This is the most common tensor data representation where all elements
6+
* are stored in a contiguous array with standard stride patterns.
7+
*/
8+
public class DenseTensorData<T : DType, V>(
9+
override val shape: Shape,
10+
private val data: Array<V>,
11+
override val strides: IntArray = shape.computeStrides(),
12+
override val offset: Int = 0
13+
) : TensorData<T, V> {
14+
15+
override val isContiguous: Boolean = true
16+
17+
override operator fun get(vararg indices: Int): V {
18+
require(indices.size == shape.dimensions.size) {
19+
"Number of indices (${indices.size}) must match tensor dimensions (${shape.dimensions.size})"
20+
}
21+
22+
var flatIndex = offset
23+
for (i in indices.indices) {
24+
require(indices[i] >= 0 && indices[i] < shape.dimensions[i]) {
25+
"Index ${indices[i]} out of bounds for dimension $i with size ${shape.dimensions[i]}"
26+
}
27+
flatIndex += indices[i] * strides[i]
28+
}
29+
30+
return data[flatIndex]
31+
}
32+
33+
override fun copyTo(dest: Array<V>, destOffset: Int) {
34+
if (isContiguous && offset == 0) {
35+
// Fast path for contiguous data
36+
data.copyInto(dest, destOffset, 0, shape.volume)
37+
} else {
38+
// Stride-based copy for non-contiguous or offset data
39+
var destIndex = destOffset
40+
iterateAll { flatIndex ->
41+
dest[destIndex++] = data[flatIndex]
42+
}
43+
}
44+
}
45+
46+
override fun slice(ranges: IntArray): TensorData<T, V> {
47+
require(ranges.size == shape.dimensions.size * 2) {
48+
"Ranges array must contain start,end pairs for each dimension. Expected ${shape.dimensions.size * 2}, got ${ranges.size}"
49+
}
50+
51+
// Calculate new shape and strides for the slice
52+
val newDimensions = mutableListOf<Int>()
53+
val newStrides = mutableListOf<Int>()
54+
var newOffset = offset
55+
56+
for (i in shape.dimensions.indices) {
57+
val start = ranges[i * 2]
58+
val end = ranges[i * 2 + 1]
59+
60+
require(start >= 0 && start < shape.dimensions[i] && end > start && end <= shape.dimensions[i]) {
61+
"Invalid range [$start, $end) for dimension $i with size ${shape.dimensions[i]}"
62+
}
63+
64+
newDimensions.add(end - start)
65+
newStrides.add(strides[i])
66+
newOffset += start * strides[i]
67+
}
68+
69+
val newShape = Shape(newDimensions.toIntArray())
70+
return ViewTensorData(data, newShape, newStrides.toIntArray(), newOffset, shape)
71+
}
72+
73+
override fun materialize(): TensorData<T, V> {
74+
// Already materialized (contiguous)
75+
return this
76+
}
77+
78+
/**
79+
* Iterates over all elements in the tensor using stride-based indexing.
80+
*/
81+
private fun iterateAll(action: (flatIndex: Int) -> Unit) {
82+
val indices = IntArray(shape.dimensions.size)
83+
iterateRecursive(0, offset, indices, action)
84+
}
85+
86+
private fun iterateRecursive(dim: Int, currentOffset: Int, indices: IntArray, action: (flatIndex: Int) -> Unit) {
87+
if (dim == shape.dimensions.size) {
88+
action(currentOffset)
89+
return
90+
}
91+
92+
for (i in 0 until shape.dimensions[dim]) {
93+
indices[dim] = i
94+
iterateRecursive(dim + 1, currentOffset + i * strides[dim], indices, action)
95+
}
96+
}
97+
}
98+
99+
/**
100+
* Computes standard row-major strides for the given shape.
101+
*/
102+
private fun Shape.computeStrides(): IntArray {
103+
if (dimensions.isEmpty()) return intArrayOf()
104+
105+
val strides = IntArray(dimensions.size)
106+
strides[dimensions.size - 1] = 1
107+
108+
for (i in dimensions.size - 2 downTo 0) {
109+
strides[i] = strides[i + 1] * dimensions[i + 1]
110+
}
111+
112+
return strides
113+
}

0 commit comments

Comments
 (0)