|
| 1 | +package sk.ai.net.io.gguf |
| 2 | + |
| 3 | +import kotlinx.io.Source |
| 4 | +import sk.ai.net.Shape |
| 5 | +import sk.ai.net.Tensor |
| 6 | +import sk.ai.net.impl.DoublesTensor |
| 7 | +import sk.ai.net.io.ParametersLoader |
| 8 | +import sk.ai.net.gguf.GGUFReader |
| 9 | +import sk.ai.net.gguf.GGMLQuantizationType |
| 10 | + |
| 11 | +/** |
| 12 | + * A parameters loader that loads tensors from a GGUF file. |
| 13 | + * |
| 14 | + * @param handleSource A function that returns a Source for the GGUF file. |
| 15 | + */ |
| 16 | +class GGUFParametersLoader(private val handleSource: () -> Source) : ParametersLoader { |
| 17 | + override suspend fun load(onTensorLoaded: (String, Tensor) -> Unit) { |
| 18 | + handleSource().use { source -> |
| 19 | + val reader = GGUFReader(source) |
| 20 | + |
| 21 | + // Process each tensor in the GGUF file |
| 22 | + reader.tensors.forEach { tensor -> |
| 23 | + // Convert the tensor data to a DoublesTensor |
| 24 | + val shapeArray = tensor.shape.map { it.toInt() }.toIntArray() |
| 25 | + val shape = Shape(*shapeArray) |
| 26 | + |
| 27 | + // Convert the tensor data to a DoubleArray |
| 28 | + val doubleValues = when (tensor.tensorType) { |
| 29 | + GGMLQuantizationType.F32 -> (tensor.data as List<*>).map { (it as Float).toDouble() }.toDoubleArray() |
| 30 | + GGMLQuantizationType.F64 -> (tensor.data as List<*>).map { it as Double }.toDoubleArray() |
| 31 | + GGMLQuantizationType.I8 -> (tensor.data as List<*>).map { (it as Byte).toDouble() }.toDoubleArray() |
| 32 | + GGMLQuantizationType.I16 -> (tensor.data as List<*>).map { (it as Short).toDouble() }.toDoubleArray() |
| 33 | + GGMLQuantizationType.I32 -> (tensor.data as List<*>).map { (it as Int).toDouble() }.toDoubleArray() |
| 34 | + GGMLQuantizationType.I64 -> (tensor.data as List<*>).map { (it as Long).toDouble() }.toDoubleArray() |
| 35 | + else -> throw IllegalArgumentException("Unsupported tensor type: ${tensor.tensorType}") |
| 36 | + } |
| 37 | + |
| 38 | + // Create a DoublesTensor with the shape and values |
| 39 | + val doublesTensor = DoublesTensor(shape, doubleValues) |
| 40 | + |
| 41 | + // Call the callback with the tensor name and the DoublesTensor |
| 42 | + onTensorLoaded(tensor.name, doublesTensor) |
| 43 | + } |
| 44 | + } |
| 45 | + } |
| 46 | +} |
0 commit comments