@@ -3,6 +3,7 @@ package sk.ai.net.dsl
33import sk.ai.net.nn.activations.ActivationsWrapperModule
44import sk.ai.net.Shape
55import sk.ai.net.Tensor
6+ import sk.ai.net.nn.Flatten
67import sk.ai.net.nn.Input
78import sk.ai.net.nn.Linear
89import sk.ai.net.nn.Module
@@ -25,6 +26,8 @@ interface NetworkDslItem
2526interface NeuralNetworkDsl : NetworkDslItem {
2627 fun input (inputSize : Int , id : String = "")
2728
29+ fun flatten (id : String = "", content : FLATTEN .() -> Unit = {})
30+
2831 fun dense (outputDimension : Int , id : String = "", content : DENSE .() -> Unit = {})
2932}
3033
@@ -35,6 +38,13 @@ interface DENSE : NetworkDslItem {
3538 fun bias (initBlock : (Shape ) -> Tensor )
3639}
3740
41+ @NetworkDsl
42+ interface FLATTEN : NetworkDslItem {
43+ var startDim: Int
44+ var endDim: Int
45+ }
46+
47+
3848private fun getDefaultName (id : String , s : String , size : Int ): String {
3949 if (id.isNotEmpty()) return id
4050 return " $s -$size "
@@ -69,6 +79,16 @@ fun createLinear(
6979 }
7080}
7181
82+ class FlattenImpl (
83+ override var startDim : Int = 1 ,
84+ override var endDim : Int = -1 ,
85+ private val id : String
86+ ) : FLATTEN {
87+ fun create (): Module {
88+ return Flatten (startDim, endDim, id)
89+ }
90+ }
91+
7292class DenseImpl (
7393 private val inputDimension : Int , private val outputDimension : Int , private val id : String
7494) : DENSE {
@@ -119,6 +139,14 @@ private class NeuralNetworkDslImpl : NeuralNetworkDsl {
119139 modules.add(Input (Shape (inputSize), name = getDefaultName(id, " Input" , modules.size)))
120140 }
121141
142+ override fun flatten (id : String , content : FLATTEN .() -> Unit ) {
143+ val impl = FlattenImpl (
144+ id = getDefaultName(id, " flatten" , modules.size)
145+ )
146+ impl.content()
147+ modules + = impl.create()
148+ }
149+
122150 override fun dense (outputDimension : Int , id : String , content : DENSE .() -> Unit ) {
123151 val inputDimension = lastDimension
124152 lastDimension = outputDimension
0 commit comments