Skip to content

Commit 70df324

Browse files
committed
Add LoadProgressCallback for model-load progress (#113)
Exposes llama.cpp's llama_model_params.progress_callback as a Java functional interface. New constructor: new LlamaModel(parameters, progress -> { ... return true; }); The callback receives a float in [0.0, 1.0] on the loader thread (same thread that called the constructor) and may return false to abort, in which case the constructor throws LlamaException. JNI: extracts the existing loadModel body into load_model_impl, adds a trampoline that forwards float progress to a Java LoadProgressCallback.onProgress(float)Z via CallBooleanMethod. Trampoline state lives on the loader stack — bounded lifetime is the single load call. Two native entry points share the implementation: loadModel(String[]) — unchanged signature loadModelWithProgress(String[], LoadProgressCallback) Tests in LoadProgressCallbackTest (model-gated): non-decreasing progress in [0,1] reaching ~1.0, returning false aborts with LlamaException, null callback overload delegates to plain loadModel. All 435 C++ unit tests still pass. mvn javadoc:jar BUILD SUCCESS. https://claude.ai/code/session_01R4ZrEy3ptJDLuUgUKuM4Gy
1 parent 85b3895 commit 70df324

4 files changed

Lines changed: 188 additions & 1 deletion

File tree

src/main/cpp/jllama.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,26 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) {
598598
llama_backend_free();
599599
}
600600

