Skip to content

Commit 1b5811d

Browse files
skainet-devmichalharakal
authored andcommitted
Refactor conv2d DSL
1 parent 9e7003c commit 1b5811d

2 files changed

Lines changed: 42 additions & 3 deletions

File tree

core/src/commonMain/kotlin/sk/ai/net/dsl/NetworkBuilder.kt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import sk.ai.net.Tensor
66
import sk.ai.net.nn.Flatten
77
import sk.ai.net.nn.Input
88
import sk.ai.net.nn.Linear
9+
import sk.ai.net.nn.Conv2d
910
import sk.ai.net.nn.Module
1011
import sk.ai.net.nn.topology.MLP
1112

@@ -28,6 +29,8 @@ interface NeuralNetworkDsl : NetworkDslItem {
2829

2930
fun flatten(id: String = "", content: FLATTEN.() -> Unit = {})
3031

32+
fun conv2d(id: String = "", content: CONV2D.() -> Unit = {})
33+
3134
fun dense(outputDimension: Int, id: String = "", content: DENSE.() -> Unit = {})
3235
}
3336

@@ -44,6 +47,14 @@ interface FLATTEN : NetworkDslItem {
4447
var endDim: Int
4548
}
4649

50+
@NetworkDsl
51+
interface CONV2D : NetworkDslItem {
52+
var outChannels: Int
53+
var kernelSize: Int
54+
var stride: Int
55+
var padding: Int
56+
}
57+
4758

4859
private fun getDefaultName(id: String, s: String, size: Int): String {
4960
if (id.isNotEmpty()) return id
@@ -128,6 +139,24 @@ class DenseImpl(
128139
}
129140
}
130141

142+
class Conv2dImpl(
143+
private val inChannels: Int,
144+
override var outChannels: Int = 1,
145+
override var kernelSize: Int = 3,
146+
override var stride: Int = 1,
147+
override var padding: Int = 0,
148+
private val id: String
149+
) : CONV2D {
150+
fun create(): Module = Conv2d(
151+
inChannels = inChannels,
152+
outChannels = outChannels,
153+
kernelSize = kernelSize,
154+
stride = stride,
155+
padding = padding,
156+
name = id
157+
)
158+
}
159+
131160
private class NeuralNetworkDslImpl : NeuralNetworkDsl {
132161

133162
val modules = mutableListOf<Module>()
@@ -147,6 +176,16 @@ private class NeuralNetworkDslImpl : NeuralNetworkDsl {
147176
modules += impl.create()
148177
}
149178

179+
override fun conv2d(id: String, content: CONV2D.() -> Unit) {
180+
val impl = Conv2dImpl(
181+
inChannels = lastDimension,
182+
id = getDefaultName(id, "conv2d", modules.size)
183+
)
184+
impl.content()
185+
lastDimension = impl.outChannels
186+
modules += impl.create()
187+
}
188+
150189
override fun dense(outputDimension: Int, id: String, content: DENSE.() -> Unit) {
151190
val inputDimension = lastDimension
152191
lastDimension = outputDimension

core/src/commonMain/kotlin/sk/ai/net/nn/Conv2d.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ class Conv2d(
1313
val kernelSize: Int,
1414
val stride: Int = 1,
1515
val padding: Int = 0,
16-
useBias: Boolean = true
16+
useBias: Boolean = true,
17+
name: String = "Conv2d"
1718
) : Module() {
19+
override val name: String = name
1820
val weight: Tensor
1921
val bias: Tensor?
20-
override val name: String
21-
get() = "Conv2d"
2222
override val modules: List<Module>
2323
get() = emptyList()
2424

0 commit comments

Comments
 (0)