Skip to content

Commit 276d63a

Browse files
Merge pull request #12 from sk-ai-net/feature/gguf-kmp
Feature/gguf kmp
2 parents f1b362f + f279c81 commit 276d63a

32 files changed

Lines changed: 1172 additions & 216 deletions

File tree

core/build.gradle.kts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ kotlin {
2424
iosArm64()
2525
iosSimulatorArm64()
2626
wasmJs().nodejs()
27+
macosX64 ()
28+
linuxX64 ()
29+
2730

2831
sourceSets {
2932
val commonMain by getting {
3033
dependencies {
31-
//put your multiplatform dependencies here
3234
}
3335
}
3436
val commonTest by getting {

core/src/iosMain/kotlin/sk/ai/net/performance/Measure.ios.kt renamed to core/src/appleMain/kotlin/sk/ai/net/performance/Measure.apple.kt

File renamed without changes.

core/src/commonMain/kotlin/sk/ai/net/dsl/NetworkBuilder.kt

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,28 @@ private fun getDefaultName(id: String, s: String, size: Int): String {
4444
fun createLinear(
4545
inFeatures: Int,
4646
outFeatures: Int,
47+
id: String,
4748
myInitWeights: Tensor? = null,
4849
myInitBias: Tensor? = null
4950
): Linear {
5051
return when {
5152
myInitWeights != null && myInitBias != null ->
52-
Linear(inFeatures, outFeatures, initWeights = myInitWeights, initBias = myInitBias)
53+
Linear(
54+
inFeatures = inFeatures,
55+
outFeatures = outFeatures,
56+
name = id,
57+
initWeights = myInitWeights,
58+
initBias = myInitBias
59+
)
60+
5361
myInitWeights != null ->
54-
Linear(inFeatures, outFeatures, initWeights = myInitWeights)
62+
Linear(inFeatures = inFeatures, outFeatures = outFeatures, name = id, initWeights = myInitWeights)
63+
5564
myInitBias != null ->
56-
Linear(inFeatures, outFeatures, initBias = myInitBias)
65+
Linear(inFeatures = inFeatures, outFeatures = outFeatures, name = id, initBias = myInitBias)
66+
5767
else ->
58-
Linear(inFeatures, outFeatures)
68+
Linear(inFeatures = inFeatures, outFeatures = outFeatures, name = id)
5969
}
6070
}
6171

@@ -69,8 +79,16 @@ class DenseImpl(
6979

7080
fun create(): List<Module> {
7181

82+
val linear = createLinear(
83+
inFeatures = inputDimension,
84+
outFeatures = outputDimension,
85+
id = id,
86+
myInitWeights = weightsValue,
87+
myInitBias = biasValue
88+
)
89+
7290
return listOf(
73-
createLinear(inputDimension, outputDimension, weightsValue, biasValue),
91+
linear,
7492
ActivationsWrapperModule(activation, "activation")
7593
)
7694
}

core/src/commonMain/kotlin/sk/ai/net/impl/DoublesTensor.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,3 +410,5 @@ data class DoublesTensor(override val shape: Shape, val elements: DoubleArray) :
410410
return DoublesTensor(shape, softmaxElements)
411411
}
412412
}
413+
414+
fun DoublesTensor.prod(): Double = this.elements.fold(1.0) { acc, element -> acc * element }

core/src/commonMain/kotlin/sk/ai/net/nn/Linear.kt

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ package sk.ai.net.nn
33
import sk.ai.net.Shape
44
import sk.ai.net.Tensor
55
import sk.ai.net.impl.DoublesTensor
6+
import sk.ai.net.nn.reflection.ModuleParameter
7+
import sk.ai.net.nn.reflection.ModuleParameters
8+
import sk.ai.net.nn.reflection.bias
9+
import sk.ai.net.nn.reflection.weights
610

711
/**
812
* Linear layer (a.k.a. fully connected dense layer). This layer applies a linear transformation to the input data.
@@ -19,26 +23,26 @@ class Linear(
1923
inFeatures: Int,
2024
outFeatures: Int,
2125
override val name: String = "Linear",
22-
val initWeights: Tensor = DoublesTensor(
26+
initWeights: Tensor = DoublesTensor(
2327
Shape(outFeatures, inFeatures),
2428
List(inFeatures * outFeatures) { 0.0 }.map { it }.toDoubleArray()
2529
),
26-
val initBias: Tensor = DoublesTensor(
30+
initBias: Tensor = DoublesTensor(
2731
Shape(outFeatures),
2832
List(outFeatures) { 0.0 }.map { it }.toDoubleArray()
2933
),
3034
override val params: List<ModuleParameter> = listOf(
31-
ModuleParameter("weight", initWeights),
32-
ModuleParameter("bias", initBias)
35+
ModuleParameter.WeightParameter("$name.weight", initWeights),
36+
ModuleParameter.BiasParameter("$name.bias", initBias)
3337
),
3438
) : Module(), ModuleParameters {
3539

3640
override val modules: List<Module>
3741
get() = emptyList()
3842

3943
override fun forward(input: Tensor): Tensor {
40-
val weight = initWeights
41-
val bias = initBias
44+
val weight = params.weights().value
45+
val bias = params.bias().value
4246

4347
// matrix multiplication on tensors and addition
4448
return input.matmul(weight.t()) + bias

core/src/commonMain/kotlin/sk/ai/net/nn/ModuleParameters.kt

Lines changed: 0 additions & 13 deletions
This file was deleted.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package sk.ai.net.nn.reflection
2+
3+
import sk.ai.net.Shape
4+
import sk.ai.net.Tensor
5+
import sk.ai.net.impl.DoublesTensor
6+
import sk.ai.net.impl.prod
7+
import sk.ai.net.nn.Module
8+
import sk.ai.net.nn.reflection.table.table
9+
10+
data class NodeSummary(val name: String, val input: Shape, val output: Shape, val params: Long)
11+
12+
class Summary {
13+
14+
val nodes = mutableListOf<NodeSummary>()
15+
16+
private fun nodeSummary(index: Int, module: Module, input: Shape, output: Tensor): NodeSummary {
17+
var params = 0L
18+
19+
if (module is ModuleParameters) {
20+
21+
module.params.by("W")?.let { weight ->
22+
val dimension =
23+
DoublesTensor(weight.value.shape, weight.value.shape.dimensions.map { it.toDouble() }.toDoubleArray())
24+
params += dimension.prod().toLong()
25+
}
26+
27+
module.params.by("B")?.let { bias ->
28+
val dimension =
29+
DoublesTensor(bias.value.shape, bias.value.shape.dimensions.map { it.toDouble() }.toDoubleArray())
30+
params += dimension.prod().toLong()
31+
}
32+
}
33+
34+
35+
return NodeSummary(
36+
module.name,
37+
input,
38+
output.shape,
39+
params
40+
)
41+
}
42+
43+
44+
fun summary(model: Module, input: Shape, batch_size: Int = -1): List<NodeSummary> {
45+
var data = DoublesTensor(input, List(input.volume) { 0.0 }.toDoubleArray())
46+
var count = 1
47+
model.modules.forEach { module ->
48+
val moduleInput = data
49+
data = module.forward(moduleInput) as DoublesTensor
50+
val nodeSummary = nodeSummary(count, module, moduleInput.shape, data)
51+
if (nodeSummary.params > 0) {
52+
count++
53+
nodes.add(nodeSummary)
54+
}
55+
}
56+
return nodes
57+
}
58+
59+
fun printSummary(nodes: List<NodeSummary>) =
60+
table {
61+
cellStyle {
62+
border = true
63+
}
64+
header {
65+
row {
66+
cell("Layer (type)")
67+
cell("Output Shape")
68+
cell("Param #")
69+
}
70+
}
71+
nodes.forEach { node ->
72+
row {
73+
cell(node.name)
74+
cell(node.output.toString())
75+
cell(node.params)
76+
}
77+
}
78+
}.toString()
79+
}
80+
81+
fun Module.summary(input: Shape, batch_size: Int = -1): String {
82+
val summary = Summary()
83+
val nodes = summary.summary(this, input, batch_size)
84+
return summary.printSummary(nodes)
85+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package sk.ai.net.nn.reflection
2+
3+
import sk.ai.net.Tensor
4+
5+
sealed class ModuleParameter {
6+
abstract val name: String
7+
abstract var value: Tensor
8+
9+
data class WeightParameter(
10+
override val name: String,
11+
override var value: Tensor
12+
) : ModuleParameter()
13+
14+
data class BiasParameter(
15+
override val name: String,
16+
override var value: Tensor
17+
) : ModuleParameter()
18+
}
19+
20+
interface ModuleParameters {
21+
val params: List<ModuleParameter>
22+
}
23+
24+
public fun List<ModuleParameter>.by(name: String): ModuleParameter? =
25+
firstOrNull { namedParameter -> namedParameter.name.uppercase().contains(name.uppercase()) }
26+
27+
// Returns the first BiasParameter or throws a NoSuchElementException if none is found.
28+
fun List<ModuleParameter>.bias(): ModuleParameter.BiasParameter =
29+
this.filterIsInstance<ModuleParameter.BiasParameter>()
30+
.firstOrNull() ?: throw NoSuchElementException("No bias parameter found!")
31+
32+
// Returns the first WeightParameter or throws a NoSuchElementException if none is found.
33+
fun List<ModuleParameter>.weights(): ModuleParameter.WeightParameter =
34+
this.filterIsInstance<ModuleParameter.WeightParameter>()
35+
.firstOrNull() ?: throw NoSuchElementException("No weight parameter found!")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package sk.ai.net.nn.reflection
2+
3+
import sk.ai.net.nn.Module
4+
5+
// Extension function to generate a custom string representation for a Module
6+
fun Module.toCustomString(indent: String = ""): String {
7+
// Build the string for the current module
8+
val builder = StringBuilder()
9+
builder.append("$indent$name")
10+
11+
// If there are child modules, recursively add their representations with increased indent
12+
if (modules.isNotEmpty()) {
13+
modules.forEach { child ->
14+
builder.append("\n")
15+
builder.append(child.toCustomString("$indent "))
16+
}
17+
}
18+
return builder.toString()
19+
}
20+
21+
// Extension function for the root that prints the module tree hierarchy.
22+
fun Module.toVisualString(): String {
23+
val builder = StringBuilder()
24+
// Print the root node (without any branch symbols)
25+
builder.append(name).append("\n")
26+
// For each child, call the helper function with an empty initial prefix.
27+
modules.forEachIndexed { index, module ->
28+
val isLast = index == modules.lastIndex
29+
builder.append(module.toVisualStringHelper("", isLast))
30+
}
31+
return builder.toString()
32+
}
33+
34+
// Private helper extension function that handles the branch prefixes.
35+
private fun Module.toVisualStringHelper(prefix: String, isLast: Boolean): String {
36+
val builder = StringBuilder()
37+
// Append the current prefix and the branch symbols:
38+
// "└── " if this node is the last child, otherwise "├── "
39+
builder.append(prefix)
40+
builder.append(if (isLast) "└── " else "├── ")
41+
builder.append(name)
42+
builder.append("\n")
43+
44+
// Update the prefix for children:
45+
// If this node is the last, add spaces; otherwise add a vertical bar and spaces.
46+
val newPrefix = prefix + if (isLast) " " else ""
47+
48+
// Recursively process all child modules.
49+
modules.forEachIndexed { index, module ->
50+
val childIsLast = index == modules.lastIndex
51+
builder.append(module.toVisualStringHelper(newPrefix, childIsLast))
52+
}
53+
return builder.toString()
54+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package sk.ai.net.nn.reflection
2+
3+
import sk.ai.net.nn.Module
4+
5+
fun flattenParams(module: Module): List<ModuleParameter> {
6+
val params = mutableListOf<ModuleParameter>()
7+
for (m in module.modules) {
8+
params.addAll(flattenParams(m))
9+
}
10+
if (module is ModuleParameters) {
11+
params.addAll(module.params)
12+
}
13+
return params
14+
}

0 commit comments

Comments
 (0)