Skip to content

Commit 982b697

Browse files
authored
Merge pull request #6 from skainet-dev/codex/add-maxpool2d-support
Add MaxPool2d layer and DSL
2 parents f128ed1 + 4228e25 commit 982b697

3 files changed

Lines changed: 145 additions & 0 deletions

File tree

core/src/commonMain/kotlin/sk/ai/net/dsl/NetworkBuilder.kt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import sk.ai.net.nn.Flatten
77
import sk.ai.net.nn.Input
88
import sk.ai.net.nn.Linear
99
import sk.ai.net.nn.Conv2d
10+
import sk.ai.net.nn.MaxPool2d
1011
import sk.ai.net.nn.Module
1112
import sk.ai.net.nn.topology.MLP
1213

@@ -31,6 +32,8 @@ interface NeuralNetworkDsl : NetworkDslItem {
3132

3233
fun conv2d(id: String = "", content: CONV2D.() -> Unit = {})
3334

35+
fun maxPool2d(id: String = "", content: MAXPOOL2D.() -> Unit = {})
36+
3437
fun dense(outputDimension: Int, id: String = "", content: DENSE.() -> Unit = {})
3538
}
3639

@@ -55,6 +58,12 @@ interface CONV2D : NetworkDslItem {
5558
var padding: Int
5659
}
5760

61+
@NetworkDsl
62+
interface MAXPOOL2D : NetworkDslItem {
63+
var kernelSize: Int
64+
var stride: Int
65+
}
66+
5867

5968
private fun getDefaultName(id: String, s: String, size: Int): String {
6069
if (id.isNotEmpty()) return id
@@ -157,6 +166,18 @@ class Conv2dImpl(
157166
)
158167
}
159168

169+
class MaxPool2dImpl(
170+
override var kernelSize: Int = 2,
171+
override var stride: Int = 2,
172+
private val id: String
173+
) : MAXPOOL2D {
174+
fun create(): Module = MaxPool2d(
175+
kernelSize = kernelSize,
176+
stride = stride,
177+
name = id
178+
)
179+
}
180+
160181
private class NeuralNetworkDslImpl : NeuralNetworkDsl {
161182

162183
val modules = mutableListOf<Module>()
@@ -186,6 +207,14 @@ private class NeuralNetworkDslImpl : NeuralNetworkDsl {
186207
modules += impl.create()
187208
}
188209

210+
override fun maxPool2d(id: String, content: MAXPOOL2D.() -> Unit) {
211+
val impl = MaxPool2dImpl(
212+
id = getDefaultName(id, "maxPool2d", modules.size)
213+
)
214+
impl.content()
215+
modules += impl.create()
216+
}
217+
189218
override fun dense(outputDimension: Int, id: String, content: DENSE.() -> Unit) {
190219
val inputDimension = lastDimension
191220
lastDimension = outputDimension
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package sk.ai.net.nn
2+
3+
import sk.ai.net.Shape
4+
import sk.ai.net.Tensor
5+
import sk.ai.net.impl.DoublesTensor
6+
7+
/**
8+
* 2D max pooling layer.
9+
* Works with tensors of shape (N, C, H, W) or (C, H, W).
10+
*/
11+
class MaxPool2d(
12+
val kernelSize: Int,
13+
val stride: Int = kernelSize,
14+
override val name: String = "MaxPool2d"
15+
) : Module() {
16+
override val modules: List<Module>
17+
get() = emptyList()
18+
19+
override fun forward(input: Tensor): Tensor = maxPool2d(input)
20+
21+
private fun maxPool2d(input: Tensor): Tensor {
22+
val tensor = input as DoublesTensor
23+
val shape = tensor.shape
24+
require(shape.rank == 3 || shape.rank == 4) {
25+
"MaxPool2d expected 3D or 4D input tensor, but got shape $shape"
26+
}
27+
val batchSize: Int
28+
val channels: Int
29+
val height: Int
30+
val width: Int
31+
if (shape.rank == 4) {
32+
batchSize = shape[0]
33+
channels = shape[1]
34+
height = shape[2]
35+
width = shape[3]
36+
} else {
37+
batchSize = 1
38+
channels = shape[0]
39+
height = shape[1]
40+
width = shape[2]
41+
}
42+
43+
val outH = (height - kernelSize) / stride + 1
44+
val outW = (width - kernelSize) / stride + 1
45+
val outElements = DoubleArray(batchSize * channels * outH * outW)
46+
var idx = 0
47+
for (n in 0 until batchSize) {
48+
for (c in 0 until channels) {
49+
for (i in 0 until outH) {
50+
for (j in 0 until outW) {
51+
var maxVal = Double.NEGATIVE_INFINITY
52+
for (ki in 0 until kernelSize) {
53+
for (kj in 0 until kernelSize) {
54+
val h = i * stride + ki
55+
val w = j * stride + kj
56+
val value = if (shape.rank == 4) {
57+
tensor[n, c, h, w]
58+
} else {
59+
tensor[c, h, w]
60+
}
61+
if (value > maxVal) maxVal = value
62+
}
63+
}
64+
outElements[idx++] = maxVal
65+
}
66+
}
67+
}
68+
}
69+
val outShape = Shape(batchSize, channels, outH, outW)
70+
return DoublesTensor(outShape, outElements)
71+
}
72+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package sk.ai.net.nn
2+
3+
import sk.ai.net.Shape
4+
import sk.ai.net.dsl.network
5+
import sk.ai.net.impl.DoublesTensor
6+
import kotlin.test.Test
7+
import kotlin.test.assertContentEquals
8+
import kotlin.test.assertEquals
9+
import kotlin.test.assertTrue
10+
11+
class MaxPool2dTest {
12+
@Test
13+
fun max_pool2d_basic() {
14+
val input = DoublesTensor(
15+
Shape(1, 1, 4, 4),
16+
doubleArrayOf(
17+
1.0, 2.0, 3.0, 4.0,
18+
5.0, 6.0, 7.0, 8.0,
19+
9.0, 10.0, 11.0, 12.0,
20+
13.0, 14.0, 15.0, 16.0
21+
)
22+
)
23+
val pool = MaxPool2d(kernelSize = 2, stride = 2)
24+
val result = pool.forward(input) as DoublesTensor
25+
assertEquals(Shape(1, 1, 2, 2), result.shape)
26+
assertContentEquals(doubleArrayOf(6.0, 8.0, 14.0, 16.0), result.elements)
27+
}
28+
29+
@Test
30+
fun dsl_support() {
31+
val module = network {
32+
input(1)
33+
maxPool2d {
34+
kernelSize = 2
35+
stride = 2
36+
}
37+
}
38+
val mlp = module as sk.ai.net.nn.topology.MLP
39+
assertTrue(mlp.modules[1] is MaxPool2d)
40+
val mp = mlp.modules[1] as MaxPool2d
41+
assertEquals(2, mp.kernelSize)
42+
assertEquals(2, mp.stride)
43+
}
44+
}

0 commit comments

Comments
 (0)