Skip to content

Commit eb6c089

Browse files
Merge pull request #30 from skainet-dev/feature/MNIST
maxpool2d and kdoc
2 parents 1c8b362 + 9c45ef7 commit eb6c089

10 files changed

Lines changed: 198 additions & 4 deletions

File tree

KDOC_PREPROCESSING.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# KDoc preprocessing
2+
3+
This project follows the same approach as [Kotlin/DataFrame](https://github.com/Kotlin/dataframe/blob/master/KDOC_PREPROCESSING.md) for preprocessing KDoc comments before generating documentation with Dokka.
4+
5+
The Dokka plugin is applied to all modules. Run `./gradlew dokkaHtml` to generate HTML documentation under each module's `build/dokka` directory.

build.gradle.kts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import org.jetbrains.dokka.gradle.DokkaMultiModuleTask
2+
13
plugins {
24
alias(libs.plugins.androidLibrary) apply false
35
alias(libs.plugins.kotlinMultiplatform) apply false
46
alias(libs.plugins.jetbrainsKotlinJvm) apply false
57
alias(libs.plugins.binaryCompatibility) apply false
8+
alias(libs.plugins.dokka) apply false
69
alias(libs.plugins.modulegraph.souza) apply true
710
}
811

12+
apply(plugin = "org.jetbrains.dokka")
13+
914
allprojects {
1015
group = "sk.ai.net"
1116
version = "0.0.6-SNAPSHOT"
@@ -15,3 +20,7 @@ moduleGraphConfig {
1520
readmePath.set("./Modules.md")
1621
heading = "### Module Graph"
1722
}
23+
24+
tasks.register<org.jetbrains.dokka.gradle.DokkaMultiModuleTask>("dokkaHtmlMultiModule") {
25+
outputDirectory.set(buildDir.resolve("dokka"))
26+
}

core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ plugins {
88
alias(libs.plugins.kotlinMultiplatform)
99
alias(libs.plugins.androidLibrary)
1010
alias(libs.plugins.binaryCompatibility)
11+
alias(libs.plugins.dokka)
1112
alias(libs.plugins.vanniktech.mavenPublish)
1213
}
1314

core/src/commonMain/kotlin/sk/ai/net/dsl/NetworkBuilder.kt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import sk.ai.net.nn.Flatten
77
import sk.ai.net.nn.Input
88
import sk.ai.net.nn.Linear
99
import sk.ai.net.nn.Conv2d
10+
import sk.ai.net.nn.MaxPool2d
1011
import sk.ai.net.nn.Module
1112
import sk.ai.net.nn.topology.MLP
1213

@@ -31,6 +32,8 @@ interface NeuralNetworkDsl : NetworkDslItem {
3132

3233
fun conv2d(id: String = "", content: CONV2D.() -> Unit = {})
3334

35+
fun maxPool2d(id: String = "", content: MAXPOOL2D.() -> Unit = {})
36+
3437
fun dense(outputDimension: Int, id: String = "", content: DENSE.() -> Unit = {})
3538
}
3639

@@ -55,6 +58,12 @@ interface CONV2D : NetworkDslItem {
5558
var padding: Int
5659
}
5760

61+
@NetworkDsl
62+
interface MAXPOOL2D : NetworkDslItem {
63+
var kernelSize: Int
64+
var stride: Int
65+
}
66+
5867

5968
private fun getDefaultName(id: String, s: String, size: Int): String {
6069
if (id.isNotEmpty()) return id
@@ -157,6 +166,18 @@ class Conv2dImpl(
157166
)
158167
}
159168

169+
class MaxPool2dImpl(
170+
override var kernelSize: Int = 2,
171+
override var stride: Int = 2,
172+
private val id: String
173+
) : MAXPOOL2D {
174+
fun create(): Module = MaxPool2d(
175+
kernelSize = kernelSize,
176+
stride = stride,
177+
name = id
178+
)
179+
}
180+
160181
private class NeuralNetworkDslImpl : NeuralNetworkDsl {
161182

162183
val modules = mutableListOf<Module>()
@@ -186,6 +207,14 @@ private class NeuralNetworkDslImpl : NeuralNetworkDsl {
186207
modules += impl.create()
187208
}
188209

210+
override fun maxPool2d(id: String, content: MAXPOOL2D.() -> Unit) {
211+
val impl = MaxPool2dImpl(
212+
id = getDefaultName(id, "maxPool2d", modules.size)
213+
)
214+
impl.content()
215+
modules += impl.create()
216+
}
217+
189218
override fun dense(outputDimension: Int, id: String, content: DENSE.() -> Unit) {
190219
val inputDimension = lastDimension
191220
lastDimension = outputDimension
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package sk.ai.net.nn
2+
3+
import sk.ai.net.Shape
4+
import sk.ai.net.Tensor
5+
import sk.ai.net.impl.DoublesTensor
6+
7+
/**
8+
* 2D max pooling layer.
9+
* Works with tensors of shape (N, C, H, W) or (C, H, W).
10+
*/
11+
class MaxPool2d(
12+
val kernelSize: Int,
13+
val stride: Int = kernelSize,
14+
override val name: String = "MaxPool2d"
15+
) : Module() {
16+
override val modules: List<Module>
17+
get() = emptyList()
18+
19+
override fun forward(input: Tensor): Tensor = maxPool2d(input)
20+
21+
private fun maxPool2d(input: Tensor): Tensor {
22+
val tensor = input as DoublesTensor
23+
val shape = tensor.shape
24+
require(shape.rank == 3 || shape.rank == 4) {
25+
"MaxPool2d expected 3D or 4D input tensor, but got shape $shape"
26+
}
27+
val batchSize: Int
28+
val channels: Int
29+
val height: Int
30+
val width: Int
31+
if (shape.rank == 4) {
32+
batchSize = shape[0]
33+
channels = shape[1]
34+
height = shape[2]
35+
width = shape[3]
36+
} else {
37+
batchSize = 1
38+
channels = shape[0]
39+
height = shape[1]
40+
width = shape[2]
41+
}
42+
43+
val outH = (height - kernelSize) / stride + 1
44+
val outW = (width - kernelSize) / stride + 1
45+
val outElements = DoubleArray(batchSize * channels * outH * outW)
46+
var idx = 0
47+
for (n in 0 until batchSize) {
48+
for (c in 0 until channels) {
49+
for (i in 0 until outH) {
50+
for (j in 0 until outW) {
51+
var maxVal = Double.NEGATIVE_INFINITY
52+
for (ki in 0 until kernelSize) {
53+
for (kj in 0 until kernelSize) {
54+
val h = i * stride + ki
55+
val w = j * stride + kj
56+
val value = if (shape.rank == 4) {
57+
tensor[n, c, h, w]
58+
} else {
59+
tensor[c, h, w]
60+
}
61+
if (value > maxVal) maxVal = value
62+
}
63+
}
64+
outElements[idx++] = maxVal
65+
}
66+
}
67+
}
68+
}
69+
val outShape = Shape(batchSize, channels, outH, outW)
70+
return DoublesTensor(outShape, outElements)
71+
}
72+
}

core/src/commonTest/kotlin/sk/ai/net/ShapeTest.kt

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,43 @@ package sk.ai.net
22

33
import kotlin.test.Test
44
import kotlin.test.assertEquals
5+
import kotlin.test.assertFalse
6+
import kotlin.test.assertContentEquals
57

68
class ShapeTest {
79

810
@Test
9-
fun `test scalar`() {
10-
val shape = Shape(0)
11-
assertEquals(shape, Shape(1, 2))
11+
fun volumeIsProductOfDimensions() {
12+
val shape = Shape(2, 3, 4)
13+
assertEquals(24, shape.volume)
1214
}
13-
}
15+
16+
@Test
17+
fun rankIsNumberOfDimensions() {
18+
val shape = Shape(2, 3)
19+
assertEquals(2, shape.rank)
20+
}
21+
22+
@Test
23+
fun equalityChecksAllDimensions() {
24+
val shape1 = Shape(2, 3)
25+
val shape2 = Shape(2, 3)
26+
val shape3 = Shape(3, 2)
27+
assertEquals(shape1, shape2)
28+
assertFalse(shape1 == shape3)
29+
}
30+
31+
@Test
32+
fun constructorCopiesDimensions() {
33+
val dims = intArrayOf(2, 3)
34+
val shape = Shape(*dims)
35+
dims[0] = 5
36+
assertContentEquals(intArrayOf(2, 3), shape.dimensions)
37+
}
38+
39+
@Test
40+
fun toStringContainsSizeInformation() {
41+
val shape = Shape(2, 3)
42+
assertEquals("Shape: Dimensions = [2 x 3], Size (Volume) = 6", shape.toString())
43+
}
44+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package sk.ai.net.nn
2+
3+
import sk.ai.net.Shape
4+
import sk.ai.net.dsl.network
5+
import sk.ai.net.impl.DoublesTensor
6+
import kotlin.test.Test
7+
import kotlin.test.assertContentEquals
8+
import kotlin.test.assertEquals
9+
import kotlin.test.assertTrue
10+
11+
class MaxPool2dTest {
12+
@Test
13+
fun max_pool2d_basic() {
14+
val input = DoublesTensor(
15+
Shape(1, 1, 4, 4),
16+
doubleArrayOf(
17+
1.0, 2.0, 3.0, 4.0,
18+
5.0, 6.0, 7.0, 8.0,
19+
9.0, 10.0, 11.0, 12.0,
20+
13.0, 14.0, 15.0, 16.0
21+
)
22+
)
23+
val pool = MaxPool2d(kernelSize = 2, stride = 2)
24+
val result = pool.forward(input) as DoublesTensor
25+
assertEquals(Shape(1, 1, 2, 2), result.shape)
26+
assertContentEquals(doubleArrayOf(6.0, 8.0, 14.0, 16.0), result.elements)
27+
}
28+
29+
@Test
30+
fun dsl_support() {
31+
val module = network {
32+
input(1)
33+
maxPool2d {
34+
kernelSize = 2
35+
stride = 2
36+
}
37+
}
38+
val mlp = module as sk.ai.net.nn.topology.MLP
39+
assertTrue(mlp.modules[1] is MaxPool2d)
40+
val mp = mlp.modules[1] as MaxPool2d
41+
assertEquals(2, mp.kernelSize)
42+
assertEquals(2, mp.stride)
43+
}
44+
}

gguf/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import org.jetbrains.kotlin.gradle.dsl.JvmTarget
44
plugins {
55
alias(libs.plugins.kotlinMultiplatform)
66
alias(libs.plugins.androidLibrary)
7+
alias(libs.plugins.dokka)
78
alias(libs.plugins.vanniktech.mavenPublish)
89
}
910

gradle/libs.versions.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ kotlinSerialization = { id = "org.jetbrains.kotlin.plugin.serialization", versio
2626
binaryCompatibility = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibility" }
2727
modulegraph-souza = { id = "dev.iurysouza.modulegraph", version.ref = "moduleGraphSouza" }
2828
vanniktech-mavenPublish = { id = "com.vanniktech.maven.publish", version = "0.32.0" }
29+
dokka = { id = "org.jetbrains.dokka", version = "1.9.20" }

io/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ plugins {
55
alias(libs.plugins.kotlinMultiplatform)
66
alias(libs.plugins.androidLibrary)
77
alias(libs.plugins.kotlinSerialization)
8+
alias(libs.plugins.dokka)
89

910
alias(libs.plugins.vanniktech.mavenPublish)
1011
}

0 commit comments

Comments
 (0)