Skip to content

Commit a5a2678

Browse files
committed
Add Dropout layer implementation with DSL support and tests.
Related-To: #5
1 parent cf6849f commit a5a2678

4 files changed

Lines changed: 216 additions & 1 deletion

File tree

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Build Feature Branch Apk
1+
name: Build and test project
22

33
on: [ push, pull_request ]
44

core/src/commonMain/kotlin/sk/ai/net/dsl/NetworkBuilder.kt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sk.ai.net.dsl
33
import sk.ai.net.nn.activations.ActivationsWrapperModule
44
import sk.ai.net.Shape
55
import sk.ai.net.Tensor
6+
import sk.ai.net.nn.Dropout
67
import sk.ai.net.nn.Flatten
78
import sk.ai.net.nn.Input
89
import sk.ai.net.nn.Linear
@@ -34,6 +35,8 @@ interface NeuralNetworkDsl : NetworkDslItem {
3435

3536
fun maxPool2d(id: String = "", content: MAXPOOL2D.() -> Unit = {})
3637

38+
fun dropout(id: String = "", content: DROPOUT.() -> Unit = {})
39+
3740
fun dense(outputDimension: Int, id: String = "", content: DENSE.() -> Unit = {})
3841
}
3942

@@ -64,6 +67,12 @@ interface MAXPOOL2D : NetworkDslItem {
6467
var stride: Int
6568
}
6669

70+
@NetworkDsl
71+
interface DROPOUT : NetworkDslItem {
72+
var p: Double
73+
var inplace: Boolean
74+
}
75+
6776

6877
private fun getDefaultName(id: String, s: String, size: Int): String {
6978
if (id.isNotEmpty()) return id
@@ -178,6 +187,18 @@ class MaxPool2dImpl(
178187
)
179188
}
180189

