Skip to content

Commit 4b7ddfa

Browse files
Merge pull request #631 from SKaiNET-developers/feature/630-tanh-activation
Add tanh as a first-class TensorOps activation primitive (#630)
2 parents 1c817fc + 43720df commit 4b7ddfa

9 files changed

Lines changed: 323 additions & 7 deletions

File tree

skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2526,11 +2526,18 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
25262526
return newTensor(outData, tensor.dtype, tensor)
25272527
}
25282528

2529+
@TensorOp()
2530+
@InProgress("cpu", owner = "team:cpu", issue = "task-ops.md#op-tanh")
25292531
override fun <T : DType, V> tanh(tensor: Tensor<T, V>): Tensor<T, V> {
25302532
val outData = dataFactory.init<T, V>(tensor.shape, tensor.dtype) { idx ->
2531-
val x = tensor.data.get(*idx) as Float
2532-
@Suppress("UNCHECKED_CAST")
2533-
kotlin.math.tanh(x).toFloat() as V
2533+
when (tensor.dtype) {
2534+
sk.ainet.lang.types.FP32::class, sk.ainet.lang.types.FP16::class -> {
2535+
val x = tensor.data.get(*idx) as Float
2536+
@Suppress("UNCHECKED_CAST")
2537+
kotlin.math.tanh(x) as V
2538+
}
2539+
else -> throw IllegalArgumentException("Unsupported dtype for tanh: ${tensor.dtype}")
2540+
}
25342541
}
25352542
return newTensor(outData, tensor.dtype, tensor)
25362543
}

skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/sk/ainet/exec/tensor/ops/DefaultCpuOpsActivationsTest.kt

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class DefaultCpuOpsActivationsTest {
2727

2828
private fun sigmoid(x: Float): Float = 1f / (1f + kotlin.math.exp(-x))
2929
private fun silu(x: Float): Float = x * sigmoid(x)
30+
private fun tanh(x: Float): Float = kotlin.math.tanh(x)
3031

3132
private fun assertAlmostEquals(expected: Float, actual: Float, eps: Float = 1e-5f, msg: String = "") {
3233
assertTrue(kotlin.math.abs(expected - actual) <= eps, msg.ifEmpty { "Expected $expected, got $actual" })
@@ -104,4 +105,50 @@ class DefaultCpuOpsActivationsTest {
104105
}
105106
assertTrue(threw, "Expected IllegalArgumentException for Int32 silu")
106107
}
108+
109+
@Test
110+
fun tanh_fp32_basic_values() {
111+
val input = fTensor(Shape(5), floatArrayOf(-2f, -1f, 0f, 1f, 2f))
112+
val out = cpuOps.tanh(input)
113+
assertEquals(Shape(5), out.shape)
114+
assertEquals(FP32::class, out.dtype)
115+
val expected = floatArrayOf(-2f, -1f, 0f, 1f, 2f).map { tanh(it) }
116+
for (i in expected.indices) {
117+
assertAlmostEquals(expected[i], out.data[i] as Float, 1e-6f, "tanh at $i mismatch")
118+
}
119+
}
120+
121+
@Test
122+
fun tanh_fp32_matrix_shape_preserved() {
123+
val input = fTensor(Shape(2, 3), floatArrayOf(
124+
-1f, 0f, 1f,
125+
2f, -2f, 0.5f
126+
))
127+
val out = cpuOps.tanh(input)
128+
assertEquals(Shape(2, 3), out.shape)
129+
assertAlmostEquals(tanh(-1f), out.data[0, 0] as Float, 1e-6f)
130+
assertAlmostEquals(tanh(0.5f), out.data[1, 2] as Float, 1e-6f)
131+
}
132+
133+
@Test
134+
fun tanh_fp32_saturates_at_extremes() {
135+
// tanh(x) → 1 as x → +∞, → -1 as x → -∞. kotlin.math.tanh is numerically stable.
136+
val input = fTensor(Shape(2), floatArrayOf(-100f, 100f))
137+
val out = cpuOps.tanh(input)
138+
assertAlmostEquals(-1f, out.data[0] as Float, 1e-6f)
139+
assertAlmostEquals(1f, out.data[1] as Float, 1e-6f)
140+
}
141+
142+
@Test
143+
fun tanh_unsupported_dtype_int32_throws() {
144+
val input = iTensor(Shape(2), intArrayOf(1, 2))
145+
var threw = false
146+
try {
147+
cpuOps.tanh(input as sk.ainet.lang.tensor.Tensor<Int32, Int>)
148+
} catch (e: IllegalArgumentException) {
149+
threw = true
150+
assertTrue(e.message?.contains("Unsupported dtype") == true)
151+
}
152+
assertTrue(threw, "Expected IllegalArgumentException for Int32 tanh")
153+
}
107154
}

skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ private fun stableInputName(op: Operation, index: Int, total: Int): String = whe
136136
is ReluOperation<*, *> -> "input"
137137
is SoftmaxOperation<*, *> -> "input"
138138
is SigmoidOperation<*, *> -> "input"
139+
is TanhOperation<*, *> -> "input"
139140
is SqueezeOperation<*, *> -> "input"
140141
is UnsqueezeOperation<*, *> -> "input"
141142
else -> if (total == 1) "input" else "input_$index"
@@ -418,6 +419,12 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
418419
return out
419420
}
420421

422+
override fun <T : DType, V> tanh(tensor: Tensor<T, V>): Tensor<T, V> {
423+
val out = base.tanh(tensor)
424+
record(TanhOperation<T, V>(), listOf(tensor), listOf(out))
425+
return out
426+
}
427+
421428
// --- Misc ---
422429
override fun <T : DType, V> squeeze(tensor: Tensor<T, V>, dim: Int?): Tensor<T, V> {
423430
val out = base.squeeze(tensor, dim)

skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ public open class DefaultExecutionTape(
175175
"matmul" -> listOf(ops.matmul(typedInputs[0], typedInputs[1]))
176176
"relu" -> listOf(ops.relu(typedInputs[0]))
177177
"sigmoid" -> listOf(ops.sigmoid(typedInputs[0]))
178+
"tanh" -> listOf(ops.tanh(typedInputs[0]))
178179
"sum" -> listOf(ops.sum(typedInputs[0], params["dim"] as? Int))
179180
"mean" -> listOf(ops.mean(typedInputs[0], params["dim"] as? Int))
180181
"concat" -> listOf(ops.concat(typedInputs, params["dim"] as Int))
@@ -880,6 +881,13 @@ public class DefaultGradientTape(
880881
return listOf(grad)
881882
}
882883

884+
override fun tanhBackward(upstream: Tensor<DType, Any>, output: Tensor<DType, Any>, inputs: List<Tensor<DType, Any>>, attributes: Map<String, Any?>): List<Tensor<DType, Any>?> {
885+
// d(tanh(x))/dx = 1 - tanh(x)^2 = 1 - output^2
886+
val oneMinusSquare = output.ops.rsubScalar(1.0, output.ops.multiply(output, output))
887+
val grad = upstream.ops.multiply(upstream, oneMinusSquare)
888+
return listOf(grad)
889+
}
890+
883891
override fun siluBackward(upstream: Tensor<DType, Any>, output: Tensor<DType, Any>, inputs: List<Tensor<DType, Any>>, attributes: Map<String, Any?>): List<Tensor<DType, Any>?> {
884892
// silu(x) = x * sigmoid(x)
885893
// d(silu(x))/dx = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x)) = sigmoid(x) + silu(x) * (1 - sigmoid(x))
@@ -1040,6 +1048,7 @@ public class DefaultGradientTape(
10401048
"squeeze" -> BackwardOp(inputs, output) { upstream -> squeezeBackward(upstream, output, inputs, trace.attributes) }
10411049
"unsqueeze" -> BackwardOp(inputs, output) { upstream -> unsqueezeBackward(upstream, output, inputs, trace.attributes) }
10421050
"sigmoid" -> BackwardOp(inputs, output) { upstream -> sigmoidBackward(upstream, output, inputs, trace.attributes) }
1051+
"tanh" -> BackwardOp(inputs, output) { upstream -> tanhBackward(upstream, output, inputs, trace.attributes) }
10431052
"silu" -> BackwardOp(inputs, output) { upstream -> siluBackward(upstream, output, inputs, trace.attributes) }
10441053
"gelu" -> BackwardOp(inputs, output) { upstream -> geluBackward(upstream, output, inputs, trace.attributes) }
10451054
"variance" -> BackwardOp(inputs, output) { upstream -> varianceBackward(upstream, output, inputs, trace.attributes) }

skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ private class TestTensorOps : TensorOps {
171171
override fun <T : DType, V> softmax(tensor: Tensor<T, V>, dim: Int): Tensor<T, V> = tensor
172172
override fun <T : DType, V> logSoftmax(tensor: Tensor<T, V>, dim: Int): Tensor<T, V> = tensor
173173
override fun <T : DType, V> sigmoid(tensor: Tensor<T, V>): Tensor<T, V> = tensor
174+
override fun <T : DType, V> tanh(tensor: Tensor<T, V>): Tensor<T, V> = tensor
174175
override fun <T : DType, V> silu(tensor: Tensor<T, V>): Tensor<T, V> = tensor
175176
override fun <T : DType, V> gelu(tensor: Tensor<T, V>): Tensor<T, V> = tensor
176177
override fun <T : DType, V> sum(tensor: Tensor<T, V>, dim: Int?): Tensor<T, V> = tensor

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/TensorExtensions.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ public fun <T : DType, V> Tensor<T, V>.relu(): Tensor<T, V> = ops.relu(this)
7373
public fun <T : DType, V> Tensor<T, V>.leakyRelu(negativeSlope: Float = 0.01f): Tensor<T, V> = ops.leakyRelu(this, negativeSlope)
7474
public fun <T : DType, V> Tensor<T, V>.elu(alpha: Float = 1.0f): Tensor<T, V> = ops.elu(this, alpha)
7575
public fun <T : DType, V> Tensor<T, V>.sigmoid(): Tensor<T, V> = ops.sigmoid(this)
76+
public fun <T : DType, V> Tensor<T, V>.tanh(): Tensor<T, V> = ops.tanh(this)
7677
public fun <T : DType, V> Tensor<T, V>.silu(): Tensor<T, V> = ops.silu(this)
7778
public fun <T : DType, V> Tensor<T, V>.gelu(): Tensor<T, V> = ops.gelu(this)
7879
public fun <T : DType, V> Tensor<T, V>.exp(): Tensor<T, V> = ops.exp(this)

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,37 @@ public class SigmoidOperation<T : DType, V>(
813813
override fun clone(newParameters: Map<String, Any>): Operation = SigmoidOperation<T, V>(newParameters)
814814
}
815815

816+
public class TanhOperation<T : DType, V>(
817+
parameters: Map<String, Any> = emptyMap()
818+
) : BaseOperation("tanh", "activation", parameters) {
819+
820+
override fun <T2 : DType, V2> execute(inputs: List<Tensor<T2, V2>>): List<Tensor<T2, V2>> {
821+
require(inputs.size == 1) { "Tanh operation requires exactly 1 input" }
822+
throw UnsupportedOperationException("Direct execution not supported in graph mode")
823+
}
824+
825+
override fun validateInputs(inputs: List<TensorSpec>): ValidationResult {
826+
if (inputs.size != 1) {
827+
return ValidationResult.Invalid(listOf("Tanh operation requires exactly 1 input, got ${inputs.size}"))
828+
}
829+
return ValidationResult.Valid
830+
}
831+
832+
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
833+
require(inputs.size == 1) { "Tanh operation requires exactly 1 input" }
834+
return listOf(
835+
TensorSpec(
836+
name = "tanh_output",
837+
shape = inputs[0].shape,
838+
dtype = inputs[0].dtype,
839+
requiresGrad = inputs[0].requiresGrad
840+
)
841+
)
842+
}
843+
844+
override fun clone(newParameters: Map<String, Any>): Operation = TanhOperation<T, V>(newParameters)
845+
}
846+
816847
/**
817848
* Additional shape operations
818849
*/

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ public interface TensorOps {
178178
public fun <T : DType, V> sigmoid(tensor: Tensor<T, V>): Tensor<T, V>
179179
@Diff
180180
@ActivationDsl
181+
public fun <T : DType, V> tanh(tensor: Tensor<T, V>): Tensor<T, V>
182+
@Diff
183+
@ActivationDsl
181184
public fun <T : DType, V> silu(tensor: Tensor<T, V>): Tensor<T, V>
182185
@Diff
183186
@ActivationDsl
@@ -304,10 +307,6 @@ public interface TensorOps {
304307
throw NotImplementedError("cos not implemented by this TensorOps backend")
305308
}
306309

307-
public fun <T : DType, V> tanh(tensor: Tensor<T, V>): Tensor<T, V> {
308-
throw NotImplementedError("tanh not implemented by this TensorOps backend")
309-
}
310-
311310
/**
312311
* Scaled dot-product attention.
313312
*

0 commit comments

Comments
 (0)