Skip to content

Commit 0b7c418

Browse files
committed
Fix erro in loading GGUF fully connected model
1 parent 67d3613 commit 0b7c418

10 files changed

Lines changed: 98 additions & 13 deletions

File tree

MNISTDemo/.java-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
21.0

MNISTDemo/composeApp/src/commonMain/kotlin/sk/ai/net/samples/kmp/mnist/demo/DrawingScreenViewModel.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class DrawingScreenViewModel(handleSource: () -> Source) : ViewModel() {
2525
@Suppress("UnusedPrivateMember")
2626
private val handleSourceFn = handleSource
2727

28-
// Selected model (default to CNN). Exposed to UI for selection.
29-
var selectedModelId by mutableStateOf(ModelId.CNN_MNIST)
28+
// Selected model (default to MLP). Exposed to UI for selection.
29+
var selectedModelId by mutableStateOf(ModelId.MLP_MNIST)
3030
private set
3131

3232
// Current model status

MNISTDemo/composeApp/src/commonMain/kotlin/sk/ai/net/samples/kmp/mnist/demo/settings/AppSettings.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ enum class ModelStatus {
1818
* Minimal solution to support model selection from Settings.
1919
*/
2020
object AppSettings {
21-
private val _selectedModelId = MutableStateFlow(ModelId.CNN_MNIST)
21+
private val _selectedModelId = MutableStateFlow(ModelId.MLP_MNIST)
2222
val selectedModelId: StateFlow<ModelId> = _selectedModelId.asStateFlow()
2323

2424
// Status for each model
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package sk.ainet.lang.model
2+
3+
import kotlinx.io.Buffer
4+
import kotlinx.io.Source
5+
import sk.ainet.io.gguf.GGUFReader
6+
import sk.ainet.lang.nn.Module
7+
import sk.ainet.lang.types.FP32
8+
9+
/**
10+
* Android implementation using reflection-based GGUF loader.
11+
* Same approach as JVM since Android supports reflection.
12+
*/
13+
actual fun loadWeightsFromBytes(module: Module<FP32, Float>, bytes: ByteArray) {
14+
val source: Source = Buffer().apply { write(bytes) }
15+
val reader = GGUFReader(source)
16+
val tensorMap = reader.tensors.associateBy { it.name }
17+
18+
module.trainableParameters().forEach { param ->
19+
val readerTensor = tensorMap[param.name]
20+
if (readerTensor != null) {
21+
val tensorData = param.value.data
22+
// Use reflection to access the underlying buffer array
23+
val bufferField = tensorData::class.java.declaredFields.firstOrNull {
24+
it.type.isArray && it.type.componentType == Float::class.javaPrimitiveType
25+
}
26+
if (bufferField != null) {
27+
bufferField.isAccessible = true
28+
val array = bufferField.get(tensorData) as FloatArray
29+
readerTensor.data.forEachIndexed { idx, value ->
30+
array[idx] = (value as Number).toFloat()
31+
}
32+
}
33+
}
34+
}
35+
}

MNISTDemo/shared/src/commonMain/kotlin/sk/ainet/clean/framework/inference/CnnInferenceModuleAdapter.kt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ package sk.ainet.clean.framework.inference
33
import sk.ainet.clean.data.image.GrayScale28To28Image
44
import sk.ainet.lang.model.createMNISTCNN
55
import sk.ainet.lang.model.classifyImage
6-
import sk.ainet.lang.model.loader.loadModelWeights
7-
import kotlinx.io.Buffer
8-
import kotlinx.io.Source
6+
import sk.ainet.lang.model.loadWeightsFromBytes
97
import sk.ainet.clean.domain.port.InferenceModule
108
import sk.ainet.lang.nn.Module
119
import sk.ainet.lang.types.FP32
@@ -30,8 +28,7 @@ class CnnInferenceModuleAdapter(
3028
fun fromModule(module: Module<FP32, Float>): CnnInferenceModuleAdapter {
3129
return CnnInferenceModuleAdapter(
3230
loadFn = { bytes ->
33-
val src: Source = Buffer().apply { write(bytes) }
34-
loadModelWeights(module, src)
31+
loadWeightsFromBytes(module, bytes)
3532
},
3633
inferFn = { image -> classifyImage(module, image) }
3734
)

MNISTDemo/shared/src/commonMain/kotlin/sk/ainet/clean/framework/inference/MlpInferenceModuleAdapter.kt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ package sk.ainet.clean.framework.inference
33
import sk.ainet.clean.data.image.GrayScale28To28Image
44
import sk.ainet.lang.model.createMNISTMLP
55
import sk.ainet.lang.model.classifyImage
6-
import sk.ainet.lang.model.loader.loadModelWeights
7-
import kotlinx.io.Buffer
8-
import kotlinx.io.Source
6+
import sk.ainet.lang.model.loadWeightsFromBytes
97
import sk.ainet.clean.domain.port.InferenceModule
108
import sk.ainet.lang.nn.Module
119

@@ -29,8 +27,7 @@ class MlpInferenceModuleAdapter(
2927
fun fromModule(module: Module<FP32, Float>): MlpInferenceModuleAdapter {
3028
return MlpInferenceModuleAdapter(
3129
loadFn = { bytes ->
32-
val src: Source = Buffer().apply { write(bytes) }
33-
loadModelWeights(module, src)
30+
loadWeightsFromBytes(module, bytes)
3431
},
3532
inferFn = { image -> classifyImage(module, image) }
3633
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package sk.ainet.lang.model
2+
3+
import sk.ainet.lang.nn.Module
4+
import sk.ainet.lang.types.FP32
5+
6+
/**
7+
* Platform-specific weight loader for GGUF files.
8+
* On JVM, uses reflection-based loader that correctly handles PyTorch-exported GGUF.
9+
* On other platforms, falls back to library loader.
10+
*/
11+
expect fun loadWeightsFromBytes(module: Module<FP32, Float>, bytes: ByteArray)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package sk.ainet.lang.model
2+
3+
import kotlinx.io.Buffer
4+
import kotlinx.io.Source
5+
import sk.ainet.lang.model.loader.loadModelWeights
6+
import sk.ainet.lang.nn.Module
7+
import sk.ainet.lang.types.FP32
8+
9+
/**
10+
* JS implementation using library loader.
11+
* Note: This may not work correctly with PyTorch-exported GGUF files.
12+
*/
13+
actual fun loadWeightsFromBytes(module: Module<FP32, Float>, bytes: ByteArray) {
14+
val source: Source = Buffer().apply { write(bytes) }
15+
loadModelWeights(module, source)
16+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package sk.ainet.lang.model
2+
3+
import sk.ainet.lang.nn.Module
4+
import sk.ainet.lang.types.FP32
5+
6+
/**
7+
* JVM implementation using reflection-based GGUF loader.
8+
* This correctly loads PyTorch-exported GGUF weights.
9+
*/
10+
actual fun loadWeightsFromBytes(module: Module<FP32, Float>, bytes: ByteArray) {
11+
loadGgufWeights(module, bytes)
12+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package sk.ainet.lang.model
2+
3+
import kotlinx.io.Buffer
4+
import kotlinx.io.Source
5+
import sk.ainet.lang.model.loader.loadModelWeights
6+
import sk.ainet.lang.nn.Module
7+
import sk.ainet.lang.types.FP32
8+
9+
/**
10+
* WasmJS implementation using library loader.
11+
* Note: This may not work correctly with PyTorch-exported GGUF files.
12+
*/
13+
actual fun loadWeightsFromBytes(module: Module<FP32, Float>, bytes: ByteArray) {
14+
val source: Source = Buffer().apply { write(bytes) }
15+
loadModelWeights(module, source)
16+
}

0 commit comments

Comments
 (0)