@@ -6,6 +6,7 @@ import sk.ai.net.Tensor
66import sk.ai.net.nn.Flatten
77import sk.ai.net.nn.Input
88import sk.ai.net.nn.Linear
9+ import sk.ai.net.nn.Conv2d
910import sk.ai.net.nn.Module
1011import sk.ai.net.nn.topology.MLP
1112
@@ -28,6 +29,8 @@ interface NeuralNetworkDsl : NetworkDslItem {
2829
2930 fun flatten (id : String = "", content : FLATTEN .() -> Unit = {})
3031
32+ fun conv2d (id : String = "", content : CONV2D .() -> Unit = {})
33+
3134 fun dense (outputDimension : Int , id : String = "", content : DENSE .() -> Unit = {})
3235}
3336
@@ -44,6 +47,14 @@ interface FLATTEN : NetworkDslItem {
4447 var endDim: Int
4548}
4649
50+ @NetworkDsl
51+ interface CONV2D : NetworkDslItem {
52+ var outChannels: Int
53+ var kernelSize: Int
54+ var stride: Int
55+ var padding: Int
56+ }
57+
4758
4859private fun getDefaultName (id : String , s : String , size : Int ): String {
4960 if (id.isNotEmpty()) return id
@@ -128,6 +139,24 @@ class DenseImpl(
128139 }
129140}
130141
142+ class Conv2dImpl (
143+ private val inChannels : Int ,
144+ override var outChannels : Int = 1 ,
145+ override var kernelSize : Int = 3 ,
146+ override var stride : Int = 1 ,
147+ override var padding : Int = 0 ,
148+ private val id : String
149+ ) : CONV2D {
150+ fun create (): Module = Conv2d (
151+ inChannels = inChannels,
152+ outChannels = outChannels,
153+ kernelSize = kernelSize,
154+ stride = stride,
155+ padding = padding,
156+ name = id
157+ )
158+ }
159+
131160private class NeuralNetworkDslImpl : NeuralNetworkDsl {
132161
133162 val modules = mutableListOf<Module >()
@@ -147,6 +176,16 @@ private class NeuralNetworkDslImpl : NeuralNetworkDsl {
147176 modules + = impl.create()
148177 }
149178
179+ override fun conv2d (id : String , content : CONV2D .() -> Unit ) {
180+ val impl = Conv2dImpl (
181+ inChannels = lastDimension,
182+ id = getDefaultName(id, " conv2d" , modules.size)
183+ )
184+ impl.content()
185+ lastDimension = impl.outChannels
186+ modules + = impl.create()
187+ }
188+
150189 override fun dense (outputDimension : Int , id : String , content : DENSE .() -> Unit ) {
151190 val inputDimension = lastDimension
152191 lastDimension = outputDimension
0 commit comments