Skip to content

Commit 9e7003c

Browse files
committed
Add flatten to DSL
Related-To: #23
1 parent db7ff75 commit 9e7003c

3 files changed

Lines changed: 29 additions & 3 deletions

File tree

build.gradle.kts

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ plugins {
44
alias(libs.plugins.jetbrainsKotlinJvm) apply false
55
alias(libs.plugins.binaryCompatibility) apply false
66
alias(libs.plugins.modulegraph.souza) apply true
7-
8-
97
}
108

119
allprojects {

core/src/commonMain/kotlin/sk/ai/net/dsl/NetworkBuilder.kt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sk.ai.net.dsl
33
import sk.ai.net.nn.activations.ActivationsWrapperModule
44
import sk.ai.net.Shape
55
import sk.ai.net.Tensor
6+
import sk.ai.net.nn.Flatten
67
import sk.ai.net.nn.Input
78
import sk.ai.net.nn.Linear
89
import sk.ai.net.nn.Module
@@ -25,6 +26,8 @@ interface NetworkDslItem
2526
interface NeuralNetworkDsl : NetworkDslItem {
2627
fun input(inputSize: Int, id: String = "")
2728

29+
fun flatten(id: String = "", content: FLATTEN.() -> Unit = {})
30+
2831
fun dense(outputDimension: Int, id: String = "", content: DENSE.() -> Unit = {})
2932
}
3033

@@ -35,6 +38,13 @@ interface DENSE : NetworkDslItem {
3538
fun bias(initBlock: (Shape) -> Tensor)
3639
}
3740

41+
@NetworkDsl
42+
interface FLATTEN : NetworkDslItem {
43+
var startDim: Int
44+
var endDim: Int
45+
}
46+
47+
3848
private fun getDefaultName(id: String, s: String, size: Int): String {
3949
if (id.isNotEmpty()) return id
4050
return "$s-$size"
@@ -69,6 +79,16 @@ fun createLinear(
6979
}
7080
}
7181

82+
class FlattenImpl(
83+
override var startDim: Int = 1,
84+
override var endDim: Int = -1,
85+
private val id: String
86+
) : FLATTEN {
87+
fun create(): Module {
88+
return Flatten(startDim, endDim, id)
89+
}
90+
}
91+
7292
class DenseImpl(
7393
private val inputDimension: Int, private val outputDimension: Int, private val id: String
7494
) : DENSE {
@@ -119,6 +139,14 @@ private class NeuralNetworkDslImpl : NeuralNetworkDsl {
119139
modules.add(Input(Shape(inputSize), name = getDefaultName(id, "Input", modules.size)))
120140
}
121141

142+
override fun flatten(id: String, content: FLATTEN.() -> Unit) {
143+
val impl = FlattenImpl(
144+
id = getDefaultName(id, "flatten", modules.size)
145+
)
146+
impl.content()
147+
modules += impl.create()
148+
}
149+
122150
override fun dense(outputDimension: Int, id: String, content: DENSE.() -> Unit) {
123151
val inputDimension = lastDimension
124152
lastDimension = outputDimension

settings.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ dependencyResolutionManagement {
1616
rootProject.name = "skainet"
1717
include(":core")
1818
include(":io")
19-
include(":gguf")
19+
include(":gguf")

0 commit comments

Comments
 (0)