Skip to content

Commit 3a836a6

Browse files
committed
Add GGUF paremeters loader for fixing the local build
1 parent c5d069e commit 3a836a6

2 files changed

Lines changed: 47 additions & 0 deletions

File tree

io/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ kotlin {
3131
val commonMain by getting {
3232
dependencies {
3333
implementation(project(":core"))
34+
implementation(project(":gguf"))
3435
implementation(libs.kotlinx.io.core)
3536

3637
implementation(libs.kotlinx.serialization.json)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package sk.ai.net.io.gguf
2+
3+
import kotlinx.io.Source
4+
import sk.ai.net.Shape
5+
import sk.ai.net.Tensor
6+
import sk.ai.net.impl.DoublesTensor
7+
import sk.ai.net.io.ParametersLoader
8+
import sk.ai.net.gguf.GGUFReader
9+
import sk.ai.net.gguf.GGMLQuantizationType
10+
11+
/**
12+
* A parameters loader that loads tensors from a GGUF file.
13+
*
14+
* @param handleSource A function that returns a Source for the GGUF file.
15+
*/
16+
class GGUFParametersLoader(private val handleSource: () -> Source) : ParametersLoader {
17+
override suspend fun load(onTensorLoaded: (String, Tensor) -> Unit) {
18+
handleSource().use { source ->
19+
val reader = GGUFReader(source)
20+
21+
// Process each tensor in the GGUF file
22+
reader.tensors.forEach { tensor ->
23+
// Convert the tensor data to a DoublesTensor
24+
val shapeArray = tensor.shape.map { it.toInt() }.toIntArray()
25+
val shape = Shape(*shapeArray)
26+
27+
// Convert the tensor data to a DoubleArray
28+
val doubleValues = when (tensor.tensorType) {
29+
GGMLQuantizationType.F32 -> (tensor.data as List<*>).map { (it as Float).toDouble() }.toDoubleArray()
30+
GGMLQuantizationType.F64 -> (tensor.data as List<*>).map { it as Double }.toDoubleArray()
31+
GGMLQuantizationType.I8 -> (tensor.data as List<*>).map { (it as Byte).toDouble() }.toDoubleArray()
32+
GGMLQuantizationType.I16 -> (tensor.data as List<*>).map { (it as Short).toDouble() }.toDoubleArray()
33+
GGMLQuantizationType.I32 -> (tensor.data as List<*>).map { (it as Int).toDouble() }.toDoubleArray()
34+
GGMLQuantizationType.I64 -> (tensor.data as List<*>).map { (it as Long).toDouble() }.toDoubleArray()
35+
else -> throw IllegalArgumentException("Unsupported tensor type: ${tensor.tensorType}")
36+
}
37+
38+
// Create a DoublesTensor with the shape and values
39+
val doublesTensor = DoublesTensor(shape, doubleValues)
40+
41+
// Call the callback with the tensor name and the DoublesTensor
42+
onTensorLoaded(tensor.name, doublesTensor)
43+
}
44+
}
45+
}
46+
}

0 commit comments

Comments
 (0)