190+
class DropoutImpl(
191+
override var p: Double = 0.5,
192+
override var inplace: Boolean = false,
193+
private val id: String
194+
) : DROPOUT {
195+
fun create(): Module = Dropout(
196+
p = p,
197+
inplace = inplace,
198+
name = id
199+
)
200+
}
201+
181202
private class NeuralNetworkDslImpl : NeuralNetworkDsl {
182203

183204
val modules = mutableListOf<Module>()
@@ -215,6 +236,14 @@ private class NeuralNetworkDslImpl : NeuralNetworkDsl {
215236
modules += impl.create()
216237
}
217238

239+
override fun dropout(id: String, content: DROPOUT.() -> Unit) {
240+
val impl = DropoutImpl(
241+
id = getDefaultName(id, "dropout", modules.size)
242+
)
243+
impl.content()
244+
modules += impl.create()
245+
}
246+
218247
override fun dense(outputDimension: Int, id: String, content: DENSE.() -> Unit) {
219248
val inputDimension = lastDimension
220249
lastDimension = outputDimension
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package sk.ai.net.nn
2+
3+
import sk.ai.net.Shape
4+
import sk.ai.net.Tensor
5+
import sk.ai.net.impl.DoublesTensor
6+
import kotlin.random.Random
7+
8+
/**
9+
* Dropout layer that randomly zeroes some of the elements of the input tensor with probability p.
10+
*
11+
* During training, each element is zeroed with probability p, and the remaining elements are scaled by 1/(1-p)
12+
* to maintain the same expected sum. During evaluation, this module simply returns the input.
13+
*
14+
* @param p Probability of an element to be zeroed. Default: 0.5
15+
* @param inplace Whether to do the operation in-place. Default: false
16+
* @param name Name of the module. Default: "Dropout"
17+
*/
18+
class Dropout(
19+
val p: Double = 0.5,
20+
val inplace: Boolean = false,
21+
override val name: String = "Dropout"
22+
) : Module() {
23+
init {
24+
require(p in 0.0..1.0) { "Dropout probability has to be between 0 and 1, but got $p" }
25+
}
26+
27+
private var training: Boolean = true
28+
29+
/**
30+
* Sets the module in training mode.
31+
*/
32+
fun train() {
33+
training = true
34+
}
35+
36+
/**
37+
* Sets the module in evaluation mode.
38+
*/
39+
fun eval() {
40+
training = false
41+
}
42+
43+
override val modules: List<Module>
44+
get() = emptyList()
45+
46+
override fun forward(input: Tensor): Tensor {
47+
if (!training || p == 0.0) {
48+
return input
49+
}
50+
51+
if (p == 1.0) {
52+
// If p is 1, drop everything
53+
val zeros = DoubleArray(input.size) { 0.0 }
54+
return DoublesTensor(input.shape, zeros)
55+
}
56+
57+
// Create a binary mask with the same shape as the input
58+
val maskElements = DoubleArray(input.size) {
59+
if (Random.nextDouble() < p) 0.0 else 1.0 / (1.0 - p)
60+
}
61+
val mask = DoublesTensor(input.shape, maskElements)
62+
63+
// Apply the mask to the input
64+
return when (input) {
65+
is DoublesTensor -> input.times(mask as DoublesTensor)
66+
else -> {
67+
// Convert input to DoublesTensor if it's not already
68+
val inputAsDoubles = input as? DoublesTensor
69+
?: throw IllegalArgumentException("Input tensor must be a DoublesTensor")
70+
inputAsDoubles.times(mask as DoublesTensor)
71+
}
72+
}
73+
}
74+
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package sk.ai.net.nn
2+
3+
import sk.ai.net.Shape
4+
import sk.ai.net.impl.DoublesTensor
5+
import kotlin.test.Test
6+
import kotlin.test.assertEquals
7+
import kotlin.test.assertContentEquals
8+
import kotlin.test.assertNotEquals
9+
10+
class DropoutTest {
11+
@Test
12+
fun `dropout in training mode with p=0_5`() {
13+
val tensor = DoublesTensor(
14+
Shape(2, 3),
15+
doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)
16+
)
17+
val dropout = Dropout(p = 0.5)
18+
dropout.train() // Ensure training mode
19+
20+
val result = dropout.forward(tensor) as DoublesTensor
21+
22+
// Check that shape is preserved
23+
assertEquals(Shape(2, 3), result.shape)
24+
25+
// Check that some elements are zeroed out (this is probabilistic, but with p=0.5 it's very likely)
26+
// and that non-zero elements are scaled by 1/(1-p) = 2
27+
var hasZeros = false
28+
var hasScaledValues = false
29+
30+
for (i in 0 until tensor.size) {
31+
if (result.elements[i] == 0.0) {
32+
hasZeros = true
33+
} else if (result.elements[i] == tensor.elements[i] * 2.0) {
34+
hasScaledValues = true
35+
}
36+
}
37+
38+
// Assert that we have both zeros and scaled values
39+
// Note: This is a probabilistic test, so there's a very small chance it could fail
40+
// even if the implementation is correct
41+
kotlin.test.assertTrue(hasZeros, "Dropout should zero out some elements")
42+
kotlin.test.assertTrue(hasScaledValues, "Dropout should scale non-zero elements by 1/(1-p)")
43+
}
44+
45+
@Test
46+
fun `dropout in evaluation mode`() {
47+
val tensor = DoublesTensor(
48+
Shape(2, 3),
49+
doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)
50+
)
51+
val dropout = Dropout(p = 0.5)
52+
dropout.eval() // Set to evaluation mode
53+
54+
val result = dropout.forward(tensor) as DoublesTensor
55+
56+
// In evaluation mode, dropout should return the input unchanged
57+
assertEquals(Shape(2, 3), result.shape)
58+
assertContentEquals(tensor.elements, result.elements)
59+
}
60+
61+
@Test
62+
fun `dropout with p=0`() {
63+
val tensor = DoublesTensor(
64+
Shape(2, 3),
65+
doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)
66+
)
67+
val dropout = Dropout(p = 0.0)
68+
dropout.train() // Ensure training mode
69+
70+
val result = dropout.forward(tensor) as DoublesTensor
71+
72+
// With p=0, dropout should return the input unchanged
73+
assertEquals(Shape(2, 3), result.shape)
74+
assertContentEquals(tensor.elements, result.elements)
75+
}
76+
77+
@Test
78+
fun `dropout with p=1`() {
79+
val tensor = DoublesTensor(
80+
Shape(2, 3),
81+
doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)
82+
)
83+
val dropout = Dropout(p = 1.0)
84+
dropout.train() // Ensure training mode
85+
86+
val result = dropout.forward(tensor) as DoublesTensor
87+
88+
// With p=1, dropout should zero out all elements
89+
assertEquals(Shape(2, 3), result.shape)
90+
assertContentEquals(DoubleArray(tensor.size) { 0.0 }, result.elements)
91+
}
92+
93+
@Test
94+
fun `dropout preserves tensor shape`() {
95+
val shapes = listOf(
96+
Shape(1, 10),
97+
Shape(5, 5),
98+
Shape(2, 3, 4),
99+
Shape(1, 2, 3, 4)
100+
)
101+
102+
for (shape in shapes) {
103+
val tensor = DoublesTensor(shape, DoubleArray(shape.volume) { 1.0 })
104+
val dropout = Dropout(p = 0.5)
105+
dropout.train()
106+
107+
val result = dropout.forward(tensor) as DoublesTensor
108+
109+
assertEquals(shape, result.shape, "Dropout should preserve tensor shape")
110+
}
111+
}
112+
}

0 commit comments

Comments
 (0)