Skip to content

Commit 4186c4f

Browse files
committed
Simplify Module forward signature to be just function, to an extension function.
Relate-To: #117, #104
1 parent 92a9b24 commit 4186c4f

12 files changed

Lines changed: 49 additions & 45 deletions

File tree

skainet-nn/skainet-nn-api/src/commonMain/kotlin/sk/ainet/nn/Flatten.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ public class Flatten<T : DType, V>(
1616
override val modules: List<Module<T, V>>
1717
get() = emptyList()
1818

19-
override fun Tensor<T, V>.forward(input: Tensor<T, V>): Tensor<T, V> {
20-
return input.flatten(startDim, endDim)
19+
override fun forward(input: Tensor<T, V>): Tensor<T, V> {
20+
return with(input) {
21+
flatten(startDim, endDim)
22+
}
2123
}
2224
}

skainet-nn/skainet-nn-api/src/commonMain/kotlin/sk/ainet/nn/Input.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class Input<T : DType, V>(private val inputShape: Shape, override val nam
1212
get() = emptyList()
1313

1414

15-
override fun Tensor<T, V>.forward(input: Tensor<T, V>): Tensor<T, V> {
15+
override fun forward(input: Tensor<T, V>): Tensor<T, V> {
1616
return input
1717
}
1818
}

skainet-nn/skainet-nn-api/src/commonMain/kotlin/sk/ainet/nn/Linear.kt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ public class Linear<T : DType, V>(
3737
override val modules: List<Module<T, V>>
3838
get() = emptyList()
3939

40-
override fun Tensor<T, V>.forward(input: Tensor<T, V>): Tensor<T, V> {
40+
override fun forward(input: Tensor<T, V>): Tensor<T, V> {
4141
val weight = params.weights().value
4242
val bias = params.bias().value
4343

4444
// Use TensorOps context operations
45-
val weightTransposed = weight.t()
46-
val matmulResult = matmul(input, weightTransposed)
47-
return matmulResult + bias
45+
with(input) {
46+
val weightTransposed = weight.t()
47+
val matmulResult = matmul(input, weightTransposed)
48+
return matmulResult + bias
49+
}
4850
}
4951
}

skainet-nn/skainet-nn-api/src/commonMain/kotlin/sk/ainet/nn/Module.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ public abstract class Module<T : DType, V> {
1111

1212
public abstract val modules: List<Module<T, V>>
1313

14-
public abstract fun Tensor<T, V>.forward(input: Tensor<T, V>): Tensor<T, V>
14+
public abstract fun forward(input: Tensor<T, V>): Tensor<T, V>
1515

16-
public operator fun Tensor<T, V>.invoke(input: Tensor<T, V>): Tensor<T, V> {
16+
public operator fun invoke(input: Tensor<T, V>): Tensor<T, V> {
1717
return forward(input)
1818
}
1919
}

skainet-nn/skainet-nn-api/src/commonMain/kotlin/sk/ainet/nn/activations/ActivationsWrapperModule.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class ActivationsWrapperModule<T : DType, V>(
1414
override val modules: List<Module<T, V>>
1515
get() = emptyList()
1616

17-
override fun Tensor<T, V>.forward(input: Tensor<T, V>): Tensor<T, V> {
17+
override fun forward(input: Tensor<T, V>): Tensor<T, V> {
1818
return activationHandler(input)
1919
}
2020
}

skainet-nn/skainet-nn-api/src/commonMain/kotlin/sk/ainet/nn/activations/Softmax.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ public class Softmax<T : DType, V>(private val dimension: Int, override val name
99
override val modules: List<Module<T, V>>
1010
get() = emptyList()
1111

12-
override fun Tensor<T, V>.forward(input: Tensor<T, V>): Tensor<T, V> {
13-
return input.softmax(dimension)
12+
override fun forward(input: Tensor<T, V>): Tensor<T, V> {
13+
return with(input) {
14+
softmax(dimension)
15+
}
1416
}
1517
}
1618

skainet-nn/skainet-nn-api/src/commonMain/kotlin/sk/ainet/nn/activations/relu.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ public class ReLU<T : DType, V>(override val name: String = "ReLU") : Module<T,
99
override val modules: List<Module<T, V>>
1010
get() = emptyList()
1111

12-
override fun Tensor<T, V>.forward(input: Tensor<T, V>): Tensor<T, V> = input.relu()
12+
override fun forward(input: Tensor<T, V>): Tensor<T, V> = with(input) { relu() }
1313
}
1414

skainet-nn/skainet-nn-api/src/commonMain/kotlin/sk/ainet/nn/topology/MLP.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ public class MLP<T : DType, V>(vararg modules: Module<T, V>, override val name:
1212
override val modules: List<Module<T, V>>
1313
get() = modulesList
1414

15-
override fun Tensor<T, V>.forward(input: Tensor<T, V>): Tensor<T, V> {
15+
override fun forward(input: Tensor<T, V>): Tensor<T, V> {
1616
var tmp = input
1717
modulesList.forEach { module ->
18-
tmp = with(module) { this@forward.forward(tmp) }
18+
tmp = module.forward(tmp)
1919
}
2020
return tmp
2121
}

skainet-nn/skainet-nn-api/src/commonTest/kotlin/sk/ainet/nn/dsl/NetworkBuilderGenericTest.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import sk.ainet.core.tensor.Int8
99
import sk.ainet.core.tensor.Int32
1010
import sk.ainet.core.tensor.backend.CpuBackend
1111
import sk.ainet.core.tensor.backend.CpuBackendInt8
12+
import sk.ainet.nn.Linear
13+
import sk.ainet.nn.Module
1214
import kotlin.test.*
1315

1416
/**
@@ -20,7 +22,7 @@ class NetworkBuilderGenericTest {
2022
@Test
2123
fun testGenericNetworkFP32() {
2224
// Test FP32/Float combination - basic functionality
23-
val network = network<FP32, Float> {
25+
val network: Module<FP32, Float> = network<FP32, Float> {
2426
input(2)
2527
dense(3) {
2628
weights { shape -> CpuTensorFP32.ones(shape) }
@@ -181,12 +183,12 @@ class NetworkBuilderGenericTest {
181183
// Test NetworkBuilder class directly with generic types
182184
val builder = NetworkBuilder<FP32, Float>()
183185

184-
val linear1 = sk.ainet.nn.Linear(2, 4, "layer1",
186+
val linear1 = Linear(2, 4, "layer1",
185187
CpuTensorFP32.ones(Shape(4, 2)),
186188
CpuTensorFP32.zeros(Shape(4))
187189
)
188190

189-
val linear2 = sk.ainet.nn.Linear(4, 1, "layer2",
191+
val linear2 = Linear(4, 1, "layer2",
190192
CpuTensorFP32.ones(Shape(1, 4)),
191193
CpuTensorFP32.zeros(Shape(1))
192194
)

skainet-nn/skainet-nn-api/src/commonTest/kotlin/sk/ainet/nn/dsl/NetworkBuilderTest.kt

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class NetworkBuilderTest {
3838

3939
// Test forward pass
4040
val input = CpuTensorFP32.fromArray(Shape(1, 2), floatArrayOf(1.0f, 2.0f))
41-
val output = with(network) { input.forward(input) }
41+
val output = network.forward(input)
4242
assertEquals(Shape(1, 1), output.shape)
4343
}
4444

@@ -63,7 +63,7 @@ class NetworkBuilderTest {
6363

6464
// Test forward pass
6565
val input = CpuTensorInt8.fromArray(Shape(1, 2), byteArrayOf(1, 2))
66-
val output = with(network) { input.forward(input) }
66+
val output = network.forward(input)
6767
assertEquals(Shape(1, 1), output.shape)
6868
}
6969

@@ -82,7 +82,7 @@ class NetworkBuilderTest {
8282

8383
// Test forward pass
8484
val input = CpuTensorInt32.fromArray(Shape(1, 3), intArrayOf(1, 2, 3))
85-
val output = with(network) { input.forward(input) }
85+
val output = network.forward(input)
8686
assertEquals(Shape(1, 2), output.shape)
8787
}
8888

@@ -105,7 +105,7 @@ class NetworkBuilderTest {
105105

106106
// Verify it's using FP32/Float types
107107
val input = CpuTensorFP32.fromArray(Shape(1, 2), floatArrayOf(1.0f, 2.0f))
108-
val output = with(network) { input.forward(input) }
108+
val output = network.forward(input)
109109
assertEquals(Shape(1, 1), output.shape)
110110
}
111111

@@ -122,7 +122,7 @@ class NetworkBuilderTest {
122122

123123
assertNotNull(network)
124124
val input = CpuTensorFP32.fromArray(Shape(1, 2), floatArrayOf(1.0f, 2.0f))
125-
val output = with(network) { input.forward(input) }
125+
val output = network.forward(input)
126126
assertEquals(Shape(1, 1), output.shape)
127127
}
128128

@@ -174,7 +174,7 @@ class NetworkBuilderTest {
174174

175175
assertNotNull(network)
176176
val input = CpuTensorFP32.fromArray(Shape(1, 2), floatArrayOf(1.0f, 2.0f))
177-
val output = with(network) { input.forward(input) }
177+
val output = network.forward(input)
178178
assertEquals(Shape(1, 1), output.shape)
179179
}
180180

@@ -203,7 +203,7 @@ class NetworkBuilderTest {
203203

204204
assertNotNull(network)
205205
val input = CpuTensorFP32.fromArray(Shape(1, 3), floatArrayOf(1.0f, 2.0f, 3.0f))
206-
val output = with(network) { input.forward(input) }
206+
val output = network.forward(input)
207207
assertEquals(Shape(1, 2), output.shape)
208208
}
209209

@@ -225,7 +225,7 @@ class NetworkBuilderTest {
225225

226226
assertNotNull(network)
227227
val input = CpuTensorFP32.fromArray(Shape(1, 2), floatArrayOf(1.0f, -1.0f))
228-
val output = with(network) { input.forward(input) }
228+
val output = network.forward(input)
229229
assertEquals(Shape(1, 1), output.shape)
230230
}
231231

@@ -250,7 +250,7 @@ class NetworkBuilderTest {
250250
assertEquals("MLP", network.name)
251251

252252
val input = CpuTensorFP32.fromArray(Shape(1, 2), floatArrayOf(1.0f, 2.0f))
253-
val output = with(network) { input.forward(input) }
253+
val output = network.forward(input)
254254
assertEquals(Shape(1, 1), output.shape)
255255
}
256256

@@ -274,7 +274,7 @@ class NetworkBuilderTest {
274274

275275
// Test forward pass
276276
val input = CpuTensorFP32.fromArray(Shape(1, 2), floatArrayOf(1.0f, 2.0f))
277-
val output = with(network) { input.forward(input) }
277+
val output = network.forward(input)
278278
assertEquals(Shape(1, 1), output.shape)
279279
}
280280

@@ -297,7 +297,7 @@ class NetworkBuilderTest {
297297

298298
// Test forward pass
299299
val input = CpuTensorInt32.fromArray(Shape(1, 3), intArrayOf(1, 2, 3))
300-
val output = with(network) { input.forward(input) }
300+
val output = network.forward(input)
301301
assertEquals(Shape(1, 1), output.shape)
302302
}
303303

@@ -321,7 +321,7 @@ class NetworkBuilderTest {
321321
assertNotNull(network)
322322

323323
val input = CpuTensorFP32.fromArray(Shape(1, 4), floatArrayOf(1.0f, 2.0f, 3.0f, 4.0f))
324-
val output = with(network) { input.forward(input) }
324+
val output = network.forward(input)
325325
assertEquals(Shape(1, 1), output.shape)
326326
}
327327

@@ -358,8 +358,8 @@ class NetworkBuilderTest {
358358

359359
// Test that both networks produce the same output for the same input (deterministic)
360360
val input = CpuTensorFP32.fromArray(Shape(1, 3), floatArrayOf(1.0f, 2.0f, 3.0f))
361-
val output1 = with(network1) { input.forward(input) }
362-
val output2 = with(network2) { input.forward(input) }
361+
val output1 = network1.forward(input)
362+
val output2 = network2.forward(input)
363363

364364
assertEquals(output1.shape, output2.shape)
365365
// Note: We can't easily test exact equality due to potential floating point differences
@@ -392,7 +392,7 @@ class NetworkBuilderTest {
392392

393393
// Test forward pass
394394
val input = CpuTensorFP32.fromArray(Shape(1, 5), floatArrayOf(1.0f, 2.0f, 3.0f, 4.0f, 5.0f))
395-
val output = with(network) { input.forward(input) }
395+
val output = network.forward(input)
396396
assertEquals(Shape(1, 1), output.shape)
397397
}
398398

@@ -416,7 +416,7 @@ class NetworkBuilderTest {
416416

417417
// Test forward pass
418418
val input = CpuTensorFP32.fromArray(Shape(1, 3), floatArrayOf(1.0f, 2.0f, 3.0f))
419-
val output = with(network) { input.forward(input) }
419+
val output = network.forward(input)
420420
assertEquals(Shape(1, 2), output.shape)
421421
}
422422

@@ -439,7 +439,7 @@ class NetworkBuilderTest {
439439

440440
// Test forward pass
441441
val input = CpuTensorInt32.fromArray(Shape(1, 2), intArrayOf(10, 20))
442-
val output = with(network) { input.forward(input) }
442+
val output = network.forward(input)
443443
assertEquals(Shape(1, 1), output.shape)
444444
}
445445
}

0 commit comments

Comments
 (0)