Skip to content

Commit fe20517

Browse files
committed
pretrained gguf FC model from python works
1 parent fb60f06 commit fe20517

16 files changed

Lines changed: 1404 additions & 4 deletions

File tree

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"permissions": {
3+
"allow": [
4+
"Bash(./gradlew :cli:compileKotlin:*)",
5+
"Bash(./gradlew :cli:test:*)",
6+
"Bash(./gradlew :cli:run:*)",
7+
"Bash(python3:*)",
8+
"Bash(./gradlew :shared:jvmTest:*)",
9+
"Bash(uv venv:*)",
10+
"Bash(uv pip install:*)",
11+
"Bash(source:*)",
12+
"Bash(ls:*)",
13+
"Bash(./gradlew:*)"
14+
]
15+
}
16+
}

MNISTDemo/cli/build.gradle.kts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
plugins {
2+
alias(libs.plugins.kotlinJvm)
3+
application
4+
}
5+
6+
kotlin {
7+
jvmToolchain(21)
8+
}
9+
10+
dependencies {
11+
implementation(project(":shared"))
12+
implementation(libs.kotlinx.coroutines)
13+
implementation(libs.kotlinx.io.core)
14+
15+
// JVM-optimized SKaiNET backend for inference
16+
implementation(libs.skainet.backend.cpu.jvm)
17+
18+
// Testing
19+
testImplementation(kotlin("test"))
20+
testImplementation(libs.kotlinx.coroutines.test)
21+
22+
// SKaiNET dependencies for tests that need direct access to Module
23+
testImplementation(libs.skainet.lang.core)
24+
testImplementation(libs.skainet.io.gguf)
25+
}
26+
27+
application {
28+
mainClass.set("sk.ainet.cli.MainKt")
29+
}
30+
31+
tasks.named<JavaExec>("run") {
32+
standardInput = System.`in`
33+
}
34+
35+
tasks.withType<Test> {
36+
maxHeapSize = "4g"
37+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package sk.ainet.cli
2+
3+
/**
4+
* CLI argument parser for MNIST classifier.
5+
*/
6+
object ArgsParser {
7+
8+
/**
9+
* Parses command line arguments into a CliConfig.
10+
*/
11+
fun parse(args: Array<String>): CliConfig {
12+
var model = "mlp"
13+
var weightsDir: String? = null
14+
var modelFile: String? = null
15+
var imagePath: String? = null
16+
var invert = false
17+
var debug = false
18+
var help = false
19+
20+
var i = 0
21+
while (i < args.size) {
22+
when (args[i]) {
23+
"--model" -> {
24+
i++
25+
if (i < args.size) {
26+
model = args[i]
27+
}
28+
}
29+
"--weights-dir" -> {
30+
i++
31+
if (i < args.size) {
32+
weightsDir = args[i]
33+
}
34+
}
35+
"--model-file" -> {
36+
i++
37+
if (i < args.size) {
38+
modelFile = args[i]
39+
}
40+
}
41+
"--invert" -> invert = true
42+
"--debug" -> debug = true
43+
"--help", "-h" -> help = true
44+
else -> {
45+
if (!args[i].startsWith("-")) {
46+
imagePath = args[i]
47+
}
48+
}
49+
}
50+
i++
51+
}
52+
53+
return CliConfig(model, weightsDir, modelFile, imagePath, invert, debug, help)
54+
}
55+
}
56+
57+
/**
58+
* Configuration parsed from CLI arguments.
59+
*/
60+
data class CliConfig(
61+
val model: String = "mlp",
62+
val weightsDir: String? = null,
63+
val modelFile: String? = null,
64+
val imagePath: String? = null,
65+
val invert: Boolean = false,
66+
val debug: Boolean = false,
67+
val help: Boolean = false,
68+
)
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package sk.ainet.cli
2+
3+
import sk.ainet.clean.domain.port.InferenceModule
4+
import sk.ainet.clean.framework.inference.CnnInferenceModuleAdapter
5+
import sk.ainet.clean.framework.inference.MlpInferenceModuleAdapter
6+
import sk.ainet.cli.io.ImageLoader
7+
import java.io.File
8+
import kotlin.system.exitProcess
9+
10+
/**
11+
* CLI application for MNIST digit classification.
12+
*
13+
* Usage: mnist-cli [options] <image-path>
14+
*
15+
* Options:
16+
* --model <mlp|cnn> Model type to use for classification (default: mlp)
17+
* --model-file <file> Path to model weights file (.gguf)
18+
* --weights-dir <dir> Directory containing model weights (alternative to --model-file)
19+
* --invert Invert image colors (use for black-on-white images)
20+
* --debug Print debug information including image preview
21+
* --help Show this help message
22+
*/
23+
fun main(args: Array<String>) {
24+
val config = ArgsParser.parse(args)
25+
26+
if (config.help) {
27+
printHelp()
28+
exitProcess(0)
29+
}
30+
31+
if (config.imagePath == null) {
32+
System.err.println("Error: Image path is required")
33+
printHelp()
34+
exitProcess(1)
35+
}
36+
37+
val imageFile = File(config.imagePath)
38+
if (!imageFile.exists()) {
39+
System.err.println("Error: Image file not found: ${config.imagePath}")
40+
exitProcess(1)
41+
}
42+
43+
// Determine model type
44+
val modelType = when (config.model.lowercase()) {
45+
"mlp" -> ModelType.MLP
46+
"cnn" -> ModelType.CNN
47+
else -> {
48+
System.err.println("Error: Unknown model type '${config.model}'. Use 'mlp' or 'cnn'.")
49+
exitProcess(1)
50+
}
51+
}
52+
53+
// Determine model weights source
54+
val modelWeights: ByteArray = when {
55+
config.modelFile != null -> {
56+
val modelFile = File(config.modelFile)
57+
if (!modelFile.exists()) {
58+
System.err.println("Error: Model file not found: ${config.modelFile}")
59+
exitProcess(1)
60+
}
61+
modelFile.readBytes()
62+
}
63+
config.weightsDir != null -> {
64+
val weightsDir = File(config.weightsDir)
65+
if (!weightsDir.exists()) {
66+
System.err.println("Error: Weights directory not found: ${config.weightsDir}")
67+
exitProcess(1)
68+
}
69+
val modelFileName = when (modelType) {
70+
ModelType.MLP -> "files/mnist_mlp.gguf"
71+
ModelType.CNN -> "files/mnist_cnn.gguf"
72+
}
73+
val modelFile = File(weightsDir, modelFileName)
74+
if (!modelFile.exists()) {
75+
System.err.println("Error: Model file not found: ${modelFile.absolutePath}")
76+
exitProcess(1)
77+
}
78+
modelFile.readBytes()
79+
}
80+
else -> {
81+
System.err.println("Error: Either --model-file or --weights-dir is required")
82+
printHelp()
83+
exitProcess(1)
84+
}
85+
}
86+
87+
classify(imageFile, modelType, modelWeights, config.invert, config.debug)
88+
}
89+
90+
private enum class ModelType { MLP, CNN }
91+
92+
private fun classify(
93+
imageFile: File,
94+
modelType: ModelType,
95+
modelWeights: ByteArray,
96+
invert: Boolean,
97+
debug: Boolean,
98+
) {
99+
// Create inference module based on model type
100+
val inferenceModule: InferenceModule = when (modelType) {
101+
ModelType.MLP -> MlpInferenceModuleAdapter.create()
102+
ModelType.CNN -> CnnInferenceModuleAdapter.create()
103+
}
104+
105+
if (debug) {
106+
println("Loading model: ${modelType.name.lowercase()}")
107+
}
108+
109+
// Load weights directly into inference module
110+
inferenceModule.load(modelWeights)
111+
112+
if (debug) {
113+
println("Model loaded successfully (${modelWeights.size} bytes)")
114+
}
115+
116+
// Load and convert image
117+
val image = ImageLoader.load(imageFile, invert)
118+
119+
if (debug) {
120+
println("Image loaded: ${imageFile.absolutePath}")
121+
println("Image preview:")
122+
image.debugPrintInConsoleOutput()
123+
}
124+
125+
// Classify
126+
val digit = inferenceModule.infer(image)
127+
128+
println(digit)
129+
}
130+
131+
private fun printHelp() {
132+
println("""
133+
MNIST Digit Classifier CLI
134+
135+
Usage: mnist-cli [options] <image-path>
136+
137+
Arguments:
138+
<image-path> Path to the image file (PNG, JPG, etc.)
139+
140+
Options:
141+
--model <mlp|cnn> Model type to use for classification (default: mlp)
142+
--model-file <file> Path to model weights file (.gguf)
143+
--weights-dir <dir> Directory containing model weights (alternative to --model-file)
144+
Expects files: files/mnist_mlp.gguf or files/mnist_cnn.gguf
145+
--invert Invert image colors (use for black-on-white images,
146+
as MNIST expects white digit on black background)
147+
--debug Print debug information including image preview
148+
--help, -h Show this help message
149+
150+
Examples:
151+
# Using explicit model file
152+
mnist-cli --model mlp --model-file ./mnist_mlp.gguf digit.png
153+
154+
# Using weights directory
155+
mnist-cli --model cnn --weights-dir ./models digit.png
156+
157+
# With debug output
158+
mnist-cli --debug --model-file ./model.gguf test_image.png
159+
""".trimIndent())
160+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package sk.ainet.cli.io
2+
3+
import sk.ainet.clean.data.io.ResourceReader
4+
import java.io.File
5+
6+
/**
7+
* File-system based ResourceReader for CLI application.
8+
* Reads model weights from a specified base directory.
9+
*/
10+
class FileResourceReader(private val baseDir: File) : ResourceReader {
11+
12+
override suspend fun read(path: String): ByteArray? {
13+
val file = File(baseDir, path)
14+
return if (file.exists() && file.isFile) {
15+
file.readBytes()
16+
} else {
17+
null
18+
}
19+
}
20+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package sk.ainet.cli.io
2+
3+
import sk.ainet.clean.data.image.GrayScale28To28Image
4+
import java.awt.Color
5+
import java.awt.RenderingHints
6+
import java.awt.image.BufferedImage
7+
import java.io.File
8+
import javax.imageio.ImageIO
9+
10+
/**
11+
* Loads and converts images to GrayScale28To28Image format for MNIST classification.
12+
*/
13+
object ImageLoader {
14+
15+
/**
16+
* Loads an image from file and converts it to a 28x28 grayscale image.
17+
*
18+
* @param file The image file (PNG, JPG, etc.)
19+
* @param invert If true, inverts colors (useful for white-on-black MNIST style images)
20+
* @return A 28x28 grayscale image ready for classification
21+
*/
22+
fun load(file: File, invert: Boolean = false): GrayScale28To28Image {
23+
require(file.exists()) { "Image file does not exist: ${file.absolutePath}" }
24+
25+
val originalImage = ImageIO.read(file)
26+
?: throw IllegalArgumentException("Cannot read image file: ${file.absolutePath}")
27+
28+
return convertToGrayscale28x28(originalImage, invert)
29+
}
30+
31+
/**
32+
* Converts a BufferedImage to a 28x28 grayscale image.
33+
*/
34+
fun convertToGrayscale28x28(image: BufferedImage, invert: Boolean = false): GrayScale28To28Image {
35+
// Resize to 28x28
36+
val resized = resizeImage(image, 28, 28)
37+
38+
val result = GrayScale28To28Image()
39+
for (y in 0 until 28) {
40+
for (x in 0 until 28) {
41+
val rgb = resized.getRGB(x, y)
42+
val color = Color(rgb)
43+
44+
// Convert to grayscale using luminance formula
45+
val gray = (0.299 * color.red + 0.587 * color.green + 0.114 * color.blue) / 255.0
46+
47+
// MNIST expects white digit on black background (1.0 = digit, 0.0 = background)
48+
// Default: assume input is already white-on-black (MNIST style), keep as-is
49+
// With invert: for black digit on white background, invert colors
50+
val value = if (invert) (1.0 - gray) else gray
51+
52+
result.setPixel(x, y, value.toFloat().coerceIn(0f, 1f))
53+
}
54+
}
55+
return result
56+
}
57+
58+
private fun resizeImage(original: BufferedImage, width: Int, height: Int): BufferedImage {
59+
val resized = BufferedImage(width, height, BufferedImage.TYPE_INT_RGB)
60+
val g = resized.createGraphics()
61+
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR)
62+
g.setRenderingHint(RenderingHints.KEY_RENDERING, RenderingHints.VALUE_RENDER_QUALITY)
63+
g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON)
64+
g.drawImage(original, 0, 0, width, height, null)
65+
g.dispose()
66+
return resized
67+
}
68+
}

0 commit comments

Comments
 (0)