Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions core/src/commonMain/kotlin/sk/ai/net/dsl/NetworkBuilder.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import sk.ai.net.Tensor
import sk.ai.net.nn.Flatten
import sk.ai.net.nn.Input
import sk.ai.net.nn.Linear
import sk.ai.net.nn.Conv2d
import sk.ai.net.nn.Module
import sk.ai.net.nn.topology.MLP

Expand All @@ -28,6 +29,8 @@ interface NeuralNetworkDsl : NetworkDslItem {

fun flatten(id: String = "", content: FLATTEN.() -> Unit = {})

fun conv2d(id: String = "", content: CONV2D.() -> Unit = {})

fun dense(outputDimension: Int, id: String = "", content: DENSE.() -> Unit = {})
}

Expand All @@ -44,6 +47,14 @@ interface FLATTEN : NetworkDslItem {
var endDim: Int
}

@NetworkDsl
interface CONV2D : NetworkDslItem {
var outChannels: Int
var kernelSize: Int
var stride: Int
var padding: Int
}


private fun getDefaultName(id: String, s: String, size: Int): String {
if (id.isNotEmpty()) return id
Expand Down Expand Up @@ -128,6 +139,24 @@ class DenseImpl(
}
}

class Conv2dImpl(
private val inChannels: Int,
override var outChannels: Int = 1,
override var kernelSize: Int = 3,
override var stride: Int = 1,
override var padding: Int = 0,
private val id: String
) : CONV2D {
fun create(): Module = Conv2d(
inChannels = inChannels,
outChannels = outChannels,
kernelSize = kernelSize,
stride = stride,
padding = padding,
name = id
)
}

private class NeuralNetworkDslImpl : NeuralNetworkDsl {

val modules = mutableListOf<Module>()
Expand All @@ -147,6 +176,16 @@ private class NeuralNetworkDslImpl : NeuralNetworkDsl {
modules += impl.create()
}

override fun conv2d(id: String, content: CONV2D.() -> Unit) {
val impl = Conv2dImpl(
inChannels = lastDimension,
id = getDefaultName(id, "conv2d", modules.size)
)
impl.content()
lastDimension = impl.outChannels
modules += impl.create()
}

override fun dense(outputDimension: Int, id: String, content: DENSE.() -> Unit) {
val inputDimension = lastDimension
lastDimension = outputDimension
Expand Down
6 changes: 3 additions & 3 deletions core/src/commonMain/kotlin/sk/ai/net/nn/Conv2d.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class Conv2d(
val kernelSize: Int,
val stride: Int = 1,
val padding: Int = 0,
useBias: Boolean = true
useBias: Boolean = true,
name: String = "Conv2d"
) : Module() {
override val name: String = name
val weight: Tensor
val bias: Tensor?
override val name: String
get() = "Conv2d"
override val modules: List<Module>
get() = emptyList()

Expand Down
5 changes: 5 additions & 0 deletions core/src/commonMain/kotlin/sk/ai/net/nn/Flatten.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ package sk.ai.net.nn

import sk.ai.net.Tensor

/**
* A simple layer that flattens an input tensor into a 1D tensor.
* This layer has no parameters and simply reshapes the input.
*/

class Flatten(
private val startDim: Int = 1,
private val endDim: Int = -1,
Expand Down
26 changes: 26 additions & 0 deletions core/src/commonTest/kotlin/sk/ai/net/nn/FlattenTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,36 @@ package sk.ai.net.nn
import sk.ai.net.Shape
import sk.ai.net.impl.DoublesTensor
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals

class FlattenTest {
@Test
fun `flatten 2d tensor`() {
val tensor = DoublesTensor(
Shape(2, 3),
doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)
)
val flatten = Flatten()
val result = flatten.forward(tensor) as DoublesTensor
assertEquals(Shape(6), result.shape)
assertContentEquals(doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), result.elements)
}

@Test
fun `flatten 3d tensor`() {
val tensor = DoublesTensor(
Shape(2, 2, 2),
doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)
)
val flatten = Flatten()
val result = flatten.forward(tensor) as DoublesTensor
assertEquals(Shape(8), result.shape)
assertContentEquals(
doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0),
result.elements
)
@Test
fun flatten_basic() {
val flatten = Flatten()
val input = DoublesTensor(Shape(2,1,28,28), DoubleArray(2*1*28*28))
Expand Down
Loading