Skip to content

Add TensorOps.permute(axes) for arbitrary-axis permutation #551

@michalharakal

Description

@michalharakal

Context

TensorOps.transpose only swaps the last two dimensions. There's no API today for arbitrary-axis permutation (e.g. swapping dims 0 and 1 of a rank-3 tensor).

This came up concretely in SKaiNET-transformers PR #81: MultiHeadAttention needs to convert [seqLen, nHeads, headDim][nHeads, seqLen, headDim] after Q/K/V projection. With no permute available, that PR landed a manual copy-based helper (swapSeqHeadDims) that allocates a fresh FloatArray per call. On a 1024-token prefill of TinyLlama 1.1B, that's ~13M floats / forward (~52 MB) of transient allocation just to fix the shape.

A proper permute(axes) is generic, single-allocation per call, and likely stride-only on backends that support strided tensors.

Proposal

Add to sk.ainet.lang.tensor.ops.TensorOps:

/**
 * Permute the dimensions of [tensor] according to [axes].
 *
 * `axes` is a permutation of `0..tensor.rank-1`; the i-th axis of the
 * result is the `axes[i]`-th axis of the input. For example, on a rank-3
 * tensor of shape `[A, B, C]`, `permute(t, intArrayOf(1, 0, 2))` returns
 * shape `[B, A, C]`.
 *
 * `permute(t, intArrayOf(0, 1, ..., rank-3, rank-1, rank-2))` is
 * equivalent to [transpose].
 *
 * @param tensor input tensor, any rank ≥ 1
 * @param axes a permutation of `0..tensor.rank-1` (length must equal
 *   `tensor.rank`, every value in `[0, rank)` exactly once)
 */
@Diff
public fun <T : DType, V> permute(tensor: Tensor<T, V>, axes: IntArray): Tensor<T, V>

Implementation scope

  • commonMain interface declaration + default-shape derivation.
  • VoidTensorOps shape-only stub (returns zeros at the permuted shape).
  • DefaultCpuOps JVM/native data-moving implementation.
  • Unit tests: identity permute, dim-0/dim-1 swap (rank 3), full shuffle (rank 4), round-trip equivalence with permute(permute(t, p), inverse(p)) == t.

Out of scope

  • Strided/lazy permute (no data copy). Worth a follow-up once a strided tensor representation lands; for now, copy is fine and matches existing op semantics.
  • AD support (the @Diff annotation is included for symmetry with transpose, but the gradient impl can land in a follow-up if needed).

Acceptance

  • API compiles on all KMP targets.
  • Unit tests pass.
  • Downstream MultiHeadAttention.swapSeqHeadDims can be deleted and replaced with ctx.ops.permute(t, intArrayOf(1, 0, 2)).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions