Skip to content

Commit 9d45262

Browse files
Merge pull request #289 from vaiju1981/finetuner-api
Configurable fine-tuning API (TrainingParameters + Optimizer)
2 parents 0b4b58e + 07f7788 commit 9d45262

8 files changed

Lines changed: 545 additions & 0 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ endif()
308308
add_library(jllama SHARED
309309
src/main/cpp/jllama.cpp
310310
src/main/cpp/tts_engine.cpp
311+
src/main/cpp/train_engine.cpp
311312
${JLLAMA_TTS_GEN_CPP}
312313
src/main/cpp/utils.hpp
313314
${llama.cpp_SOURCE_DIR}/tools/server/server-common.cpp

src/main/cpp/train_engine.cpp

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
#include "train_engine.h"
6+
7+
#include "common.h"
8+
#include "ggml-opt.h"
9+
#include "llama.h"
10+
11+
#include <nlohmann/json.hpp>
12+
13+
#include <jni.h>
14+
15+
#include <exception>
16+
#include <fstream>
17+
#include <iterator>
18+
#include <string>
19+
#include <vector>
20+
21+
using json = nlohmann::json;
22+
23+
namespace jllama_train {
24+
25+
bool finetune(const finetune_config &cfg, std::string &err) {
26+
common_params params;
27+
params.escape = false;
28+
params.model.path = cfg.model_path;
29+
params.out_file = cfg.output_path;
30+
params.n_ctx = cfg.n_ctx;
31+
params.n_gpu_layers = cfg.n_gpu_layers;
32+
params.val_split = cfg.val_split;
33+
if (cfg.n_batch > 0) {
34+
params.n_batch = cfg.n_batch;
35+
}
36+
if (cfg.n_ubatch > 0) {
37+
params.n_ubatch = cfg.n_ubatch;
38+
}
39+
40+
params.optimizer =
41+
cfg.optimizer == 1 ? GGML_OPT_OPTIMIZER_TYPE_SGD : GGML_OPT_OPTIMIZER_TYPE_ADAMW;
42+
params.lr.lr0 = cfg.learning_rate;
43+
params.lr.lr_min = cfg.lr_min;
44+
params.lr.decay_epochs = cfg.decay_epochs;
45+
params.lr.wd = cfg.weight_decay;
46+
params.lr.epochs = static_cast<unsigned>(cfg.epochs > 0 ? cfg.epochs : 1);
47+
params.lr.init(); // required after setting lr fields, before the optimizer reads get_lr()
48+
49+
// The corpus is either read from a file or supplied inline.
50+
if (!cfg.training_file.empty()) {
51+
std::ifstream in(cfg.training_file, std::ios::binary);
52+
if (!in) {
53+
err = "cannot open training file: " + cfg.training_file;
54+
return false;
55+
}
56+
params.prompt.assign(std::istreambuf_iterator<char>(in), std::istreambuf_iterator<char>());
57+
} else {
58+
params.prompt = cfg.training_text;
59+
}
60+
61+
// Training needs writable weights (mmap yields read-only pointers) and an f32 KV cache
62+
// (OUT_PROD has no f16 support) — same forced settings as upstream finetune.cpp.
63+
params.use_mmap = false;
64+
params.cache_type_k = GGML_TYPE_F32;
65+
params.cache_type_v = GGML_TYPE_F32;
66+
67+
llama_backend_init();
68+
llama_numa_init(params.numa);
69+
70+
common_init_result_ptr llama_init = common_init_from_params(params);
71+
llama_model *model = llama_init->model();
72+
llama_context *ctx = llama_init->context();
73+
if (model == nullptr || ctx == nullptr) {
74+
err = "failed to load model for training: " + cfg.model_path;
75+
return false;
76+
}
77+
78+
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
79+
if (tokens.size() < 2) {
80+
err = "training corpus produced too few tokens (need at least 2)";
81+
return false;
82+
}
83+
84+
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx) / 2);
85+
86+
llama_opt_params lopt_params = {
87+
/*n_ctx_train =*/0,
88+
/*param_filter =*/llama_opt_param_filter_all,
89+
/*param_filter_ud =*/nullptr,
90+
/*get_opt_pars =*/common_opt_lr_pars,
91+
/*get_opt_pars_ud =*/&params.lr,
92+
/*optimizer_type =*/params.optimizer,
93+
};
94+
llama_opt_init(ctx, model, lopt_params);
95+
96+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
97+
98+
ggml_opt_result_t result_train = ggml_opt_result_init();
99+
ggml_opt_result_t result_eval = ggml_opt_result_init();
100+
101+
for (params.lr.epoch = 0; params.lr.epoch < params.lr.epochs; ++params.lr.epoch) {
102+
llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split,
103+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
104+
ggml_opt_result_reset(result_train);
105+
ggml_opt_result_reset(result_eval);
106+
}
107+
108+
ggml_opt_result_free(result_train);
109+
ggml_opt_result_free(result_eval);
110+
ggml_opt_dataset_free(dataset);
111+
112+
llama_model_save_to_file(model, params.out_file.c_str());
113+
114+
// Deliberately NOT calling llama_backend_free(): other live llama contexts in this JVM
115+
// (e.g. an inference LlamaModel) may still depend on the initialized backend.
116+
return true;
117+
}
118+
119+
} // namespace jllama_train
120+
121+
extern "C" JNIEXPORT jstring JNICALL
122+
Java_net_ladenthin_llama_LlamaTrainer_finetuneNative(JNIEnv *env, jclass, jstring jconfig) {
123+
std::string config_json;
124+
if (jconfig != nullptr) {
125+
const char *c = env->GetStringUTFChars(jconfig, nullptr);
126+
if (c != nullptr) {
127+
config_json = c;
128+
env->ReleaseStringUTFChars(jconfig, c);
129+
}
130+
}
131+
132+
jllama_train::finetune_config cfg;
133+
try {
134+
const json j = json::parse(config_json);
135+
cfg.model_path = j.value("model_path", std::string());
136+
cfg.training_text = j.value("training_text", std::string());
137+
cfg.training_file = j.value("training_file", std::string());
138+
cfg.output_path = j.value("output_path", std::string());
139+
cfg.epochs = j.value("epochs", 2);
140+
cfg.learning_rate = j.value("learning_rate", 1e-5f);
141+
cfg.lr_min = j.value("lr_min", -1.0f);
142+
cfg.decay_epochs = j.value("decay_epochs", -1.0f);
143+
cfg.weight_decay = j.value("weight_decay", 0.0f);
144+
cfg.optimizer = j.value("optimizer", 0);
145+
cfg.n_ctx = j.value("n_ctx", 0);
146+
cfg.n_gpu_layers = j.value("n_gpu_layers", -1);
147+
cfg.val_split = j.value("val_split", 0.05f);
148+
cfg.n_batch = j.value("n_batch", 0);
149+
cfg.n_ubatch = j.value("n_ubatch", 0);
150+
} catch (const std::exception &e) {
151+
return env->NewStringUTF((std::string("invalid training config: ") + e.what()).c_str());
152+
}
153+
154+
std::string err;
155+
try {
156+
if (jllama_train::finetune(cfg, err)) {
157+
return env->NewStringUTF(""); // empty == success
158+
}
159+
} catch (const std::exception &e) {
160+
err = e.what();
161+
} catch (...) {
162+
err = "unknown C++ exception during fine-tuning";
163+
}
164+
return env->NewStringUTF(err.c_str());
165+
}

