Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions KDOC_PREPROCESSING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# KDoc preprocessing

This project follows the same approach as [Kotlin/DataFrame](https://github.com/Kotlin/dataframe/blob/master/KDOC_PREPROCESSING.md) for preprocessing KDoc comments before generating documentation with Dokka.

The Dokka plugin is applied to all modules. Run `./gradlew dokkaHtml` to generate HTML documentation under each module's `build/dokka` directory.
9 changes: 9 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import org.jetbrains.dokka.gradle.DokkaMultiModuleTask

plugins {
alias(libs.plugins.androidLibrary) apply false
alias(libs.plugins.kotlinMultiplatform) apply false
alias(libs.plugins.jetbrainsKotlinJvm) apply false
alias(libs.plugins.binaryCompatibility) apply false
alias(libs.plugins.dokka) apply false
alias(libs.plugins.modulegraph.souza) apply true
}

apply(plugin = "org.jetbrains.dokka")

allprojects {
group = "sk.ai.net"
version = "0.0.6-SNAPSHOT"
Expand All @@ -15,3 +20,7 @@ moduleGraphConfig {
readmePath.set("./Modules.md")
heading = "### Module Graph"
}

tasks.register<org.jetbrains.dokka.gradle.DokkaMultiModuleTask>("dokkaHtmlMultiModule") {
outputDirectory.set(buildDir.resolve("dokka"))
}
1 change: 1 addition & 0 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ plugins {
alias(libs.plugins.kotlinMultiplatform)
alias(libs.plugins.androidLibrary)
alias(libs.plugins.binaryCompatibility)
alias(libs.plugins.dokka)
alias(libs.plugins.vanniktech.mavenPublish)
}

Expand Down
29 changes: 29 additions & 0 deletions core/src/commonMain/kotlin/sk/ai/net/dsl/NetworkBuilder.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import sk.ai.net.nn.Flatten
import sk.ai.net.nn.Input
import sk.ai.net.nn.Linear
import sk.ai.net.nn.Conv2d
import sk.ai.net.nn.MaxPool2d
import sk.ai.net.nn.Module
import sk.ai.net.nn.topology.MLP

Expand All @@ -31,6 +32,8 @@ interface NeuralNetworkDsl : NetworkDslItem {

fun conv2d(id: String = "", content: CONV2D.() -> Unit = {})

fun maxPool2d(id: String = "", content: MAXPOOL2D.() -> Unit = {})

fun dense(outputDimension: Int, id: String = "", content: DENSE.() -> Unit = {})
}

Expand All @@ -55,6 +58,12 @@ interface CONV2D : NetworkDslItem {
var padding: Int
}

@NetworkDsl
interface MAXPOOL2D : NetworkDslItem {
var kernelSize: Int
var stride: Int
}


private fun getDefaultName(id: String, s: String, size: Int): String {
if (id.isNotEmpty()) return id
Expand Down Expand Up @@ -157,6 +166,18 @@ class Conv2dImpl(
)
}

class MaxPool2dImpl(
override var kernelSize: Int = 2,
override var stride: Int = 2,
private val id: String
) : MAXPOOL2D {
fun create(): Module = MaxPool2d(
kernelSize = kernelSize,
stride = stride,
name = id
)
}

private class NeuralNetworkDslImpl : NeuralNetworkDsl {

val modules = mutableListOf<Module>()
Expand Down Expand Up @@ -186,6 +207,14 @@ private class NeuralNetworkDslImpl : NeuralNetworkDsl {
modules += impl.create()
}

override fun maxPool2d(id: String, content: MAXPOOL2D.() -> Unit) {
val impl = MaxPool2dImpl(
id = getDefaultName(id, "maxPool2d", modules.size)
)
impl.content()
modules += impl.create()
}

override fun dense(outputDimension: Int, id: String, content: DENSE.() -> Unit) {
val inputDimension = lastDimension
lastDimension = outputDimension
Expand Down
72 changes: 72 additions & 0 deletions core/src/commonMain/kotlin/sk/ai/net/nn/MaxPool2d.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package sk.ai.net.nn

import sk.ai.net.Shape
import sk.ai.net.Tensor
import sk.ai.net.impl.DoublesTensor

/**
* 2D max pooling layer.
* Works with tensors of shape (N, C, H, W) or (C, H, W).
*/
class MaxPool2d(
val kernelSize: Int,
val stride: Int = kernelSize,
override val name: String = "MaxPool2d"
) : Module() {
override val modules: List<Module>
get() = emptyList()

override fun forward(input: Tensor): Tensor = maxPool2d(input)

private fun maxPool2d(input: Tensor): Tensor {
val tensor = input as DoublesTensor
val shape = tensor.shape
require(shape.rank == 3 || shape.rank == 4) {
"MaxPool2d expected 3D or 4D input tensor, but got shape $shape"
}
val batchSize: Int
val channels: Int
val height: Int
val width: Int
if (shape.rank == 4) {
batchSize = shape[0]
channels = shape[1]
height = shape[2]
width = shape[3]
} else {
batchSize = 1
channels = shape[0]
height = shape[1]
width = shape[2]
}

val outH = (height - kernelSize) / stride + 1
val outW = (width - kernelSize) / stride + 1
val outElements = DoubleArray(batchSize * channels * outH * outW)
var idx = 0
for (n in 0 until batchSize) {
for (c in 0 until channels) {
for (i in 0 until outH) {
for (j in 0 until outW) {
var maxVal = Double.NEGATIVE_INFINITY
for (ki in 0 until kernelSize) {
for (kj in 0 until kernelSize) {
val h = i * stride + ki
val w = j * stride + kj
val value = if (shape.rank == 4) {
tensor[n, c, h, w]
} else {
tensor[c, h, w]
}
if (value > maxVal) maxVal = value
}
}
outElements[idx++] = maxVal
}
}
}
}
val outShape = Shape(batchSize, channels, outH, outW)
return DoublesTensor(outShape, outElements)
}
}
39 changes: 35 additions & 4 deletions core/src/commonTest/kotlin/sk/ai/net/ShapeTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,43 @@ package sk.ai.net

import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertContentEquals

class ShapeTest {

@Test
fun `test scalar`() {
val shape = Shape(0)
assertEquals(shape, Shape(1, 2))
fun volumeIsProductOfDimensions() {
val shape = Shape(2, 3, 4)
assertEquals(24, shape.volume)
}
}

@Test
fun rankIsNumberOfDimensions() {
val shape = Shape(2, 3)
assertEquals(2, shape.rank)
}

@Test
fun equalityChecksAllDimensions() {
val shape1 = Shape(2, 3)
val shape2 = Shape(2, 3)
val shape3 = Shape(3, 2)
assertEquals(shape1, shape2)
assertFalse(shape1 == shape3)
}

@Test
fun constructorCopiesDimensions() {
val dims = intArrayOf(2, 3)
val shape = Shape(*dims)
dims[0] = 5
assertContentEquals(intArrayOf(2, 3), shape.dimensions)
}

@Test
fun toStringContainsSizeInformation() {
val shape = Shape(2, 3)
assertEquals("Shape: Dimensions = [2 x 3], Size (Volume) = 6", shape.toString())
}
}
44 changes: 44 additions & 0 deletions core/src/commonTest/kotlin/sk/ai/net/nn/MaxPool2dTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package sk.ai.net.nn

import sk.ai.net.Shape
import sk.ai.net.dsl.network
import sk.ai.net.impl.DoublesTensor
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertTrue

class MaxPool2dTest {
@Test
fun max_pool2d_basic() {
val input = DoublesTensor(
Shape(1, 1, 4, 4),
doubleArrayOf(
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0
)
)
val pool = MaxPool2d(kernelSize = 2, stride = 2)
val result = pool.forward(input) as DoublesTensor
assertEquals(Shape(1, 1, 2, 2), result.shape)
assertContentEquals(doubleArrayOf(6.0, 8.0, 14.0, 16.0), result.elements)
}

@Test
fun dsl_support() {
val module = network {
input(1)
maxPool2d {
kernelSize = 2
stride = 2
}
}
val mlp = module as sk.ai.net.nn.topology.MLP
assertTrue(mlp.modules[1] is MaxPool2d)
val mp = mlp.modules[1] as MaxPool2d
assertEquals(2, mp.kernelSize)
assertEquals(2, mp.stride)
}
}
1 change: 1 addition & 0 deletions gguf/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.jetbrains.kotlin.gradle.dsl.JvmTarget
plugins {
alias(libs.plugins.kotlinMultiplatform)
alias(libs.plugins.androidLibrary)
alias(libs.plugins.dokka)
alias(libs.plugins.vanniktech.mavenPublish)
}

Expand Down
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ kotlinSerialization = { id = "org.jetbrains.kotlin.plugin.serialization", versio
binaryCompatibility = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibility" }
modulegraph-souza = { id = "dev.iurysouza.modulegraph", version.ref = "moduleGraphSouza" }
vanniktech-mavenPublish = { id = "com.vanniktech.maven.publish", version = "0.32.0" }
dokka = { id = "org.jetbrains.dokka", version = "1.9.20" }
1 change: 1 addition & 0 deletions io/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ plugins {
alias(libs.plugins.kotlinMultiplatform)
alias(libs.plugins.androidLibrary)
alias(libs.plugins.kotlinSerialization)
alias(libs.plugins.dokka)

alias(libs.plugins.vanniktech.mavenPublish)
}
Expand Down
Loading