Skip to content

Commit 61c7fea

Browse files
Merge pull request #618 from SKaiNET-developers/feature/autograd-completeness
Autograd completeness: pow + log + conv/pool backward formulas (#617)
2 parents a1fc274 + 412127e commit 61c7fea

17 files changed

Lines changed: 1838 additions & 20 deletions

File tree

skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ import sk.ainet.lang.tensor.data.FloatArrayTensorData
1212
import sk.ainet.lang.tensor.data.TensorDataFactory
1313
import sk.ainet.lang.tensor.ops.UpsampleMode
1414
import sk.ainet.lang.types.FP32
15+
import kotlin.math.ln
16+
import kotlin.math.log10 as kmLog10
17+
import kotlin.math.log2 as kmLog2
18+
import kotlin.math.pow
1519
import kotlin.math.sqrt
1620

1721
@Backend(id = "cpu", displayName = "CPU")
@@ -2123,6 +2127,112 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
21232127
return newTensor(outData, tensor.dtype, tensor)
21242128
}
21252129

2130+
/**
2131+
* Element-wise power: `c[i] = a[i] ^ b[i]`. Integer-valued exponents
2132+
* use repeated multiply for stability; everything else routes through
2133+
* `kotlin.math.pow`. Shape contract: shapes must match exactly (no
2134+
* broadcasting yet — caller's responsibility).
2135+
*/
2136+
override fun <T : DType, V> pow(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> {
2137+
require(
2138+
a.dtype == sk.ainet.lang.types.FP32::class ||
2139+
a.dtype == sk.ainet.lang.types.FP16::class
2140+
) { "pow supports only FP16/FP32, got ${a.dtype}" }
2141+
require(a.shape == b.shape) { "pow requires matching shapes; got ${a.shape} and ${b.shape}" }
2142+
val outData = dataFactory.init<T, V>(a.shape, a.dtype) { idx ->
2143+
val av = a.data.get(*idx) as Float
2144+
val bv = b.data.get(*idx) as Float
2145+
@Suppress("UNCHECKED_CAST")
2146+
scalarPow(av, bv) as V
2147+
}
2148+
return newTensor(outData, a.dtype, a)
2149+
}
2150+
2151+
/**
2152+
* Element-wise scalar power: `c[i] = a[i] ^ n`. Small-integer
2153+
* exponents (|n| <= 16) use repeated multiply for exactness; all
2154+
* other values route through `kotlin.math.pow`.
2155+
*/
2156+
override fun <T : DType, V> powScalar(a: Tensor<T, V>, n: Number): Tensor<T, V> {
2157+
require(
2158+
a.dtype == sk.ainet.lang.types.FP32::class ||
2159+
a.dtype == sk.ainet.lang.types.FP16::class
2160+
) { "powScalar supports only FP16/FP32, got ${a.dtype}" }
2161+
val nFloat = n.toFloat()
2162+
val nInt = n.toInt()
2163+
val isSmallInt = nFloat == nInt.toFloat() && kotlin.math.abs(nInt) <= 16
2164+
val outData = dataFactory.init<T, V>(a.shape, a.dtype) { idx ->
2165+
val av = a.data.get(*idx) as Float
2166+
@Suppress("UNCHECKED_CAST")
2167+
(if (isSmallInt) integerPow(av, nInt) else scalarPow(av, nFloat)) as V
2168+
}
2169+
return newTensor(outData, a.dtype, a)
2170+
}
2171+
2172+
/** Repeated-multiply for small integer exponents. Handles n < 0 via reciprocal. */
2173+
private fun integerPow(base: Float, n: Int): Float {
2174+
if (n == 0) return 1f
2175+
if (n < 0) return 1f / integerPow(base, -n)
2176+
var result = 1f
2177+
var b = base
2178+
var e = n
2179+
while (e > 0) {
2180+
if (e and 1 == 1) result *= b
2181+
b *= b
2182+
e = e ushr 1
2183+
}
2184+
return result
2185+
}
2186+
2187+
private fun scalarPow(base: Float, exp: Float): Float =
2188+
base.toDouble().pow(exp.toDouble()).toFloat()
2189+
2190+
/**
2191+
* Element-wise natural log: `c[i] = ln(a[i])`. Negative or zero
2192+
* inputs follow `kotlin.math.ln` semantics (negative → NaN, zero
2193+
* → -Infinity). Mirror of `stablehlo.log`.
2194+
*/
2195+
override fun <T : DType, V> log(tensor: Tensor<T, V>): Tensor<T, V> {
2196+
require(
2197+
tensor.dtype == sk.ainet.lang.types.FP32::class ||
2198+
tensor.dtype == sk.ainet.lang.types.FP16::class
2199+
) { "log supports only FP16/FP32, got ${tensor.dtype}" }
2200+
val outData = dataFactory.init<T, V>(tensor.shape, tensor.dtype) { idx ->
2201+
val v = tensor.data.get(*idx) as Float
2202+
@Suppress("UNCHECKED_CAST")
2203+
ln(v) as V
2204+
}
2205+
return newTensor(outData, tensor.dtype, tensor)
2206+
}
2207+
2208+
/** Element-wise base-2 log: `c[i] = log2(a[i])`. */
2209+
override fun <T : DType, V> log2(tensor: Tensor<T, V>): Tensor<T, V> {
2210+
require(
2211+
tensor.dtype == sk.ainet.lang.types.FP32::class ||
2212+
tensor.dtype == sk.ainet.lang.types.FP16::class
2213+
) { "log2 supports only FP16/FP32, got ${tensor.dtype}" }
2214+
val outData = dataFactory.init<T, V>(tensor.shape, tensor.dtype) { idx ->
2215+
val v = tensor.data.get(*idx) as Float
2216+
@Suppress("UNCHECKED_CAST")
2217+
kmLog2(v) as V
2218+
}
2219+
return newTensor(outData, tensor.dtype, tensor)
2220+
}
2221+
2222+
/** Element-wise base-10 log: `c[i] = log10(a[i])`. */
2223+
override fun <T : DType, V> log10(tensor: Tensor<T, V>): Tensor<T, V> {
2224+
require(
2225+
tensor.dtype == sk.ainet.lang.types.FP32::class ||
2226+
tensor.dtype == sk.ainet.lang.types.FP16::class
2227+
) { "log10 supports only FP16/FP32, got ${tensor.dtype}" }
2228+
val outData = dataFactory.init<T, V>(tensor.shape, tensor.dtype) { idx ->
2229+
val v = tensor.data.get(*idx) as Float
2230+
@Suppress("UNCHECKED_CAST")
2231+
kmLog10(v) as V
2232+
}
2233+
return newTensor(outData, tensor.dtype, tensor)
2234+
}
2235+
21262236
// ---- TinyFoA ops: abs, sign, clamp, lt, ge ----
21272237

21282238
@TensorOp()
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package sk.ainet.exec.tensor.ops
2+
3+
import kotlin.math.abs
4+
import kotlin.math.ln
5+
import kotlin.math.log10 as kmLog10
6+
import kotlin.math.log2 as kmLog2
7+
import kotlin.test.Test
8+
import kotlin.test.assertEquals
9+
import kotlin.test.assertFailsWith
10+
import kotlin.test.assertTrue
11+
import sk.ainet.lang.tensor.Shape
12+
import sk.ainet.lang.tensor.VoidOpsTensor
13+
import sk.ainet.lang.tensor.data.DenseTensorDataFactory
14+
import sk.ainet.lang.tensor.data.FloatArrayTensorData
15+
import sk.ainet.lang.types.FP32
16+
import sk.ainet.lang.types.Int32
17+
18+
/**
19+
* Forward-parity tests for the new `log`, `log2`, `log10` ops (Tier B
20+
* of #617). Verifies against `kotlin.math.ln/log2/log10` per element,
21+
* plus the dtype-restriction guard.
22+
*/
23+
class DefaultCpuOpsLogTest {
24+
private val dataFactory = DenseTensorDataFactory()
25+
private val ops = DefaultCpuOps(dataFactory)
26+
27+
private fun floatTensor(shape: Shape, values: FloatArray) =
28+
VoidOpsTensor(dataFactory.fromFloatArray<FP32, Float>(shape, FP32::class, values), FP32::class)
29+
30+
private fun assertCloseTo(expected: FloatArray, actual: FloatArray, tol: Float = 1e-5f) {
31+
assertEquals(expected.size, actual.size, "length mismatch")
32+
for (i in expected.indices) {
33+
val diff = abs(expected[i] - actual[i])
34+
assertTrue(diff <= tol, "[$i] expected=${expected[i]} actual=${actual[i]} diff=$diff tol=$tol")
35+
}
36+
}
37+
38+
@Test
39+
fun log_matches_kotlin_math_ln() {
40+
val a = floatTensor(Shape(5), floatArrayOf(1f, 2f, kotlin.math.E.toFloat(), 10f, 100f))
41+
val expected = floatArrayOf(0f, ln(2f), 1f, ln(10f), ln(100f))
42+
val out = ops.log(a)
43+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
44+
}
45+
46+
@Test
47+
fun log2_matches_kotlin_math_log2() {
48+
val a = floatTensor(Shape(5), floatArrayOf(1f, 2f, 4f, 8f, 1024f))
49+
val expected = floatArrayOf(0f, 1f, 2f, 3f, 10f)
50+
val out = ops.log2(a)
51+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
52+
}
53+
54+
@Test
55+
fun log10_matches_kotlin_math_log10() {
56+
val a = floatTensor(Shape(4), floatArrayOf(1f, 10f, 100f, 1000f))
57+
val expected = floatArrayOf(0f, 1f, 2f, 3f)
58+
val out = ops.log10(a)
59+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
60+
}
61+
62+
@Test
63+
fun log_of_negative_returns_nan() {
64+
val a = floatTensor(Shape(2), floatArrayOf(-1f, -2f))
65+
val out = ops.log(a)
66+
for (v in (out.data as FloatArrayTensorData<*>).buffer) {
67+
assertTrue(v.isNaN(), "log of negative must be NaN, got $v")
68+
}
69+
}
70+
71+
@Test
72+
fun log_of_zero_returns_negative_infinity() {
73+
val a = floatTensor(Shape(1), floatArrayOf(0f))
74+
val out = ops.log(a)
75+
val result = (out.data as FloatArrayTensorData<*>).buffer[0]
76+
assertEquals(Float.NEGATIVE_INFINITY, result, "log(0) must be -Inf, got $result")
77+
}
78+
79+
@Test
80+
fun log_log2_log10_consistent_with_each_other() {
81+
// log_b(x) = ln(x) / ln(b) — verify the three flavours agree.
82+
val a = floatTensor(Shape(3), floatArrayOf(2f, 10f, 100f))
83+
val logVals = (ops.log(a).data as FloatArrayTensorData<*>).buffer
84+
val log2Vals = (ops.log2(a).data as FloatArrayTensorData<*>).buffer
85+
val log10Vals = (ops.log10(a).data as FloatArrayTensorData<*>).buffer
86+
for (i in 0..2) {
87+
assertEquals(log2Vals[i], logVals[i] / ln(2f), 1e-5f, "log2 consistency at $i")
88+
assertEquals(log10Vals[i], logVals[i] / ln(10f), 1e-5f, "log10 consistency at $i")
89+
}
90+
}
91+
92+
@Test
93+
fun log_rejects_non_float_dtype() {
94+
val intData = dataFactory.fromIntArray<Int32, Int>(Shape(2), Int32::class, intArrayOf(1, 2))
95+
val tInt = VoidOpsTensor(intData, Int32::class)
96+
assertFailsWith<IllegalArgumentException> { ops.log(tInt) }
97+
}
98+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package sk.ainet.exec.tensor.ops
2+
3+
import kotlin.math.abs
4+
import kotlin.test.Test
5+
import kotlin.test.assertEquals
6+
import kotlin.test.assertFailsWith
7+
import kotlin.test.assertTrue
8+
import sk.ainet.lang.tensor.Shape
9+
import sk.ainet.lang.tensor.VoidOpsTensor
10+
import sk.ainet.lang.tensor.data.DenseTensorDataFactory
11+
import sk.ainet.lang.tensor.data.FloatArrayTensorData
12+
import sk.ainet.lang.types.FP32
13+
14+
/**
15+
* Forward-parity tests for the new `pow` and `powScalar` ops (Tier A
16+
* of #617). Checks both the binary form (tensor exponent) and the
17+
* scalar form for integer + real exponents.
18+
*/
19+
class DefaultCpuOpsPowTest {
20+
private val dataFactory = DenseTensorDataFactory()
21+
private val ops = DefaultCpuOps(dataFactory)
22+
23+
private fun floatTensor(shape: Shape, values: FloatArray) =
24+
VoidOpsTensor(dataFactory.fromFloatArray<FP32, Float>(shape, FP32::class, values), FP32::class)
25+
26+
private fun assertCloseTo(expected: FloatArray, actual: FloatArray, tol: Float = 1e-4f) {
27+
assertEquals(expected.size, actual.size, "length mismatch")
28+
for (i in expected.indices) {
29+
val diff = abs(expected[i] - actual[i])
30+
assertTrue(diff <= tol, "[$i] expected=${expected[i]} actual=${actual[i]} diff=$diff tol=$tol")
31+
}
32+
}
33+
34+
@Test
35+
fun powScalar_integer_2_matches_x_times_x() {
36+
val a = floatTensor(Shape(5), floatArrayOf(0.5f, 1f, 2f, 3f, -2f))
37+
val expected = floatArrayOf(0.25f, 1f, 4f, 9f, 4f)
38+
val out = ops.powScalar(a, 2)
39+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
40+
}
41+
42+
@Test
43+
fun powScalar_integer_3_matches_x_cubed() {
44+
val a = floatTensor(Shape(4), floatArrayOf(1f, 2f, 3f, -2f))
45+
val expected = floatArrayOf(1f, 8f, 27f, -8f)
46+
val out = ops.powScalar(a, 3)
47+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
48+
}
49+
50+
@Test
51+
fun powScalar_negative_integer_minus_1_is_reciprocal() {
52+
val a = floatTensor(Shape(3), floatArrayOf(2f, 4f, 0.5f))
53+
val expected = floatArrayOf(0.5f, 0.25f, 2f)
54+
val out = ops.powScalar(a, -1)
55+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
56+
}
57+
58+
@Test
59+
fun powScalar_real_half_is_sqrt() {
60+
val a = floatTensor(Shape(4), floatArrayOf(0f, 1f, 4f, 9f))
61+
val expected = floatArrayOf(0f, 1f, 2f, 3f)
62+
val out = ops.powScalar(a, 0.5f)
63+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
64+
}
65+
66+
@Test
67+
fun powScalar_real_1_5_matches_kotlin_math_pow() {
68+
val a = floatTensor(Shape(3), floatArrayOf(1f, 2f, 4f))
69+
val expected = floatArrayOf(1f, 2.828427f, 8f)
70+
val out = ops.powScalar(a, 1.5f)
71+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
72+
}
73+
74+
@Test
75+
fun pow_binary_element_wise() {
76+
val a = floatTensor(Shape(4), floatArrayOf(2f, 3f, 4f, 5f))
77+
val b = floatTensor(Shape(4), floatArrayOf(2f, 3f, 0.5f, 1f))
78+
val expected = floatArrayOf(4f, 27f, 2f, 5f)
79+
val out = ops.pow(a, b)
80+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
81+
}
82+
83+
@Test
84+
fun pow_binary_rejects_shape_mismatch() {
85+
val a = floatTensor(Shape(3), floatArrayOf(1f, 2f, 3f))
86+
val b = floatTensor(Shape(4), floatArrayOf(1f, 2f, 3f, 4f))
87+
assertFailsWith<IllegalArgumentException> { ops.pow(a, b) }
88+
}
89+
}

skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,21 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
184184
return out
185185
}
186186

187+
// --- Power ops ---
188+
override fun <T : DType, V> pow(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> {
189+
val out = base.pow(a, b)
190+
record(PowOperation<T, V>(), listOf(a, b), listOf(out))
191+
return out
192+
}
193+
194+
override fun <T : DType, V> powScalar(a: Tensor<T, V>, n: Number): Tensor<T, V> {
195+
val out = base.powScalar(a, n)
196+
// Single-input + scalar exponent stashed in parameters so the
197+
// backward formula can recover it (a-partial is n * a^(n-1)).
198+
record(PowOperation<T, V>(parameters = mapOf("scalar_exponent" to n)), listOf(a), listOf(out))
199+
return out
200+
}
201+
187202
// --- Scalar ops ---
188203
override fun <T : DType, V> addScalar(a: Tensor<T, V>, b: Number): Tensor<T, V> {
189204
val out = base.addScalar(a, b)
@@ -426,6 +441,9 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
426441
override fun <T : DType, V> mean(tensor: Tensor<T, V>, dim: Int?): Tensor<T, V> = base.mean(tensor, dim)
427442
override fun <T : DType, V> variance(tensor: Tensor<T, V>, dim: Int?): Tensor<T, V> = base.variance(tensor, dim)
428443
override fun <T : DType, V> sqrt(tensor: Tensor<T, V>): Tensor<T, V> = base.sqrt(tensor)
444+
override fun <T : DType, V> log(tensor: Tensor<T, V>): Tensor<T, V> = base.log(tensor)
445+
override fun <T : DType, V> log2(tensor: Tensor<T, V>): Tensor<T, V> = base.log2(tensor)
446+
override fun <T : DType, V> log10(tensor: Tensor<T, V>): Tensor<T, V> = base.log10(tensor)
429447
override fun <T : DType, V> abs(tensor: Tensor<T, V>): Tensor<T, V> = base.abs(tensor)
430448
override fun <T : DType, V> sign(tensor: Tensor<T, V>): Tensor<T, V> = base.sign(tensor)
431449
override fun <T : DType, V> clamp(tensor: Tensor<T, V>, minVal: Float, maxVal: Float): Tensor<T, V> = base.clamp(tensor, minVal, maxVal)

0 commit comments

Comments
 (0)