Skip to content

Commit a5de82f

Browse files
committed
Add initial implementation of Conv2d. Not working yet
Related-To: #4
1 parent 09fa0ef commit a5de82f

8 files changed

Lines changed: 190 additions & 17 deletions

File tree

build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ plugins {
1010

1111
allprojects {
1212
group = "sk.ai.net"
13-
version = "0.0.5"
13+
version = "0.0.6-SNAPSHOT"
1414
}
1515

1616
moduleGraphConfig {

core/src/commonMain/kotlin/sk/ai/net/Shape.kt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package sk.ai.net
22

3+
import sk.ai.net.impl.assert
34
import sk.ai.net.impl.zipFold
45

56
class Shape(vararg dimensions: Int) {
@@ -11,6 +12,20 @@ class Shape(vararg dimensions: Int) {
1112
val rank: Int
1213
get() = dimensions.size
1314

15+
internal fun index(indices: IntArray): Int {
16+
assert(
17+
{ indices.size == dimensions.size },
18+
{ "`indices.size` must be ${dimensions.size}: ${indices.size}" })
19+
return dimensions.zip(indices).fold(0) { a, x ->
20+
assert({ 0 <= x.second && x.second < x.first }, { "Illegal index: indices = ${indices}, shape = $shape" })
21+
a * x.first + x.second
22+
}
23+
}
24+
25+
operator fun get(vararg indices: Int): Int {
26+
return dimensions[index(indices)]
27+
}
28+
1429
override fun equals(other: Any?): Boolean {
1530
if (other !is Shape) {
1631
return false

core/src/commonMain/kotlin/sk/ai/net/Tensor.kt

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,30 +63,29 @@ interface Tensor {
6363
fun cos(): Tensor
6464

6565
fun tan(): Tensor
66-
66+
6767
fun asin(): Tensor
6868

6969
fun acos(): Tensor
70-
71-
fun atan(): Tensor
7270

73-
fun sinh():Tensor
71+
fun atan(): Tensor
7472

75-
fun cosh():Tensor
73+
fun sinh(): Tensor
7674

77-
fun tanh():Tensor
75+
fun cosh(): Tensor
7876

79-
fun exp():Tensor
77+
fun tanh(): Tensor
8078

81-
fun log():Tensor
79+
fun exp(): Tensor
8280

83-
fun sqrt():Tensor
81+
fun log(): Tensor
8482

85-
fun cbrt():Tensor
83+
fun sqrt(): Tensor
8684

87-
fun sigmoid():Tensor
85+
fun cbrt(): Tensor
8886

89-
fun ln():Tensor
87+
fun sigmoid(): Tensor
9088

89+
fun ln(): Tensor
9190
}
9291

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
package sk.ai.net
22

3+
import sk.ai.net.impl.BuiltInDoubleDataDescriptor
4+
import sk.ai.net.impl.DoublesTensor
5+
import kotlin.random.Random
6+
37
interface TensorFactory {
48
fun createTensor(shape: Shape, dataDescriptor: DataDescriptor, elements: DoubleArray): Tensor
5-
}
9+
}
10+
11+
fun rand(shape: Shape, dataDescriptor: DataDescriptor = BuiltInDoubleDataDescriptor()): Tensor {
12+
val random: Random = Random.Default
13+
14+
return DoublesTensor(shape, DoubleArray(shape.volume) { random.nextFloat().toDouble() })
15+
}

core/src/commonMain/kotlin/sk/ai/net/impl/DoublesTensor.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import sk.ai.net.Tensor
66
import kotlin.collections.map
77
import kotlin.math.exp
88
import kotlin.math.pow
9+
import kotlin.random.Random
910

1011
data class DoublesTensor(override val shape: Shape, val elements: DoubleArray) : TypedTensor<Double> {
1112
constructor(shape: Shape, element: Double = 0.0) : this(
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package sk.ai.net.nn
2+
3+
import sk.ai.net.Shape
4+
import sk.ai.net.Tensor
5+
import sk.ai.net.impl.DoublesTensor
6+
import sk.ai.net.rand
7+
import kotlin.math.sqrt
8+
9+
class Conv2d(
10+
val inChannels: Int,
11+
val outChannels: Int,
12+
val kernelSize: Int,
13+
val stride: Int = 1,
14+
val padding: Int = 0,
15+
useBias: Boolean = true
16+
) {
17+
val weight: Tensor
18+
val bias: Tensor?
19+
20+
init {
21+
// Initialize weights and bias
22+
val fanIn = inChannels * kernelSize * kernelSize
23+
val bound = 1f / sqrt(fanIn.toDouble()).toFloat() // 1/sqrt(fanIn)
24+
// Weight: uniform in [-bound, bound]
25+
weight = (((rand(
26+
Shape(
27+
outChannels,
28+
inChannels,
29+
kernelSize,
30+
kernelSize
31+
)
32+
) as DoublesTensor) * (2f * bound).toDouble()) as DoublesTensor) - bound.toDouble()
33+
// Bias: uniform in [-bound, bound] if enabled
34+
bias = if (useBias) {
35+
((rand(Shape(outChannels)) as DoublesTensor) * (2f * bound).toDouble()) - bound.toDouble()
36+
} else {
37+
null
38+
}
39+
}
40+
41+
operator fun invoke(input: Tensor): Tensor {
42+
// Ensure input has 3D or 4D shape
43+
val shape = input.shape // assume shape is a list or array of dimensions
44+
require(shape.rank == 3 || shape.rank == 4) {
45+
"Conv2d expected 3D or 4D input tensor, but got shape ${shape}."
46+
}
47+
// Determine batch size and input dims
48+
val batchSize: Int
49+
val inC: Int
50+
val inH: Int
51+
val inW: Int
52+
if (shape.rank == 4) {
53+
batchSize = shape.dimensions[0]
54+
inC = shape[1]
55+
inH = shape[2]
56+
inW = shape[3]
57+
} else {
58+
// if 3D (C, H, W), treat as batch of size 1
59+
batchSize = 1
60+
inC = shape[0]
61+
inH = shape[1]
62+
inW = shape[2]
63+
}
64+
require(inC == inChannels) {
65+
"Conv2d expected input channel count $inChannels, but got $inC."
66+
}
67+
68+
// Compute output spatial dimensions
69+
val outH = (inH + 2 * padding - kernelSize) / stride + 1
70+
val outW = (inW + 2 * padding - kernelSize) / stride + 1
71+
require(outH > 0 && outW > 0) {
72+
"Conv2d output size is invalid (outH=$outH, outW=$outW). Check input dimensions and padding."
73+
}
74+
75+
// Apply padding if needed
76+
val paddedInput: Tensor = if (padding > 0) {
77+
val paddedH = inH + 2 * padding
78+
val paddedW = inW + 2 * padding
79+
val temp = Tensor.zeros(batchSize, inC, paddedH, paddedW)
80+
for (n in 0 until batchSize) {
81+
for (c in 0 until inC) {
82+
for (i in 0 until inH) {
83+
for (j in 0 until inW) {
84+
temp[n, c, i + padding, j + padding] = input[n, c, i, j]
85+
}
86+
}
87+
}
88+
}
89+
temp
90+
} else {
91+
input // no padding needed
92+
}
93+
94+
// Prepare output tensor
95+
val output = Tensor.zeros(batchSize, outChannels, outH, outW)
96+
97+
// Convolution: iterate over batch, out channels, and output spatial positions
98+
for (n in 0 until batchSize) {
99+
for (oc in 0 until outChannels) {
100+
val biasVal = if (bias != null) bias[oc] else 0f
101+
for (i in 0 until outH) {
102+
for (j in 0 until outW) {
103+
var sum = 0f
104+
// Sum over all input channels and kernel elements
105+
for (c in 0 until inChannels) {
106+
for (ki in 0 until kernelSize) {
107+
for (kj in 0 until kernelSize) {
108+
sum += paddedInput[n, c, i * stride + ki, j * stride + kj] *
109+
weight[oc, c, ki, kj]
110+
}
111+
}
112+
}
113+
// Add bias and assign to output
114+
output[n, oc, i, j] = sum + biasVal
115+
}
116+
}
117+
}
118+
}
119+
return output
120+
}
121+
}

gguf/src/commonMain/kotlin/sk/ai/net/gguf/GGUFReader.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ class GGUFReader(source: Source) {
9292
}
9393

9494
/** Retrieve a metadata field as a list of Strings (for array-of-string fields) */
95-
fun getStringList(key: String): List<String>? {
96-
val field = this.fields[key] ?: return null
95+
fun getStringList(key: String): List<String> {
96+
val field = this.fields[key] ?: return emptyList()
9797
// Expect an array of strings: types[0] == ARRAY and types[1] == STRING (per format)
9898
if (field.types.size >= 2 &&
9999
field.types[0] == GGUFValueType.ARRAY && field.types[1] == GGUFValueType.STRING
@@ -104,7 +104,7 @@ class GGUFReader(source: Source) {
104104
byteList.toUByteArray().toByteArray().decodeToString()
105105
}
106106
}
107-
return null // Not an array-of-strings field
107+
return emptyList() // Not an array-of-strings field
108108
}
109109

110110

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package sk.ai.net.gguf
2+
3+
import junit.framework.Assert.assertEquals
4+
import kotlinx.io.asSource
5+
import kotlinx.io.buffered
6+
import org.junit.Test
7+
8+
9+
class GGUFReaderTest {
10+
11+
@Test
12+
fun testReadMetadataFields() {
13+
javaClass.getResourceAsStream("/skainet-small.gguf").use { inputStream ->
14+
15+
val reader = GGUFReader(inputStream.asSource().buffered())
16+
17+
// Verify the 'model_name' metadata is correct
18+
val modelName = reader.getString("model_name")
19+
assertEquals("model_name should match", "skainet-small", modelName)
20+
21+
// Verify the 'authors' metadata list is correct
22+
val authorsList = reader.getStringList("authors")
23+
assertEquals("authors list should match", 2, authorsList.size)
24+
//assertEquals (listOf("Alice", "Bob"), authorsList, "authors list should match")
25+
}
26+
}
27+
}

0 commit comments

Comments
 (0)