601-
JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jobjectArray jparams) {
601+
// Trampoline state for llama.cpp's load_progress_callback. The native loader runs
602+
// on the calling JNI thread so we can capture JNIEnv directly. Lifetime is bounded
603+
// by the single load_model_impl call.
604+
namespace {
605+
struct load_progress_ud {
606+
JNIEnv *env;
607+
jobject callback;
608+
jmethodID on_progress;
609+
};
610+
611+
bool jni_load_progress_trampoline(float progress, void *user_data) {
612+
auto *ud = static_cast<load_progress_ud *>(user_data);
613+
return ud->env->CallBooleanMethod(ud->callback, ud->on_progress, progress) == JNI_TRUE;
614+
}
615+
} // namespace
616+
617+
// Shared implementation of loadModel and loadModelWithProgress. When `progress` is
618+
// non-null, installs a load-progress trampoline; otherwise behaves identically to
619+
// the no-callback path.
620+
static void load_model_impl(JNIEnv *env, jobject obj, jobjectArray jparams, jobject progress) {
602621
common_params params;
603622

604623
const jsize argc = env->GetArrayLength(jparams);
@@ -662,6 +681,21 @@ JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_loadModel(JNIEnv *env
662681

663682
LOG_INF("%s: loading model\n", __func__);
664683

684+
// Install the load-progress trampoline if the caller supplied a callback.
685+
load_progress_ud progress_ud{};
686+
if (progress != nullptr) {
687+
jclass cb_cls = env->GetObjectClass(progress);
688+
progress_ud.env = env;
689+
progress_ud.callback = progress;
690+
progress_ud.on_progress = env->GetMethodID(cb_cls, "onProgress", "(F)Z");
691+
if (progress_ud.on_progress == nullptr) {
692+
fail_load("LoadProgressCallback.onProgress(float) not found");
693+
return;
694+
}
695+
params.load_progress_callback = jni_load_progress_trampoline;
696+
params.load_progress_callback_user_data = &progress_ud;
697+
}
698+
665699
if (!jctx->server.load_model(params)) {
666700
fail_load("could not load model from given file path");
667701
return;
@@ -706,6 +740,16 @@ JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_loadModel(JNIEnv *env
706740
env->SetLongField(obj, f_model_pointer, reinterpret_cast<jlong>(jctx));
707741
}
708742

743+
JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jobjectArray jparams) {
744+
load_model_impl(env, obj, jparams, nullptr);
745+
}
746+
747+
JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_loadModelWithProgress(JNIEnv *env, jobject obj,
748+
jobjectArray jparams,
749+
jobject callback) {
750+
load_model_impl(env, obj, jparams, callback);
751+
}
752+
709753
JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_getModelMetaJson(JNIEnv *env, jobject obj) {
710754
REQUIRE_SERVER_CONTEXT(nullptr);
711755
if (jctx->vocab_only) {

src/main/java/net/ladenthin/llama/LlamaModel.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,24 @@ public LlamaModel(ModelParameters parameters) {
5858
loadModel(parameters.toArray());
5959
}
6060

61+
/**
62+
* Load the model and forward progress updates to {@code progress}. The callback is
63+
* invoked synchronously on the constructor thread by the native loader and may
64+
* return {@code false} to abort the load (in which case this constructor throws
65+
* {@link LlamaException}).
66+
*
67+
* @param parameters the set of options
68+
* @param progress load progress sink; {@code null} disables the callback
69+
* @throws LlamaException if loading fails or the callback aborts
70+
*/
71+
public LlamaModel(ModelParameters parameters, LoadProgressCallback progress) {
72+
if (progress == null) {
73+
loadModel(parameters.toArray());
74+
} else {
75+
loadModelWithProgress(parameters.toArray(), progress);
76+
}
77+
}
78+
6179
/**
6280
* Generate and return a whole answer with custom parameters. Note, that the prompt isn't preprocessed in any
6381
* way, nothing like "User: ", "###Instruction", etc. is added.
@@ -257,6 +275,8 @@ public void close() {
257275

258276
private native void loadModel(String... parameters) throws LlamaException;
259277

278+
private native void loadModelWithProgress(String[] parameters, LoadProgressCallback callback) throws LlamaException;
279+
260280
private native void delete();
261281

262282
native void releaseTask(int taskId);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
package net.ladenthin.llama;
6+
7+
/**
8+
* Receives model-load progress updates from the native loader.
9+
* <p>
10+
* Pass an instance to {@link LlamaModel#LlamaModel(ModelParameters, LoadProgressCallback)}
11+
* to observe the {@code llama_model_params.progress_callback} hook from llama.cpp. The
12+
* callback is invoked synchronously on the loader thread (the same thread that called
13+
* the constructor) with a value in {@code [0.0, 1.0]}.
14+
* </p>
15+
* <p>
16+
* Return {@code false} to abort the load. When {@code false} is returned, the constructor
17+
* throws {@link LlamaException} because the native loader aborts and reports failure.
18+
* </p>
19+
*/
20+
@FunctionalInterface
21+
public interface LoadProgressCallback {
22+
23+
/**
24+
* Receive a progress update.
25+
*
26+
* @param progress fraction in {@code [0.0, 1.0]}
27+
* @return {@code true} to continue loading, {@code false} to abort
28+
*/
29+
boolean onProgress(float progress);
30+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
package net.ladenthin.llama;
6+
7+
import org.junit.Assume;
8+
import org.junit.Test;
9+
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
13+
import static org.junit.Assert.assertFalse;
14+
import static org.junit.Assert.assertNotEquals;
15+
import static org.junit.Assert.assertTrue;
16+
import static org.junit.Assert.fail;
17+
18+
@ClaudeGenerated(
19+
purpose = "Verify LoadProgressCallback receives non-decreasing progress values in [0,1] "
20+
+ "during a real model load, and that returning false from the callback aborts the load."
21+
)
22+
public class LoadProgressCallbackTest {
23+
24+
@Test
25+
public void receivesProgressUpdates() {
26+
Assume.assumeTrue("Model file not found", new java.io.File(TestConstants.MODEL_PATH).exists());
27+
28+
List<Float> updates = new ArrayList<Float>();
29+
int gpuLayers = Integer.getInteger(TestConstants.PROP_TEST_NGL, TestConstants.DEFAULT_TEST_NGL);
30+
31+
try (LlamaModel m = new LlamaModel(
32+
new ModelParameters()
33+
.setCtxSize(128)
34+
.setModel(TestConstants.MODEL_PATH)
35+
.setGpuLayers(gpuLayers)
36+
.setFit(false),
37+
progress -> {
38+
updates.add(progress);
39+
return true;
40+
})) {
41+
// model load completed
42+
}
43+
44+
assertFalse("expected at least one progress update", updates.isEmpty());
45+
for (Float p : updates) {
46+
assertTrue("progress out of range: " + p, p >= 0.0f && p <= 1.0f);
47+
}
48+
// Last update should reach (or be very close to) 1.0
49+
assertTrue("last progress should reach completion, got " + updates.get(updates.size() - 1),
50+
updates.get(updates.size() - 1) >= 0.9f);
51+
// Non-decreasing
52+
for (int i = 1; i < updates.size(); i++) {
53+
assertTrue("progress decreased at index " + i + ": " + updates.get(i - 1) + " -> " + updates.get(i),
54+
updates.get(i) >= updates.get(i - 1));
55+
}
56+
// Sanity: progress actually advanced
57+
assertNotEquals("progress never advanced", updates.get(0), updates.get(updates.size() - 1));
58+
}
59+
60+
@Test
61+
public void returningFalseAbortsLoad() {
62+
Assume.assumeTrue("Model file not found", new java.io.File(TestConstants.MODEL_PATH).exists());
63+
64+
int gpuLayers = Integer.getInteger(TestConstants.PROP_TEST_NGL, TestConstants.DEFAULT_TEST_NGL);
65+
try {
66+
new LlamaModel(
67+
new ModelParameters()
68+
.setCtxSize(128)
69+
.setModel(TestConstants.MODEL_PATH)
70+
.setGpuLayers(gpuLayers)
71+
.setFit(false),
72+
progress -> false).close();
73+
fail("expected LlamaException when callback aborts load");
74+
} catch (LlamaException expected) {
75+
// pass
76+
}
77+
}
78+
79+
@Test
80+
public void nullCallbackBehavesAsDefault() {
81+
Assume.assumeTrue("Model file not found", new java.io.File(TestConstants.MODEL_PATH).exists());
82+
int gpuLayers = Integer.getInteger(TestConstants.PROP_TEST_NGL, TestConstants.DEFAULT_TEST_NGL);
83+
try (LlamaModel m = new LlamaModel(
84+
new ModelParameters()
85+
.setCtxSize(128)
86+
.setModel(TestConstants.MODEL_PATH)
87+
.setGpuLayers(gpuLayers)
88+
.setFit(false),
89+
null)) {
90+
// no callback wired; just verifies the null-overload routes to plain loadModel
91+
}
92+
}
93+
}

0 commit comments

Comments
 (0)