Skip to content

Commit 2e211e8

Browse files
authored
fix(local): harden embedding batch and context handling (#7)
Skip whitespace-only chunks before embedding, size local llama contexts from the configured chunk window, and keep batch/context limits aligned to avoid encoder assertions. Use thread-local llama diagnostics instead of process-global logger user_data, rebuild the local engine when token window options change, and invalidate cached local embeddings after context rebuilds so stale embeddings are not reused.
1 parent 653ec34 commit 2e211e8

6 files changed

Lines changed: 304 additions & 75 deletions

File tree

src/dbmem-embed.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ typedef struct {
2222
float *embedding; // Pointer to embedding (points to engine's buffer, do not free)
2323
} embedding_result_t;
2424

25-
dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path, char err_msg[DBMEM_ERRBUF_SIZE]);
25+
dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path, int max_context_tokens, char err_msg[DBMEM_ERRBUF_SIZE]);
2626
int dbmem_local_compute_embedding (dbmem_local_engine_t *engine, const char *text, int text_len, embedding_result_t *result);
2727
bool dbmem_local_engine_warmup (dbmem_local_engine_t *engine);
2828
void dbmem_local_engine_free (dbmem_local_engine_t *engine);

src/dbmem-lembed.c

Lines changed: 154 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,21 @@
1313
#include <math.h>
1414
#include <string.h>
1515

