Skip to content

Commit b1f8b15

Browse files
Merge pull request #552 from SKaiNET-developers/feature/ISSUE-551-permute-axes
feat(ops): add TensorOps.permute(axes) for arbitrary-axis permutation
2 parents 380e7c5 + bc28884 commit b1f8b15

7 files changed

Lines changed: 260 additions & 0 deletions

File tree

skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,74 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
484484
return newTensor(outData, tensor.dtype, tensor)
485485
}
486486

487+
@TensorOp()
488+
override fun <T : DType, V> permute(tensor: Tensor<T, V>, axes: IntArray): Tensor<T, V> {
489+
val rank = tensor.shape.rank
490+
require(axes.size == rank) {
491+
"permute: axes length ${axes.size} must match tensor rank $rank"
492+
}
493+
val seen = BooleanArray(rank)
494+
for (a in axes) {
495+
require(a in 0 until rank) { "permute: axis $a out of range [0, $rank)" }
496+
require(!seen[a]) { "permute: axis $a appears more than once in ${axes.toList()}" }
497+
seen[a] = true
498+
}
499+
500+
val inDims = tensor.shape.dimensions
501+
val outDims = IntArray(rank) { i -> inDims[axes[i]] }
502+
val outShape = Shape(outDims)
503+
504+
// Identity permute — no copy.
505+
var isIdentity = true
506+
for (i in 0 until rank) if (axes[i] != i) { isIdentity = false; break }
507+
if (isIdentity) return tensor
508+
509+
// Row-major strides for input and output. inStrides[k] is the
510+
// distance in the source buffer between consecutive indices on
511+
// input axis k.
512+
val inStrides = IntArray(rank).also { s ->
513+
s[rank - 1] = 1
514+
for (i in rank - 2 downTo 0) s[i] = s[i + 1] * inDims[i + 1]
515+
}
516+
val outStrides = IntArray(rank).also { s ->
517+
s[rank - 1] = 1
518+
for (i in rank - 2 downTo 0) s[i] = s[i + 1] * outDims[i + 1]
519+
}
520+
521+
// Fast path: source is a contiguous FloatArray. Iterate the output
522+
// linearly, decompose each flat index to its multi-index, permute
523+
// to source coords, recompose to source flat index, copy.
524+
if (tensor.data is FloatArrayTensorData<*>) {
525+
val srcBuf = (tensor.data as FloatArrayTensorData<*>).buffer
526+
val total = outShape.volume
527+
val out = FloatArray(total)
528+
val outIdx = IntArray(rank)
529+
for (flatOut in 0 until total) {
530+
var rem = flatOut
531+
for (i in 0 until rank) {
532+
val s = outStrides[i]
533+
outIdx[i] = rem / s
534+
rem -= outIdx[i] * s
535+
}
536+
var flatIn = 0
537+
for (i in 0 until rank) flatIn += outIdx[i] * inStrides[axes[i]]
538+
out[flatOut] = srcBuf[flatIn]
539+
}
540+
@Suppress("UNCHECKED_CAST")
541+
val outData = dataFactory.fromFloatArray<T, Float>(outShape, tensor.dtype, out)
542+
as sk.ainet.lang.tensor.data.TensorData<T, V>
543+
return newTensor(outData, tensor.dtype, tensor)
544+
}
545+
546+
// Generic fallback: defer to dataFactory.init with element access.
547+
val outData = dataFactory.init<T, V>(outShape, tensor.dtype) { outIdx ->
548+
val inIdx = IntArray(rank)
549+
for (i in 0 until rank) inIdx[axes[i]] = outIdx[i]
550+
tensor.data.get(*inIdx)
551+
}
552+
return newTensor(outData, tensor.dtype, tensor)
553+
}
554+
487555
@TensorOp()
488556
@InProgress("cpu", owner = "team:cpu", issue = "task-ops.md#op-conv2d")
489557
override fun <T : DType, V> conv2d(
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package sk.ainet.exec.tensor.ops
2+
3+
import kotlin.test.Test
4+
import kotlin.test.assertContentEquals
5+
import kotlin.test.assertEquals
6+
import kotlin.test.assertFailsWith
7+
import kotlin.test.assertSame
8+
import sk.ainet.context.DirectCpuExecutionContext
9+
import sk.ainet.lang.tensor.Shape
10+
import sk.ainet.lang.types.FP32
11+
12+
class PermuteTest {
13+
14+
private fun ctx() = DirectCpuExecutionContext()
15+
16+
@Test
17+
fun identityPermuteReturnsSameTensor() {
18+
val ctx = ctx()
19+
val t = ctx.fromFloatArray<FP32, Float>(
20+
Shape(2, 3, 4), FP32::class,
21+
FloatArray(24) { it.toFloat() }
22+
)
23+
val out = ctx.ops.permute(t, intArrayOf(0, 1, 2))
24+
assertSame(t, out, "identity permute should return the input tensor")
25+
}
26+
27+
@Test
28+
fun swapDim0AndDim1OnRank3() {
29+
val ctx = ctx()
30+
// Shape [A=2, B=3, C=4], elements 0..23 row-major.
31+
// Element (a, b, c) flat = a*12 + b*4 + c.
32+
val src = FloatArray(24) { it.toFloat() }
33+
val t = ctx.fromFloatArray<FP32, Float>(Shape(2, 3, 4), FP32::class, src)
34+
val out = ctx.ops.permute(t, intArrayOf(1, 0, 2))
35+
assertContentEquals(intArrayOf(3, 2, 4), out.shape.dimensions, "expected shape [B=3, A=2, C=4]")
36+
// out(b, a, c) == in(a, b, c)
37+
for (b in 0 until 3) {
38+
for (a in 0 until 2) {
39+
for (c in 0 until 4) {
40+
val expected = (a * 12 + b * 4 + c).toFloat()
41+
val actual = out.data.get(b, a, c)
42+
assertEquals(expected, actual, "out[$b,$a,$c] vs in[$a,$b,$c]")
43+
}
44+
}
45+
}
46+
}
47+
48+
@Test
49+
fun reverseAxesOnRank4() {
50+
val ctx = ctx()
51+
// Shape [2, 3, 4, 5]. Permute (3, 2, 1, 0) → reverses all axes.
52+
val src = FloatArray(2 * 3 * 4 * 5) { it.toFloat() }
53+
val t = ctx.fromFloatArray<FP32, Float>(Shape(2, 3, 4, 5), FP32::class, src)
54+
val out = ctx.ops.permute(t, intArrayOf(3, 2, 1, 0))
55+
assertContentEquals(intArrayOf(5, 4, 3, 2), out.shape.dimensions)
56+
for (d in 0 until 5) {
57+
for (c in 0 until 4) {
58+
for (b in 0 until 3) {
59+
for (a in 0 until 2) {
60+
val flatIn = a * 60 + b * 20 + c * 5 + d
61+
assertEquals(
62+
flatIn.toFloat(),
63+
out.data.get(d, c, b, a),
64+
"out[$d,$c,$b,$a] vs in[$a,$b,$c,$d]"
65+
)
66+
}
67+
}
68+
}
69+
}
70+
}
71+
72+
@Test
73+
fun roundTripPermuteIsIdentity() {
74+
val ctx = ctx()
75+
val src = FloatArray(2 * 3 * 4) { it.toFloat() }
76+
val t = ctx.fromFloatArray<FP32, Float>(Shape(2, 3, 4), FP32::class, src)
77+
val axes = intArrayOf(2, 0, 1)
78+
val inverse = IntArray(3).also { for (i in axes.indices) it[axes[i]] = i }
79+
80+
val once = ctx.ops.permute(t, axes)
81+
val back = ctx.ops.permute(once, inverse)
82+
83+
assertContentEquals(t.shape.dimensions, back.shape.dimensions)
84+
for (a in 0 until 2) for (b in 0 until 3) for (c in 0 until 4) {
85+
assertEquals(t.data.get(a, b, c), back.data.get(a, b, c), "round-trip mismatch at [$a,$b,$c]")
86+
}
87+
}
88+
89+
@Test
90+
fun permuteEquivalentToTransposeOnRank2() {
91+
val ctx = ctx()
92+
val t = ctx.fromFloatArray<FP32, Float>(
93+
Shape(3, 5), FP32::class,
94+
FloatArray(15) { it.toFloat() }
95+
)
96+
val viaPermute = ctx.ops.permute(t, intArrayOf(1, 0))
97+
val viaTranspose = ctx.ops.transpose(t)
98+
assertContentEquals(viaTranspose.shape.dimensions, viaPermute.shape.dimensions)
99+
for (i in 0 until 5) for (j in 0 until 3) {
100+
assertEquals(viaTranspose.data.get(i, j), viaPermute.data.get(i, j))
101+
}
102+
}
103+
104+
@Test
105+
fun rejectsWrongAxesLength() {
106+
val ctx = ctx()
107+
val t = ctx.fromFloatArray<FP32, Float>(Shape(2, 3), FP32::class, FloatArray(6))
108+
assertFailsWith<IllegalArgumentException> { ctx.ops.permute(t, intArrayOf(1, 0, 2)) }
109+
}
110+
111+
@Test
112+
fun rejectsOutOfRangeAxis() {
113+
val ctx = ctx()
114+
val t = ctx.fromFloatArray<FP32, Float>(Shape(2, 3), FP32::class, FloatArray(6))
115+
assertFailsWith<IllegalArgumentException> { ctx.ops.permute(t, intArrayOf(0, 5)) }
116+
}
117+
118+
@Test
119+
fun rejectsDuplicateAxis() {
120+
val ctx = ctx()
121+
val t = ctx.fromFloatArray<FP32, Float>(Shape(2, 3), FP32::class, FloatArray(6))
122+
assertFailsWith<IllegalArgumentException> { ctx.ops.permute(t, intArrayOf(0, 0)) }
123+
}
124+
}

skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,13 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
234234
return out
235235
}
236236

