Skip to content

Commit 5e4f377

Browse files
committed
Init weight in Linear layer with zeros. Also add initial implementation of CsvParametersLoader using BufferedSource.
1 parent 8415003 commit 5e4f377

8 files changed

Lines changed: 71 additions & 11 deletions

File tree

build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@ plugins {
22
alias(libs.plugins.androidLibrary) apply false
33
alias(libs.plugins.kotlinMultiplatform) apply false
44
alias(libs.plugins.jetbrainsKotlinJvm) apply false
5+
alias(libs.plugins.binaryCompatibility) apply false
6+
57
}

core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import org.jetbrains.kotlin.gradle.dsl.JvmTarget
44
plugins {
55
alias(libs.plugins.kotlinMultiplatform)
66
alias(libs.plugins.androidLibrary)
7+
alias(libs.plugins.binaryCompatibility)
78
id("module.publication")
89
}
910

core/src/commonMain/kotlin/sk/ai/net/Shape.kt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package sk.ai.net
22

3+
import sk.ai.net.impl.zipFold
4+
35
class Shape(vararg dimensions: Int) {
46
val dimensions: IntArray = dimensions.copyOf()
57

@@ -8,4 +10,29 @@ class Shape(vararg dimensions: Int) {
810

911
val rank: Int
1012
get() = dimensions.size
13+
14+
override fun equals(other: Any?): Boolean {
15+
if (other !is Shape) {
16+
return false
17+
}
18+
19+
return dimensions.size == other.dimensions.size && zipFold(dimensions, other.dimensions, true) { result, a, b ->
20+
if (!result) {
21+
return false
22+
}
23+
a == b
24+
}
25+
}
26+
27+
override fun hashCode(): Int {
28+
return dimensions.hashCode()
29+
}
30+
31+
override fun toString(): String {
32+
// Create a string representation of the dimensions array
33+
val dimensionsString = dimensions.joinToString(separator = " x ", prefix = "[", postfix = "]")
34+
// Return the formatted string including dimensions and volume
35+
return "Shape: Dimensions = $dimensionsString, Size (Volume) = $volume"
36+
}
37+
1138
}

core/src/commonMain/kotlin/sk/ai/net/dsl/NetworkBuilder.kt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,36 @@ private fun getDefaultName(id: String, s: String, size: Int): String {
4141
}
4242

4343

44+
fun createLinear(
45+
inFeatures: Int,
46+
outFeatures: Int,
47+
myInitWeights: Tensor? = null,
48+
myInitBias: Tensor? = null
49+
): Linear {
50+
return when {
51+
myInitWeights != null && myInitBias != null ->
52+
Linear(inFeatures, outFeatures, initWeights = myInitWeights, initBias = myInitBias)
53+
myInitWeights != null ->
54+
Linear(inFeatures, outFeatures, initWeights = myInitWeights)
55+
myInitBias != null ->
56+
Linear(inFeatures, outFeatures, initBias = myInitBias)
57+
else ->
58+
Linear(inFeatures, outFeatures)
59+
}
60+
}
61+
4462
class DenseImpl(
4563
private val inputDimension: Int, private val outputDimension: Int, private val id: String
4664
) : DENSE {
4765

4866
private var weightsValue: Tensor? = null
67+
private var biasValue: Tensor? = null
4968
private var _activation: (Tensor) -> Tensor = { tensor -> tensor }
5069

5170
fun create(): List<Module> {
71+
5272
return listOf(
53-
Linear(inputDimension, outputDimension, id, weightsValue!!, weightsValue!!),
73+
createLinear(inputDimension, outputDimension, weightsValue, biasValue),
5474
ActivationsWrapperModule(activation, "activation")
5575
)
5676
}
@@ -66,7 +86,7 @@ class DenseImpl(
6686
}
6787

6888
override fun bias(initBlock: (Shape) -> Tensor) {
69-
weightsValue = initBlock(Shape(outputDimension))
89+
biasValue = initBlock(Shape(outputDimension))
7090
}
7191
}
7292

core/src/commonMain/kotlin/sk/ai/net/nn/Linear.kt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package sk.ai.net.nn
22

3+
import sk.ai.net.Shape
34
import sk.ai.net.Tensor
5+
import sk.ai.net.impl.DoublesTensor
46

57
/**
68
* Linear layer (a.k.a. fully connected dense layer). This layer applies a linear transformation to the input data.
@@ -17,8 +19,14 @@ class Linear(
1719
inFeatures: Int,
1820
outFeatures: Int,
1921
override val name: String = "Linear",
20-
val initWeights: Tensor,
21-
val initBias: Tensor,
22+
val initWeights: Tensor = DoublesTensor(
23+
Shape(outFeatures, inFeatures),
24+
List(inFeatures * outFeatures) { 0.0 }.map { it }.toDoubleArray()
25+
),
26+
val initBias: Tensor = DoublesTensor(
27+
Shape(outFeatures),
28+
List(outFeatures) { 0.0 }.map { it }.toDoubleArray()
29+
),
2230
override val params: List<ModuleParameter> = listOf(
2331
ModuleParameter("weight", initWeights),
2432
ModuleParameter("bias", initBias)

gradle/libs.versions.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ android-compileSdk = "35"
66
kotlinxSerializationJson = "1.8.0"
77
nexus-publish = "2.0.0"
88
jetbrainsKotlinJvm = "1.9.22"
9-
testng = "6.9.6"
9+
testng = "7.10.2"
1010
okio = "3.9.1"
1111
okioNodefilesystem = "3.9.0"
12+
binaryCompatibility = "0.17.0"
1213

1314
[libraries]
1415
kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" }
@@ -22,4 +23,5 @@ okio-nodefilesystem = { module = "com.squareup.okio:okio-nodefilesystem", versio
2223
androidLibrary = { id = "com.android.library", version.ref = "agp" }
2324
kotlinMultiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
2425
jetbrainsKotlinJvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "jetbrainsKotlinJvm" }
25-
kotlinSerialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
26+
kotlinSerialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
27+
binaryCompatibility = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibility" }

io/src/commonMain/kotlin/sk/ai/net/io/csv/CsvParameterLoader.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
package sk.ai.net.io.csv
22

33
import kotlinx.serialization.json.Json
4-
import okio.Path
4+
import okio.BufferedSource
5+
import okio.use
56
import sk.ai.net.Shape
67

78
import sk.ai.net.Tensor
89
import sk.ai.net.impl.DoublesTensor
910
import sk.ai.net.io.ParametersLoader
1011

11-
class CsvParametersLoader(private val fileSystem: okio.FileSystem, private val modelWeightPath: Path) :
12+
class CsvParametersLoader(private val handleSource: () -> BufferedSource) :
1213
ParametersLoader {
1314
override suspend fun load(onTensorLoaded: (String, Tensor) -> Unit) {
14-
fileSystem.read(modelWeightPath) {
15+
handleSource().use { source ->
1516
// Initialize Json object
1617
val json = Json { ignoreUnknownKeys = true }
17-
1818
// Deserialize JSON to Kotlin objects
19-
json.decodeFromString<List<ArrayValues>>(this.readUtf8()).also { values ->
19+
json.decodeFromString<List<ArrayValues>>(source.readUtf8()).also { values ->
2020
values.forEach { (name, array) ->
2121
val tensor = DoublesTensor(Shape(*array.shape.toIntArray()), array.values.toDoubleArray())
2222
onTensorLoaded(name, tensor)

io/src/jvmTest/resources/sinus-approximator.json

Whitespace-only changes.

0 commit comments

Comments
 (0)