16+
#define DBMEM_LOCAL_MIN_CONTEXT_TOKENS 128
17+
#define DBMEM_LOCAL_MAX_CONTEXT_TOKENS 8192
18+
19+
#if defined(_MSC_VER)
20+
#define DBMEM_THREAD_LOCAL __declspec(thread)
21+
#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L
22+
#define DBMEM_THREAD_LOCAL _Thread_local
23+
#else
24+
#define DBMEM_THREAD_LOCAL __thread
25+
#endif
26+
27+
static DBMEM_THREAD_LOCAL bool dbmem_llama_diag_enabled = false;
28+
static DBMEM_THREAD_LOCAL char dbmem_llama_diag[DBMEM_ERRBUF_SIZE];
29+
static DBMEM_THREAD_LOCAL size_t dbmem_llama_diag_len = 0;
30+
1631
struct dbmem_local_engine_t {
1732
dbmem_context *context;
1833

@@ -26,14 +41,17 @@ struct dbmem_local_engine_t {
2641
// Model info
2742
int n_embd; // Embedding dimension (e.g., 768 for nomic-embed)
2843
int n_ctx; // Maximum context length in tokens
44+
int n_ubatch; // Maximum physical batch size for encoder input
2945
bool is_encoder_only; // True for BERT-style models, false for GPT-style
3046

3147
// Settings
3248
bool normalize; // Whether to L2 normalize output embeddings
3349

3450
// Reusable buffers (avoid repeated allocations)
3551
llama_token *tokens; // Pre-allocated buffer for tokenized input
36-
int tokens_capacity; // Size of tokens buffer (equals n_ctx)
52+
int tokens_capacity; // Size of tokens buffer, capped by n_ubatch
53+
struct llama_batch batch; // Pre-allocated llama.cpp batch with sequence metadata
54+
bool batch_initialized; // True when batch must be freed
3755
float *embedding; // Pre-allocated buffer for output embedding (n_embd floats)
3856

3957
// Statistics
@@ -75,75 +93,130 @@ static void dbmem_embedding_normalize (float *vec, int n) {
7593
}
7694

7795
void dbmem_logger (enum ggml_log_level level, const char *text, void *user_data) {
78-
dbmem_local_engine_t *engine = (dbmem_local_engine_t *)user_data;
79-
//if (ai->db == NULL) return;
80-
//if ((level == GGML_LOG_LEVEL_INFO) && (ai->options.log_info == false)) return;
81-
82-
const char *type = NULL;
83-
switch (level) {
84-
case GGML_LOG_LEVEL_NONE: type = "NONE"; break;
85-
case GGML_LOG_LEVEL_DEBUG: type = "DEBUG"; break;
86-
case GGML_LOG_LEVEL_INFO: type = "INFO"; break;
87-
case GGML_LOG_LEVEL_WARN: type = "WARNING"; break;
88-
case GGML_LOG_LEVEL_ERROR: type = "ERROR"; break;
89-
case GGML_LOG_LEVEL_CONT: type = NULL; break;
96+
UNUSED_PARAM(user_data);
97+
if (!dbmem_llama_diag_enabled || !text) return;
98+
99+
if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_CONT) {
100+
size_t remaining = sizeof(dbmem_llama_diag) - dbmem_llama_diag_len;
101+
if (remaining > 1) {
102+
int written = snprintf(dbmem_llama_diag + dbmem_llama_diag_len, remaining, "%s", text);
103+
if (written > 0) {
104+
size_t used = (size_t)written;
105+
if (used >= remaining) {
106+
dbmem_llama_diag_len = sizeof(dbmem_llama_diag) - 1;
107+
} else {
108+
dbmem_llama_diag_len += used;
109+
}
110+
}
111+
}
90112
}
91-
92-
// DEBUG
93-
// printf("%s %s\n", type, text);
94-
95-
//const char *values[] = {type, text};
96-
//int types[] = {(type == NULL) ? SQLITE_NULL : SQLITE_TEXT, SQLITE_TEXT};
97-
//int lens[] = {-1, -1};
98-
//sqlite_db_write(NULL, ai->db, LOG_TABLE_INSERT_STMT, values, types, lens, 2);
99113
}
100114

101115
// MARK: -
102116

117+
static void dbmem_llama_diag_begin(void) {
118+
dbmem_llama_diag[0] = 0;
119+
dbmem_llama_diag_len = 0;
120+
dbmem_llama_diag_enabled = true;
121+
}
122+
123+
static void dbmem_llama_diag_end(void) {
124+
dbmem_llama_diag_enabled = false;
125+
}
126+
127+
static const char *dbmem_llama_diag_message(void) {
128+
return dbmem_llama_diag[0] ? dbmem_llama_diag : NULL;
129+
}
130+
103131
static void dbmem_local_set_error(dbmem_local_engine_t *engine, const char *message) {
104132
if (!engine || !engine->context) return;
105133
dbmem_context_set_error(engine->context, message);
106134
}
107135

108-
dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path, char err_msg[DBMEM_ERRBUF_SIZE]) {
136+
static bool dbmem_local_batch_prepare(dbmem_local_engine_t *engine, int n_tokens) {
137+
if (!engine || !engine->batch_initialized) return false;
138+
if (n_tokens <= 0 || n_tokens > engine->tokens_capacity) return false;
139+
140+
engine->batch.n_tokens = 0;
141+
for (int i = 0; i < n_tokens; i++) {
142+
engine->batch.token[i] = engine->tokens[i];
143+
engine->batch.pos[i] = i;
144+
engine->batch.n_seq_id[i] = 1;
145+
engine->batch.seq_id[i][0] = 0;
146+
engine->batch.logits[i] = 1;
147+
engine->batch.n_tokens++;
148+
}
149+
150+
return true;
151+
}
152+
153+
static int dbmem_local_process_batch(dbmem_local_engine_t *engine) {
154+
if (engine->is_encoder_only) {
155+
return llama_encode(engine->ctx, engine->batch);
156+
}
157+
return llama_decode(engine->ctx, engine->batch);
158+
}
159+
160+
dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path, int max_context_tokens, char err_msg[DBMEM_ERRBUF_SIZE]) {
109161
dbmem_local_engine_t *engine = (dbmem_local_engine_t *)dbmemory_zeroalloc(sizeof(dbmem_local_engine_t));
110162
if (!engine) return NULL;
111163
engine->context = (dbmem_context *)ctx;
112164

113165
// set logger
114-
llama_log_set(dbmem_logger, engine);
166+
llama_log_set(dbmem_logger, NULL);
167+
dbmem_llama_diag_begin();
115168

116169
// Initialize backend
117170
llama_backend_init();
118171

119172
// Load model
120173
struct llama_model_params model_params = llama_model_default_params();
174+
model_params.n_gpu_layers = 0;
175+
model_params.split_mode = LLAMA_SPLIT_MODE_NONE;
176+
model_params.main_gpu = -1;
121177
engine->model = llama_model_load_from_file(model_path, model_params);
122178
if (!engine->model) {
123-
snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to load model: %s", model_path);
179+
const char *diag = dbmem_llama_diag_message();
180+
if (diag) {
181+
snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to load model: %s: %s", model_path, diag);
182+
} else {
183+
snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to load model: %s", model_path);
184+
}
124185
goto cleanup;
125186
}
126187

127188
// Get model's native context length
128189
int n_ctx_train = llama_model_n_ctx_train(engine->model);
190+
int n_ctx = max_context_tokens * 4;
191+
if (n_ctx < DBMEM_LOCAL_MIN_CONTEXT_TOKENS) n_ctx = DBMEM_LOCAL_MIN_CONTEXT_TOKENS;
192+
if (n_ctx > DBMEM_LOCAL_MAX_CONTEXT_TOKENS) n_ctx = DBMEM_LOCAL_MAX_CONTEXT_TOKENS;
193+
if (n_ctx_train > 0 && n_ctx > n_ctx_train) n_ctx = n_ctx_train;
129194

130195
// Create context
131196
struct llama_context_params ctx_params = llama_context_default_params();
132197
ctx_params.embeddings = true;
133-
ctx_params.n_ctx = n_ctx_train;
134-
ctx_params.n_batch = n_ctx_train;
135-
ctx_params.n_ubatch = n_ctx_train;
198+
ctx_params.n_ctx = n_ctx;
199+
ctx_params.n_batch = n_ctx;
200+
ctx_params.n_ubatch = n_ctx;
201+
ctx_params.offload_kqv = false;
202+
ctx_params.op_offload = false;
136203

137204
engine->ctx = llama_init_from_model(engine->model, ctx_params);
138205
if (!engine->ctx) {
139-
snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to create context");
206+
const char *diag = dbmem_llama_diag_message();
207+
if (diag) {
208+
snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to create context: %s", diag);
209+
} else {
210+
snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to create context");
211+
}
140212
goto cleanup;
141213
}
142214

143215
// Get model info
144216
engine->vocab = llama_model_get_vocab(engine->model);
145-
engine->n_embd = llama_model_n_embd(engine->model);
217+
engine->n_embd = llama_model_n_embd_out(engine->model);
146218
engine->n_ctx = llama_n_ctx(engine->ctx);
219+
engine->n_ubatch = llama_n_ubatch(engine->ctx);
147220
engine->pooling = llama_pooling_type(engine->ctx);
148221
engine->mem = llama_get_memory(engine->ctx);
149222

@@ -159,12 +232,22 @@ dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path
159232

160233
// Allocate token buffer
161234
engine->tokens_capacity = engine->n_ctx;
235+
if (engine->n_ubatch > 0 && engine->tokens_capacity > engine->n_ubatch) {
236+
engine->tokens_capacity = engine->n_ubatch;
237+
}
162238
engine->tokens = (llama_token *)dbmemory_alloc(sizeof(llama_token) * engine->tokens_capacity);
163239
if (!engine->tokens) {
164240
snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to allocate token buffer");
165241
goto cleanup;
166242
}
167243

244+
engine->batch = llama_batch_init(engine->tokens_capacity, 0, 1);
245+
engine->batch_initialized = true;
246+
if (!engine->batch.token || !engine->batch.pos || !engine->batch.n_seq_id || !engine->batch.seq_id || !engine->batch.logits) {
247+
snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to allocate llama batch");
248+
goto cleanup;
249+
}
250+
168251
// Allocate single embedding buffer
169252
engine->embedding = (float *)dbmemory_alloc(sizeof(float) * engine->n_embd);
170253
if (!engine->embedding) {
@@ -177,9 +260,11 @@ dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path
177260
engine->total_tokens_processed = 0;
178261
engine->total_embeddings_generated = 0;
179262

263+
dbmem_llama_diag_end();
180264
return engine;
181265

182266
cleanup:
267+
dbmem_llama_diag_end();
183268
dbmem_local_engine_free(engine);
184269
return NULL;
185270
}
@@ -190,17 +275,8 @@ bool dbmem_local_engine_warmup (dbmem_local_engine_t *engine) {
190275

191276
const char *warmup_text = "Warmup";
192277
int warmup_tokens = llama_tokenize(engine->vocab, warmup_text, (int32_t)strlen(warmup_text), engine->tokens, engine->tokens_capacity, true, true);
193-
if (warmup_tokens > 0) {
194-
struct llama_batch batch = {
195-
.n_tokens = warmup_tokens,
196-
.token = engine->tokens,
197-
.embd = NULL,
198-
.pos = NULL,
199-
.n_seq_id = NULL,
200-
.seq_id = NULL,
201-
.logits = NULL,
202-
};
203-
llama_encode(engine->ctx, batch);
278+
if (warmup_tokens > 0 && dbmem_local_batch_prepare(engine, warmup_tokens)) {
279+
dbmem_local_process_batch(engine);
204280

205281
if (engine->mem != NULL) {
206282
llama_memory_clear(engine->mem, true);
@@ -215,40 +291,56 @@ int dbmem_local_compute_embedding (dbmem_local_engine_t *engine, const char *tex
215291
if (text_len == -1) text_len = (int)strlen(text);
216292
if (text_len == 0) return 0;
217293

294+
bool truncated = false;
295+
218296
// Tokenize
219297
int n_tokens = llama_tokenize(engine->vocab, text, text_len, engine->tokens, engine->tokens_capacity, true, true);
220298
if (n_tokens < 0) {
221-
dbmem_local_set_error(engine, "Tokenization failed (text too long?)");
222-
return -1;
299+
int needed = -n_tokens;
300+
if (needed <= 0) {
301+
dbmem_local_set_error(engine, "Tokenization failed");
302+
return -1;
303+
}
304+
305+
llama_token *all_tokens = (llama_token *)dbmemory_alloc(sizeof(llama_token) * needed);
306+
if (!all_tokens) {
307+
dbmem_local_set_error(engine, "Failed to allocate token overflow buffer");
308+
return -1;
309+
}
310+
311+
int full_tokens = llama_tokenize(engine->vocab, text, text_len, all_tokens, needed, true, true);
312+
if (full_tokens < 0) {
313+
dbmemory_free(all_tokens);
314+
dbmem_local_set_error(engine, "Tokenization failed");
315+
return -1;
316+
}
317+
318+
n_tokens = engine->tokens_capacity;
319+
memcpy(engine->tokens, all_tokens, sizeof(llama_token) * n_tokens);
320+
dbmemory_free(all_tokens);
321+
truncated = true;
223322
}
224323

225324
// Handle token overflow: truncate to max context size
226-
bool truncated = false;
227-
if (n_tokens > engine->n_ctx) {
325+
if (n_tokens > engine->tokens_capacity) {
228326
truncated = true;
229-
n_tokens = engine->n_ctx;
327+
n_tokens = engine->tokens_capacity;
230328
}
231329

232-
// Create batch
233-
struct llama_batch batch = {
234-
.n_tokens = n_tokens,
235-
.token = engine->tokens,
236-
.embd = NULL,
237-
.pos = NULL,
238-
.n_seq_id = NULL,
239-
.seq_id = NULL,
240-
.logits = NULL,
241-
};
330+
if (!dbmem_local_batch_prepare(engine, n_tokens)) {
331+
dbmem_local_set_error(engine, "Failed to prepare llama batch");
332+
return -1;
333+
}
242334

243335
// Clear memory
244336
if (engine->mem != NULL) {
245337
llama_memory_clear(engine->mem, true);
246338
}
247339

248340
// Encode
249-
int ret = llama_encode(engine->ctx, batch);
341+
int ret = dbmem_local_process_batch(engine);
250342
if (ret != 0) {
251-
dbmem_local_set_error(engine, "Llama_encode failed");
343+
dbmem_local_set_error(engine, "llama batch processing failed");
252344
return -1;
253345
}
254346

@@ -297,6 +389,11 @@ void dbmem_local_engine_free (dbmem_local_engine_t *engine) {
297389
dbmemory_free(engine->tokens);
298390
engine->tokens = NULL;
299391
}
392+
if (engine->batch_initialized) {
393+
llama_batch_free(engine->batch);
394+
memset(&engine->batch, 0, sizeof(engine->batch));
395+
engine->batch_initialized = false;
396+
}
300397
if (engine->ctx) {
301398
llama_free(engine->ctx);
302399
engine->ctx = NULL;

src/dbmem-parser.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,8 +1110,12 @@ int dbmem_parse (const char *md, size_t md_len, dbmem_parse_settings *settings)
11101110
src_len = src_end - src_off;
11111111
}
11121112

1113-
// Invoke callback
1114-
if (settings->callback) {
1113+
// Invoke callback (skip whitespace-only chunks)
1114+
bool has_text = false;
1115+
for (size_t k = 0; k < chunk_len; k++) {
1116+
if (!isspace((unsigned char)chunk_text[k])) { has_text = true; break; }
1117+
}
1118+
if (has_text && settings->callback) {
11151119
rc = settings->callback(chunk_text, chunk_len, src_off, src_len, settings->xdata, i);
11161120
if (rc != 0) break;
11171121
}

0 commit comments

Comments
 (0)