@@ -7,6 +7,7 @@ import sk.ai.net.nn.Flatten
77import sk.ai.net.nn.Input
88import sk.ai.net.nn.Linear
99import sk.ai.net.nn.Conv2d
10+ import sk.ai.net.nn.MaxPool2d
1011import sk.ai.net.nn.Module
1112import sk.ai.net.nn.topology.MLP
1213
@@ -31,6 +32,8 @@ interface NeuralNetworkDsl : NetworkDslItem {
3132
3233 fun conv2d (id : String = "", content : CONV2D .() -> Unit = {})
3334
35+ fun maxPool2d (id : String = "", content : MAXPOOL2D .() -> Unit = {})
36+
3437 fun dense (outputDimension : Int , id : String = "", content : DENSE .() -> Unit = {})
3538}
3639
@@ -55,6 +58,12 @@ interface CONV2D : NetworkDslItem {
5558 var padding: Int
5659}
5760
61+ @NetworkDsl
62+ interface MAXPOOL2D : NetworkDslItem {
63+ var kernelSize: Int
64+ var stride: Int
65+ }
66+
5867
5968private fun getDefaultName (id : String , s : String , size : Int ): String {
6069 if (id.isNotEmpty()) return id
@@ -157,6 +166,18 @@ class Conv2dImpl(
157166 )
158167}
159168
169+ class MaxPool2dImpl (
170+ override var kernelSize : Int = 2 ,
171+ override var stride : Int = 2 ,
172+ private val id : String
173+ ) : MAXPOOL2D {
174+ fun create (): Module = MaxPool2d (
175+ kernelSize = kernelSize,
176+ stride = stride,
177+ name = id
178+ )
179+ }
180+
160181private class NeuralNetworkDslImpl : NeuralNetworkDsl {
161182
162183 val modules = mutableListOf<Module >()
@@ -186,6 +207,14 @@ private class NeuralNetworkDslImpl : NeuralNetworkDsl {
186207 modules + = impl.create()
187208 }
188209
210+ override fun maxPool2d (id : String , content : MAXPOOL2D .() -> Unit ) {
211+ val impl = MaxPool2dImpl (
212+ id = getDefaultName(id, " maxPool2d" , modules.size)
213+ )
214+ impl.content()
215+ modules + = impl.create()
216+ }
217+
189218 override fun dense (outputDimension : Int , id : String , content : DENSE .() -> Unit ) {
190219 val inputDimension = lastDimension
191220 lastDimension = outputDimension
0 commit comments