src/main/cpp/train_engine.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
2+
//
3+
// SPDX-License-Identifier: MIT
4+
//
5+
// Native fine-tuning engine (proof-of-concept): a self-contained wrapper over llama.cpp's
6+
// ggml-opt training path (llama_opt_init / llama_opt_epoch), mirroring upstream
7+
// examples/training/finetune.cpp. Loads its own model + context (independent of the inference
8+
// server_context in jllama.cpp), fine-tunes on a text corpus, and writes a new GGUF via
9+
// llama_model_save_to_file. Kept out of jllama.cpp so the JNI layer stays thin.
10+
11+
#ifndef JLLAMA_TRAIN_ENGINE_H
12+
#define JLLAMA_TRAIN_ENGINE_H
13+
14+
#include <string>
15+
16+
namespace jllama_train {
17+
18+
// One fine-tuning run's inputs.
19+
struct finetune_config {
20+
std::string model_path; // base GGUF to fine-tune
21+
std::string training_text; // corpus supplied inline (used when training_file is empty)
22+
std::string training_file; // corpus read from this path instead of training_text
23+
std::string output_path; // where the fine-tuned GGUF is written
24+
int epochs; // number of passes over the corpus (>= 1)
25+
float learning_rate; // lr at the first epoch
26+
float lr_min; // minimum lr for decay; < 0 = no decay
27+
float decay_epochs; // decay lr0 -> lr_min over this many epochs; <= 0 = disabled
28+
float weight_decay; // weight decay; 0 = disabled
29+
int optimizer; // ggml_opt_optimizer_type: 0 = AdamW, 1 = SGD
30+
int n_ctx; // context size; 0 = the model's trained context
31+
int n_gpu_layers; // layers offloaded to the GPU; -1 = auto
32+
float val_split; // fraction of the corpus held out for validation
33+
int n_batch; // logical batch size; 0 = native default
34+
int n_ubatch; // physical (micro) batch size; 0 = native default
35+
};
36+
37+
// Run one fine-tuning job end to end. Returns true on success; on failure returns false and sets
38+
// `err`. Not re-entrant; intended to be called off the JVM's critical threads (it blocks for the
39+
// full training run).
40+
bool finetune(const finetune_config &cfg, std::string &err);
41+
42+
} // namespace jllama_train
43+
44+
#endif // JLLAMA_TRAIN_ENGINE_H
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
package net.ladenthin.llama;
6+
7+
import java.nio.file.Path;
8+
import net.ladenthin.llama.exception.LlamaException;
9+
import net.ladenthin.llama.loader.LlamaLoader;
10+
import net.ladenthin.llama.parameters.TrainingParameters;
11+
12+
/**
13+
* In-process fine-tuning entry point, wrapping llama.cpp's ggml-opt training path
14+
* ({@code llama_opt_init} / {@code llama_opt_epoch}) the same way the upstream
15+
* {@code examples/training/finetune.cpp} tool does. Loads its own model and context (independent of
16+
* {@link LlamaModel}), fine-tunes on a text corpus, and writes a new GGUF.
17+
*
18+
* <p>Configure a run with {@link TrainingParameters} and pass it to {@link #finetune(TrainingParameters)}.
19+
* Full-model fine-tuning is compute- and memory-intensive and blocks for the whole run; upstream
20+
* training support is itself experimental.
21+
*/
22+
public final class LlamaTrainer {
23+
24+
static {
25+
LlamaLoader.initialize();
26+
}
27+
28+
private LlamaTrainer() {}
29+
30+
/**
31+
* Run one fine-tuning job to completion.
32+
*
33+
* @param parameters the training configuration (model, corpus, output, optimizer, schedule, ...)
34+
* @throws LlamaException if the model cannot be loaded or training fails
35+
*/
36+
public static void finetune(TrainingParameters parameters) {
37+
String error = finetuneNative(parameters.toJson());
38+
if (error != null && !error.isEmpty()) {
39+
throw new LlamaException(error);
40+
}
41+
}
42+
43+
/**
44+
* Convenience fine-tune with inline text and otherwise-default settings.
45+
*
46+
* @param model the base GGUF model to fine-tune
47+
* @param trainingText the training corpus (tokenized in-process)
48+
* @param output the path the fine-tuned GGUF is written to
49+
* @param epochs number of passes over the corpus (at least 1)
50+
* @param learningRate the AdamW learning rate at the first epoch (e.g. {@code 1e-5f})
51+
* @throws LlamaException if the model cannot be loaded or training fails
52+
*/
53+
public static void finetune(Path model, String trainingText, Path output, int epochs, float learningRate) {
54+
finetune(
55+
TrainingParameters.builder()
56+
.modelPath(model)
57+
.trainingText(trainingText)
58+
.outputPath(output)
59+
.epochs(epochs)
60+
.learningRate(learningRate)
61+
.build());
62+
}
63+
64+
private static native String finetuneNative(String configJson);
65+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
package net.ladenthin.llama.args;
6+
7+
/**
8+
* Optimizer used by {@link net.ladenthin.llama.LlamaTrainer} fine-tuning, mapping to llama.cpp's
9+
* {@code ggml_opt_optimizer_type}.
10+
*/
11+
public enum Optimizer {
12+
13+
/** Adam with decoupled weight decay ({@code GGML_OPT_OPTIMIZER_TYPE_ADAMW}). The default. */
14+
ADAMW(0),
15+
16+
/** Stochastic gradient descent ({@code GGML_OPT_OPTIMIZER_TYPE_SGD}). */
17+
SGD(1);
18+
19+
private final int nativeValue;
20+
21+
Optimizer(int nativeValue) {
22+
this.nativeValue = nativeValue;
23+
}
24+
25+
/**
26+
* The integer value passed to the native layer (matches the {@code ggml_opt_optimizer_type} enum).
27+
*
28+
* @return the native optimizer-type ordinal
29+
*/
30+
public int getNativeValue() {
31+
return nativeValue;
32+
}
33+
}

0 commit comments

Comments
 (0)