Skip to content

Commit 6fe526a

Browse files
committed
Implement missing Ops for implementing TinyFoA (AAAI 2025) training pipeline for memory-efficient
on-device learning in SKaiNET. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Implements #359 Related-To #358
1 parent b676d15 commit 6fe526a

7 files changed

Lines changed: 437 additions & 2 deletions

File tree

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
GROUP=sk.ainet.core
2-
VERSION_NAME=0.10.1
2+
VERSION_NAME=0.11.0
33

44
POM_DESCRIPTION=SKaiNET
55

skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1960,6 +1960,194 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
19601960
return newTensor(outData, tensor.dtype, tensor)
19611961
}
19621962

1963+
// ---- TinyFoA ops: abs, sign, clamp, lt, ge ----
1964+
1965+
@TensorOp()
1966+
@InProgress("cpu", owner = "team:tinyfoa", issue = "PRD-tinyFoA#op-abs")
1967+
override fun <T : DType, V> abs(tensor: Tensor<T, V>): Tensor<T, V> {
1968+
val outData = dataFactory.init<T, V>(tensor.shape, tensor.dtype) { idx ->
1969+
when (tensor.dtype) {
1970+
sk.ainet.lang.types.FP32::class, sk.ainet.lang.types.FP16::class -> {
1971+
val v = tensor.data.get(*idx) as Float
1972+
@Suppress("UNCHECKED_CAST")
1973+
kotlin.math.abs(v) as V
1974+
}
1975+
sk.ainet.lang.types.Int32::class -> {
1976+
val v = tensor.data.get(*idx) as Int
1977+
@Suppress("UNCHECKED_CAST")
1978+
kotlin.math.abs(v) as V
1979+
}
1980+
else -> throw IllegalArgumentException("Unsupported dtype for abs: ${tensor.dtype}")
1981+
}
1982+
}
1983+
return newTensor(outData, tensor.dtype, tensor)
1984+
}
1985+
1986+
@TensorOp()
1987+
@InProgress("cpu", owner = "team:tinyfoa", issue = "PRD-tinyFoA#op-sign")
1988+
override fun <T : DType, V> sign(tensor: Tensor<T, V>): Tensor<T, V> {
1989+
val outData = dataFactory.init<T, V>(tensor.shape, tensor.dtype) { idx ->
1990+
when (tensor.dtype) {
1991+
sk.ainet.lang.types.FP32::class, sk.ainet.lang.types.FP16::class -> {
1992+
val v = tensor.data.get(*idx) as Float
1993+
@Suppress("UNCHECKED_CAST")
1994+
(if (v > 0f) 1f else if (v < 0f) -1f else 0f) as V
1995+
}
1996+
sk.ainet.lang.types.Int32::class -> {
1997+
val v = tensor.data.get(*idx) as Int
1998+
@Suppress("UNCHECKED_CAST")
1999+
(if (v > 0) 1 else if (v < 0) -1 else 0) as V
2000+
}
2001+
else -> throw IllegalArgumentException("Unsupported dtype for sign: ${tensor.dtype}")
2002+
}
2003+
}
2004+
return newTensor(outData, tensor.dtype, tensor)
2005+
}
2006+
2007+
@TensorOp()
2008+
@InProgress("cpu", owner = "team:tinyfoa", issue = "PRD-tinyFoA#op-clamp")
2009+
override fun <T : DType, V> clamp(tensor: Tensor<T, V>, minVal: Float, maxVal: Float): Tensor<T, V> {
2010+
require(minVal <= maxVal) { "clamp: minVal ($minVal) must be <= maxVal ($maxVal)" }
2011+
val outData = dataFactory.init<T, V>(tensor.shape, tensor.dtype) { idx ->
2012+
when (tensor.dtype) {
2013+
sk.ainet.lang.types.FP32::class, sk.ainet.lang.types.FP16::class -> {
2014+
val v = tensor.data.get(*idx) as Float
2015+
@Suppress("UNCHECKED_CAST")
2016+
v.coerceIn(minVal, maxVal) as V
2017+
}
2018+
sk.ainet.lang.types.Int32::class -> {
2019+
val v = tensor.data.get(*idx) as Int
2020+
@Suppress("UNCHECKED_CAST")
2021+
v.coerceIn(minVal.toInt(), maxVal.toInt()) as V
2022+
}
2023+
else -> throw IllegalArgumentException("Unsupported dtype for clamp: ${tensor.dtype}")
2024+
}
2025+
}
2026+
return newTensor(outData, tensor.dtype, tensor)
2027+
}
2028+
2029+
@TensorOp()
2030+
@InProgress("cpu", owner = "team:tinyfoa", issue = "PRD-tinyFoA#op-lt")
2031+
override fun <T : DType, V> lt(tensor: Tensor<T, V>, value: Float): Tensor<T, V> {
2032+
val outData = dataFactory.init<T, V>(tensor.shape, tensor.dtype) { idx ->
2033+
when (tensor.dtype) {
2034+
sk.ainet.lang.types.FP32::class, sk.ainet.lang.types.FP16::class -> {
2035+
val v = tensor.data.get(*idx) as Float
2036+
@Suppress("UNCHECKED_CAST")
2037+
(if (v < value) 1f else 0f) as V
2038+
}
2039+
sk.ainet.lang.types.Int32::class -> {
2040+
val v = tensor.data.get(*idx) as Int
2041+
@Suppress("UNCHECKED_CAST")
2042+
(if (v < value.toInt()) 1 else 0) as V
2043+
}
2044+
else -> throw IllegalArgumentException("Unsupported dtype for lt: ${tensor.dtype}")
2045+
}
2046+
}
2047+
return newTensor(outData, tensor.dtype, tensor)
2048+
}
2049+
2050+
@TensorOp()
2051+
@InProgress("cpu", owner = "team:tinyfoa", issue = "PRD-tinyFoA#op-ge")
2052+
override fun <T : DType, V> ge(tensor: Tensor<T, V>, value: Float): Tensor<T, V> {
2053+
val outData = dataFactory.init<T, V>(tensor.shape, tensor.dtype) { idx ->
2054+
when (tensor.dtype) {
2055+
sk.ainet.lang.types.FP32::class, sk.ainet.lang.types.FP16::class -> {
2056+
val v = tensor.data.get(*idx) as Float
2057+
@Suppress("UNCHECKED_CAST")
2058+
(if (v >= value) 1f else 0f) as V
2059+
}
2060+
sk.ainet.lang.types.Int32::class -> {
2061+
val v = tensor.data.get(*idx) as Int
2062+
@Suppress("UNCHECKED_CAST")
2063+
(if (v >= value.toInt()) 1 else 0) as V
2064+
}
2065+
else -> throw IllegalArgumentException("Unsupported dtype for ge: ${tensor.dtype}")
2066+
}
2067+
}
2068+
return newTensor(outData, tensor.dtype, tensor)
2069+
}
2070+
2071+
// ---- narrow, pad2d, unfold ----
2072+
2073+
@TensorOp()
2074+
@InProgress("cpu", owner = "team:tinyfoa", issue = "PRD-tinyFoA#op-narrow")
2075+
override fun <T : DType, V> narrow(tensor: Tensor<T, V>, dim: Int, start: Int, length: Int): Tensor<T, V> {
2076+
val actualDim = if (dim < 0) tensor.shape.rank + dim else dim
2077+
require(actualDim in 0 until tensor.shape.rank) { "narrow dim $dim out of bounds for rank ${tensor.shape.rank}" }
2078+
require(start >= 0 && start + length <= tensor.shape.dimensions[actualDim]) {
2079+
"narrow: start=$start length=$length exceeds dim size ${tensor.shape.dimensions[actualDim]}"
2080+
}
2081+
val resultDims = tensor.shape.dimensions.copyOf()
2082+
resultDims[actualDim] = length
2083+
val outShape = Shape(resultDims)
2084+
val outData = dataFactory.init<T, V>(outShape, tensor.dtype) { idx ->
2085+
val srcIdx = idx.copyOf()
2086+
srcIdx[actualDim] = srcIdx[actualDim] + start
2087+
tensor.data.get(*srcIdx)
2088+
}
2089+
return newTensor(outData, tensor.dtype, tensor)
2090+
}
2091+
2092+
@TensorOp()
2093+
@InProgress("cpu", owner = "team:tinyfoa", issue = "PRD-tinyFoA#op-pad2d")
2094+
override fun <T : DType, V> pad2d(tensor: Tensor<T, V>, padLeft: Int, padRight: Int, padTop: Int, padBottom: Int): Tensor<T, V> {
2095+
require(tensor.shape.rank == 4) { "pad2d requires 4D tensor [N,C,H,W], got rank ${tensor.shape.rank}" }
2096+
val (n, c, h, w) = tensor.shape.dimensions.toList()
2097+
val newH = h + padTop + padBottom
2098+
val newW = w + padLeft + padRight
2099+
val outShape = Shape(n, c, newH, newW)
2100+
val outData = dataFactory.init<T, V>(outShape, tensor.dtype) { idx ->
2101+
val srcRow = idx[2] - padTop
2102+
val srcCol = idx[3] - padLeft
2103+
if (srcRow in 0 until h && srcCol in 0 until w) {
2104+
tensor.data.get(idx[0], idx[1], srcRow, srcCol)
2105+
} else {
2106+
// Zero padding
2107+
when (tensor.dtype) {
2108+
sk.ainet.lang.types.FP32::class, sk.ainet.lang.types.FP16::class -> {
2109+
@Suppress("UNCHECKED_CAST") (0f as V)
2110+
}
2111+
sk.ainet.lang.types.Int32::class -> {
2112+
@Suppress("UNCHECKED_CAST") (0 as V)
2113+
}
2114+
else -> throw IllegalArgumentException("Unsupported dtype for pad2d: ${tensor.dtype}")
2115+
}
2116+
}
2117+
}
2118+
return newTensor(outData, tensor.dtype, tensor)
2119+
}
2120+
2121+
@TensorOp()
2122+
@InProgress("cpu", owner = "team:tinyfoa", issue = "PRD-tinyFoA#op-unfold")
2123+
override fun <T : DType, V> unfold(tensor: Tensor<T, V>, dim: Int, size: Int, step: Int): Tensor<T, V> {
2124+
val actualDim = if (dim < 0) tensor.shape.rank + dim else dim
2125+
require(actualDim in 0 until tensor.shape.rank) { "unfold dim $dim out of bounds for rank ${tensor.shape.rank}" }
2126+
val dimSize = tensor.shape.dimensions[actualDim]
2127+
require(size <= dimSize) { "unfold size $size > dim size $dimSize" }
2128+
val numWindows = (dimSize - size) / step + 1
2129+
val resultDims = IntArray(tensor.shape.rank + 1)
2130+
for (i in 0 until tensor.shape.rank) {
2131+
resultDims[i] = if (i == actualDim) numWindows else tensor.shape.dimensions[i]
2132+
}
2133+
resultDims[tensor.shape.rank] = size
2134+
val outShape = Shape(resultDims)
2135+
val outData = dataFactory.init<T, V>(outShape, tensor.dtype) { idx ->
2136+
// idx has rank+1 dimensions. Last dimension is the window element index.
2137+
val windowIdx = idx[tensor.shape.rank]
2138+
val srcIdx = IntArray(tensor.shape.rank)
2139+
for (i in 0 until tensor.shape.rank) {
2140+
srcIdx[i] = if (i == actualDim) {
2141+
idx[i] * step + windowIdx
2142+
} else {
2143+
idx[i]
2144+
}
2145+
}
2146+
tensor.data.get(*srcIdx)
2147+
}
2148+
return newTensor(outData, tensor.dtype, tensor)
2149+
}
2150+
19632151
@TensorOp()
19642152
@InProgress("cpu", owner = "team:cpu", issue = "task-ops.md#op-convert")
19652153
override fun <TFrom : DType, TTo : DType, V> convert(

skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,14 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
389389
override fun <T : DType, V> mean(tensor: Tensor<T, V>, dim: Int?): Tensor<T, V> = base.mean(tensor, dim)
390390
override fun <T : DType, V> variance(tensor: Tensor<T, V>, dim: Int?): Tensor<T, V> = base.variance(tensor, dim)
391391
override fun <T : DType, V> sqrt(tensor: Tensor<T, V>): Tensor<T, V> = base.sqrt(tensor)
392+
override fun <T : DType, V> abs(tensor: Tensor<T, V>): Tensor<T, V> = base.abs(tensor)
393+
override fun <T : DType, V> sign(tensor: Tensor<T, V>): Tensor<T, V> = base.sign(tensor)
394+
override fun <T : DType, V> clamp(tensor: Tensor<T, V>, minVal: Float, maxVal: Float): Tensor<T, V> = base.clamp(tensor, minVal, maxVal)
395+
override fun <T : DType, V> lt(tensor: Tensor<T, V>, value: Float): Tensor<T, V> = base.lt(tensor, value)
396+
override fun <T : DType, V> ge(tensor: Tensor<T, V>, value: Float): Tensor<T, V> = base.ge(tensor, value)
397+
override fun <T : DType, V> narrow(tensor: Tensor<T, V>, dim: Int, start: Int, length: Int): Tensor<T, V> = base.narrow(tensor, dim, start, length)
398+
override fun <T : DType, V> pad2d(tensor: Tensor<T, V>, padLeft: Int, padRight: Int, padTop: Int, padBottom: Int): Tensor<T, V> = base.pad2d(tensor, padLeft, padRight, padTop, padBottom)
399+
override fun <T : DType, V> unfold(tensor: Tensor<T, V>, dim: Int, size: Int, step: Int): Tensor<T, V> = base.unfold(tensor, dim, size, step)
392400
override fun <T : DType, TTo : DType, V> convert(tensor: Tensor<T, V>, targetType: TTo): Tensor<TTo, V> = base.convert(tensor, targetType)
393401
override fun <T : DType, V> tril(tensor: Tensor<T, V>, k: Int): Tensor<T, V> = base.tril(tensor, k)
394402
}

0 commit comments

Comments
 (0)