237+
override fun <T : DType, V> permute(tensor: Tensor<T, V>, axes: IntArray): Tensor<T, V> {
238+
// Record as a regular passthrough; permute is shape-only at the
239+
// op level. A dedicated PermuteOperation can be introduced later
240+
// if the tape consumer needs to distinguish it from raw passthrough.
241+
return base.permute(tensor, axes)
242+
}
243+
237244
// --- Conv/Pool ---
238245
override fun <T : DType, V> conv1d(
239246
input: Tensor<T, V>,

skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,16 @@ public class DefaultGradientTape(
605605
override fun transposeBackward(upstream: Tensor<DType, Any>, output: Tensor<DType, Any>, inputs: List<Tensor<DType, Any>>, attributes: Map<String, Any?>): List<Tensor<DType, Any>?> =
606606
listOf(upstream.ops.transpose(upstream))
607607

608+
override fun permuteBackward(upstream: Tensor<DType, Any>, output: Tensor<DType, Any>, inputs: List<Tensor<DType, Any>>, attributes: Map<String, Any?>): List<Tensor<DType, Any>?> {
609+
// Gradient of permute(t, axes) is permute(upstream, inverseAxes)
610+
// where inverseAxes[axes[i]] = i.
611+
val axes = (attributes["axes"] as? IntArray)
612+
?: error("permuteBackward: missing 'axes' attribute")
613+
val inverse = IntArray(axes.size)
614+
for (i in axes.indices) inverse[axes[i]] = i
615+
return listOf(upstream.ops.permute(upstream, inverse))
616+
}
617+
608618
override fun reluBackward(upstream: Tensor<DType, Any>, output: Tensor<DType, Any>, inputs: List<Tensor<DType, Any>>, attributes: Map<String, Any?>): List<Tensor<DType, Any>?> =
609619
listOf(reluGrad(upstream, inputs[0], output))
610620

skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ private class TestTensorOps : TensorOps {
164164
override fun <T : DType, V> rdivScalar(a: Number, b: Tensor<T, V>): Tensor<T, V> = b
165165
override fun <T : DType, V> matmul(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> = a
166166
override fun <T : DType, V> transpose(tensor: Tensor<T, V>): Tensor<T, V> = tensor
167+
override fun <T : DType, V> permute(tensor: Tensor<T, V>, axes: IntArray): Tensor<T, V> = tensor
167168
override fun <T : DType, V> relu(tensor: Tensor<T, V>): Tensor<T, V> = tensor
168169
override fun <T : DType, V> leakyRelu(tensor: Tensor<T, V>, negativeSlope: Float): Tensor<T, V> = tensor
169170
override fun <T : DType, V> elu(tensor: Tensor<T, V>, alpha: Float): Tensor<T, V> = tensor

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,24 @@ public interface TensorOps {
5353
@Diff
5454
public fun <T : DType, V> transpose(tensor: Tensor<T, V>): Tensor<T, V>
5555

56+
/**
57+
* Permute the dimensions of [tensor] according to [axes].
58+
*
59+
* `axes` is a permutation of `0..tensor.rank-1`; the i-th axis of the
60+
* result is the `axes[i]`-th axis of the input. On a rank-3 tensor of
61+
* shape `[A, B, C]`, `permute(t, intArrayOf(1, 0, 2))` returns shape
62+
* `[B, A, C]`.
63+
*
64+
* `permute(t, intArrayOf(0, 1, ..., rank-3, rank-1, rank-2))` is
65+
* equivalent to [transpose].
66+
*
67+
* @param tensor input tensor, any rank ≥ 1
68+
* @param axes a permutation of `0..tensor.rank-1` (length must equal
69+
* `tensor.rank`, every value in `[0, rank)` exactly once)
70+
*/
71+
@Diff
72+
public fun <T : DType, V> permute(tensor: Tensor<T, V>, axes: IntArray): Tensor<T, V>
73+
5674
// Convolutional operations
5775
@Diff
5876
public fun <T : DType, V> conv1d(

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ public class VoidTensorOps : TensorOps {
151151
return VoidOpsTensor(resultData, tensor.dtype)
152152
}
153153

154+
override fun <T : DType, V> permute(tensor: Tensor<T, V>, axes: IntArray): Tensor<T, V> {
155+
validatePermuteAxes(tensor.shape, axes)
156+
val resultShape = calculatePermuteShape(tensor.shape, axes)
157+
val resultData = dataFactory.zeros<T, V>(resultShape, tensor.dtype)
158+
return VoidOpsTensor(resultData, tensor.dtype)
159+
}
160+
154161
override fun <T : DType, V> conv1d(
155162
input: Tensor<T, V>,
156163
weight: Tensor<T, V>,
@@ -598,6 +605,31 @@ public class VoidTensorOps : TensorOps {
598605
* For 2D tensors: (m, n) -> (n, m)
599606
* For higher dimensions: swaps the last two dimensions
600607
*/
608+
/**
609+
* Validate that [axes] is a valid permutation of `0..shape.rank-1`.
610+
*/
611+
internal fun validatePermuteAxes(shape: Shape, axes: IntArray) {
612+
require(axes.size == shape.rank) {
613+
"permute: axes length ${axes.size} must match tensor rank ${shape.rank}"
614+
}
615+
val seen = BooleanArray(shape.rank)
616+
for (a in axes) {
617+
require(a in 0 until shape.rank) {
618+
"permute: axis $a out of range [0, ${shape.rank})"
619+
}
620+
require(!seen[a]) { "permute: axis $a appears more than once in $axes" }
621+
seen[a] = true
622+
}
623+
}
624+
625+
/**
626+
* Result shape after applying [axes] permutation to [shape].
627+
*/
628+
internal fun calculatePermuteShape(shape: Shape, axes: IntArray): Shape {
629+
val dims = IntArray(shape.rank) { i -> shape.dimensions[axes[i]] }
630+
return Shape(dims)
631+
}
632+
601633
private fun calculateTransposeShape(shape: Shape): Shape {
602634
if (shape.rank < 2) {
603635
throw IllegalArgumentException("Transpose requires tensors with at least 2 dimensions")

0 commit comments

Comments
 (0)