1313#include <math.h>
1414#include <string.h>
1515
16+ #define DBMEM_LOCAL_MIN_CONTEXT_TOKENS 128
17+ #define DBMEM_LOCAL_MAX_CONTEXT_TOKENS 8192
18+
19+ #if defined(_MSC_VER )
20+ #define DBMEM_THREAD_LOCAL __declspec(thread)
21+ #elif defined(__STDC_VERSION__ ) && __STDC_VERSION__ >= 201112L
22+ #define DBMEM_THREAD_LOCAL _Thread_local
23+ #else
24+ #define DBMEM_THREAD_LOCAL __thread
25+ #endif
26+
27+ static DBMEM_THREAD_LOCAL bool dbmem_llama_diag_enabled = false;
28+ static DBMEM_THREAD_LOCAL char dbmem_llama_diag [DBMEM_ERRBUF_SIZE ];
29+ static DBMEM_THREAD_LOCAL size_t dbmem_llama_diag_len = 0 ;
30+
1631struct dbmem_local_engine_t {
1732 dbmem_context * context ;
1833
@@ -26,14 +41,17 @@ struct dbmem_local_engine_t {
2641 // Model info
2742 int n_embd ; // Embedding dimension (e.g., 768 for nomic-embed)
2843 int n_ctx ; // Maximum context length in tokens
44+ int n_ubatch ; // Maximum physical batch size for encoder input
2945 bool is_encoder_only ; // True for BERT-style models, false for GPT-style
3046
3147 // Settings
3248 bool normalize ; // Whether to L2 normalize output embeddings
3349
3450 // Reusable buffers (avoid repeated allocations)
3551 llama_token * tokens ; // Pre-allocated buffer for tokenized input
36- int tokens_capacity ; // Size of tokens buffer (equals n_ctx)
52+ int tokens_capacity ; // Size of tokens buffer, capped by n_ubatch
53+ struct llama_batch batch ; // Pre-allocated llama.cpp batch with sequence metadata
54+ bool batch_initialized ; // True when batch must be freed
3755 float * embedding ; // Pre-allocated buffer for output embedding (n_embd floats)
3856
3957 // Statistics
@@ -75,75 +93,130 @@ static void dbmem_embedding_normalize (float *vec, int n) {
7593}
7694
7795void dbmem_logger (enum ggml_log_level level , const char * text , void * user_data ) {
78- dbmem_local_engine_t * engine = (dbmem_local_engine_t * )user_data ;
79- //if (ai->db == NULL) return;
80- //if ((level == GGML_LOG_LEVEL_INFO) && (ai->options.log_info == false)) return;
81-
82- const char * type = NULL ;
83- switch (level ) {
84- case GGML_LOG_LEVEL_NONE : type = "NONE" ; break ;
85- case GGML_LOG_LEVEL_DEBUG : type = "DEBUG" ; break ;
86- case GGML_LOG_LEVEL_INFO : type = "INFO" ; break ;
87- case GGML_LOG_LEVEL_WARN : type = "WARNING" ; break ;
88- case GGML_LOG_LEVEL_ERROR : type = "ERROR" ; break ;
89- case GGML_LOG_LEVEL_CONT : type = NULL ; break ;
96+ UNUSED_PARAM (user_data );
97+ if (!dbmem_llama_diag_enabled || !text ) return ;
98+
99+ if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_CONT ) {
100+ size_t remaining = sizeof (dbmem_llama_diag ) - dbmem_llama_diag_len ;
101+ if (remaining > 1 ) {
102+ int written = snprintf (dbmem_llama_diag + dbmem_llama_diag_len , remaining , "%s" , text );
103+ if (written > 0 ) {
104+ size_t used = (size_t )written ;
105+ if (used >= remaining ) {
106+ dbmem_llama_diag_len = sizeof (dbmem_llama_diag ) - 1 ;
107+ } else {
108+ dbmem_llama_diag_len += used ;
109+ }
110+ }
111+ }
90112 }
91-
92- // DEBUG
93- // printf("%s %s\n", type, text);
94-
95- //const char *values[] = {type, text};
96- //int types[] = {(type == NULL) ? SQLITE_NULL : SQLITE_TEXT, SQLITE_TEXT};
97- //int lens[] = {-1, -1};
98- //sqlite_db_write(NULL, ai->db, LOG_TABLE_INSERT_STMT, values, types, lens, 2);
99113}
100114
101115// MARK: -
102116
117+ static void dbmem_llama_diag_begin (void ) {
118+ dbmem_llama_diag [0 ] = 0 ;
119+ dbmem_llama_diag_len = 0 ;
120+ dbmem_llama_diag_enabled = true;
121+ }
122+
123+ static void dbmem_llama_diag_end (void ) {
124+ dbmem_llama_diag_enabled = false;
125+ }
126+
127+ static const char * dbmem_llama_diag_message (void ) {
128+ return dbmem_llama_diag [0 ] ? dbmem_llama_diag : NULL ;
129+ }
130+
103131static void dbmem_local_set_error (dbmem_local_engine_t * engine , const char * message ) {
104132 if (!engine || !engine -> context ) return ;
105133 dbmem_context_set_error (engine -> context , message );
106134}
107135
108- dbmem_local_engine_t * dbmem_local_engine_init (void * ctx , const char * model_path , char err_msg [DBMEM_ERRBUF_SIZE ]) {
136+ static bool dbmem_local_batch_prepare (dbmem_local_engine_t * engine , int n_tokens ) {
137+ if (!engine || !engine -> batch_initialized ) return false;
138+ if (n_tokens <= 0 || n_tokens > engine -> tokens_capacity ) return false;
139+
140+ engine -> batch .n_tokens = 0 ;
141+ for (int i = 0 ; i < n_tokens ; i ++ ) {
142+ engine -> batch .token [i ] = engine -> tokens [i ];
143+ engine -> batch .pos [i ] = i ;
144+ engine -> batch .n_seq_id [i ] = 1 ;
145+ engine -> batch .seq_id [i ][0 ] = 0 ;
146+ engine -> batch .logits [i ] = 1 ;
147+ engine -> batch .n_tokens ++ ;
148+ }
149+
150+ return true;
151+ }
152+
153+ static int dbmem_local_process_batch (dbmem_local_engine_t * engine ) {
154+ if (engine -> is_encoder_only ) {
155+ return llama_encode (engine -> ctx , engine -> batch );
156+ }
157+ return llama_decode (engine -> ctx , engine -> batch );
158+ }
159+
160+ dbmem_local_engine_t * dbmem_local_engine_init (void * ctx , const char * model_path , int max_context_tokens , char err_msg [DBMEM_ERRBUF_SIZE ]) {
109161 dbmem_local_engine_t * engine = (dbmem_local_engine_t * )dbmemory_zeroalloc (sizeof (dbmem_local_engine_t ));
110162 if (!engine ) return NULL ;
111163 engine -> context = (dbmem_context * )ctx ;
112164
113165 // set logger
114- llama_log_set (dbmem_logger , engine );
166+ llama_log_set (dbmem_logger , NULL );
167+ dbmem_llama_diag_begin ();
115168
116169 // Initialize backend
117170 llama_backend_init ();
118171
119172 // Load model
120173 struct llama_model_params model_params = llama_model_default_params ();
174+ model_params .n_gpu_layers = 0 ;
175+ model_params .split_mode = LLAMA_SPLIT_MODE_NONE ;
176+ model_params .main_gpu = -1 ;
121177 engine -> model = llama_model_load_from_file (model_path , model_params );
122178 if (!engine -> model ) {
123- snprintf (err_msg , DBMEM_ERRBUF_SIZE , "Failed to load model: %s" , model_path );
179+ const char * diag = dbmem_llama_diag_message ();
180+ if (diag ) {
181+ snprintf (err_msg , DBMEM_ERRBUF_SIZE , "Failed to load model: %s: %s" , model_path , diag );
182+ } else {
183+ snprintf (err_msg , DBMEM_ERRBUF_SIZE , "Failed to load model: %s" , model_path );
184+ }
124185 goto cleanup ;
125186 }
126187
127188 // Get model's native context length
128189 int n_ctx_train = llama_model_n_ctx_train (engine -> model );
190+ int n_ctx = max_context_tokens * 4 ;
191+ if (n_ctx < DBMEM_LOCAL_MIN_CONTEXT_TOKENS ) n_ctx = DBMEM_LOCAL_MIN_CONTEXT_TOKENS ;
192+ if (n_ctx > DBMEM_LOCAL_MAX_CONTEXT_TOKENS ) n_ctx = DBMEM_LOCAL_MAX_CONTEXT_TOKENS ;
193+ if (n_ctx_train > 0 && n_ctx > n_ctx_train ) n_ctx = n_ctx_train ;
129194
130195 // Create context
131196 struct llama_context_params ctx_params = llama_context_default_params ();
132197 ctx_params .embeddings = true;
133- ctx_params .n_ctx = n_ctx_train ;
134- ctx_params .n_batch = n_ctx_train ;
135- ctx_params .n_ubatch = n_ctx_train ;
198+ ctx_params .n_ctx = n_ctx ;
199+ ctx_params .n_batch = n_ctx ;
200+ ctx_params .n_ubatch = n_ctx ;
201+ ctx_params .offload_kqv = false;
202+ ctx_params .op_offload = false;
136203
137204 engine -> ctx = llama_init_from_model (engine -> model , ctx_params );
138205 if (!engine -> ctx ) {
139- snprintf (err_msg , DBMEM_ERRBUF_SIZE , "Failed to create context" );
206+ const char * diag = dbmem_llama_diag_message ();
207+ if (diag ) {
208+ snprintf (err_msg , DBMEM_ERRBUF_SIZE , "Failed to create context: %s" , diag );
209+ } else {
210+ snprintf (err_msg , DBMEM_ERRBUF_SIZE , "Failed to create context" );
211+ }
140212 goto cleanup ;
141213 }
142214
143215 // Get model info
144216 engine -> vocab = llama_model_get_vocab (engine -> model );
145- engine -> n_embd = llama_model_n_embd (engine -> model );
217+ engine -> n_embd = llama_model_n_embd_out (engine -> model );
146218 engine -> n_ctx = llama_n_ctx (engine -> ctx );
219+ engine -> n_ubatch = llama_n_ubatch (engine -> ctx );
147220 engine -> pooling = llama_pooling_type (engine -> ctx );
148221 engine -> mem = llama_get_memory (engine -> ctx );
149222
@@ -159,12 +232,22 @@ dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path
159232
160233 // Allocate token buffer
161234 engine -> tokens_capacity = engine -> n_ctx ;
235+ if (engine -> n_ubatch > 0 && engine -> tokens_capacity > engine -> n_ubatch ) {
236+ engine -> tokens_capacity = engine -> n_ubatch ;
237+ }
162238 engine -> tokens = (llama_token * )dbmemory_alloc (sizeof (llama_token ) * engine -> tokens_capacity );
163239 if (!engine -> tokens ) {
164240 snprintf (err_msg , DBMEM_ERRBUF_SIZE , "Failed to allocate token buffer" );
165241 goto cleanup ;
166242 }
167243
244+ engine -> batch = llama_batch_init (engine -> tokens_capacity , 0 , 1 );
245+ engine -> batch_initialized = true;
246+ if (!engine -> batch .token || !engine -> batch .pos || !engine -> batch .n_seq_id || !engine -> batch .seq_id || !engine -> batch .logits ) {
247+ snprintf (err_msg , DBMEM_ERRBUF_SIZE , "Failed to allocate llama batch" );
248+ goto cleanup ;
249+ }
250+
168251 // Allocate single embedding buffer
169252 engine -> embedding = (float * )dbmemory_alloc (sizeof (float ) * engine -> n_embd );
170253 if (!engine -> embedding ) {
@@ -177,9 +260,11 @@ dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path
177260 engine -> total_tokens_processed = 0 ;
178261 engine -> total_embeddings_generated = 0 ;
179262
263+ dbmem_llama_diag_end ();
180264 return engine ;
181265
182266cleanup :
267+ dbmem_llama_diag_end ();
183268 dbmem_local_engine_free (engine );
184269 return NULL ;
185270}
@@ -190,17 +275,8 @@ bool dbmem_local_engine_warmup (dbmem_local_engine_t *engine) {
190275
191276 const char * warmup_text = "Warmup" ;
192277 int warmup_tokens = llama_tokenize (engine -> vocab , warmup_text , (int32_t )strlen (warmup_text ), engine -> tokens , engine -> tokens_capacity , true, true);
193- if (warmup_tokens > 0 ) {
194- struct llama_batch batch = {
195- .n_tokens = warmup_tokens ,
196- .token = engine -> tokens ,
197- .embd = NULL ,
198- .pos = NULL ,
199- .n_seq_id = NULL ,
200- .seq_id = NULL ,
201- .logits = NULL ,
202- };
203- llama_encode (engine -> ctx , batch );
278+ if (warmup_tokens > 0 && dbmem_local_batch_prepare (engine , warmup_tokens )) {
279+ dbmem_local_process_batch (engine );
204280
205281 if (engine -> mem != NULL ) {
206282 llama_memory_clear (engine -> mem , true);
@@ -215,40 +291,56 @@ int dbmem_local_compute_embedding (dbmem_local_engine_t *engine, const char *tex
215291 if (text_len == -1 ) text_len = (int )strlen (text );
216292 if (text_len == 0 ) return 0 ;
217293
294+ bool truncated = false;
295+
218296 // Tokenize
219297 int n_tokens = llama_tokenize (engine -> vocab , text , text_len , engine -> tokens , engine -> tokens_capacity , true, true);
220298 if (n_tokens < 0 ) {
221- dbmem_local_set_error (engine , "Tokenization failed (text too long?)" );
222- return -1 ;
299+ int needed = - n_tokens ;
300+ if (needed <= 0 ) {
301+ dbmem_local_set_error (engine , "Tokenization failed" );
302+ return -1 ;
303+ }
304+
305+ llama_token * all_tokens = (llama_token * )dbmemory_alloc (sizeof (llama_token ) * needed );
306+ if (!all_tokens ) {
307+ dbmem_local_set_error (engine , "Failed to allocate token overflow buffer" );
308+ return -1 ;
309+ }
310+
311+ int full_tokens = llama_tokenize (engine -> vocab , text , text_len , all_tokens , needed , true, true);
312+ if (full_tokens < 0 ) {
313+ dbmemory_free (all_tokens );
314+ dbmem_local_set_error (engine , "Tokenization failed" );
315+ return -1 ;
316+ }
317+
318+ n_tokens = engine -> tokens_capacity ;
319+ memcpy (engine -> tokens , all_tokens , sizeof (llama_token ) * n_tokens );
320+ dbmemory_free (all_tokens );
321+ truncated = true;
223322 }
224323
225324 // Handle token overflow: truncate to max context size
226- bool truncated = false;
227- if (n_tokens > engine -> n_ctx ) {
325+ if (n_tokens > engine -> tokens_capacity ) {
228326 truncated = true;
229- n_tokens = engine -> n_ctx ;
327+ n_tokens = engine -> tokens_capacity ;
230328 }
231329
232- // Create batch
233- struct llama_batch batch = {
234- .n_tokens = n_tokens ,
235- .token = engine -> tokens ,
236- .embd = NULL ,
237- .pos = NULL ,
238- .n_seq_id = NULL ,
239- .seq_id = NULL ,
240- .logits = NULL ,
241- };
330+ if (!dbmem_local_batch_prepare (engine , n_tokens )) {
331+ dbmem_local_set_error (engine , "Failed to prepare llama batch" );
332+ return -1 ;
333+ }
242334
243335 // Clear memory
244336 if (engine -> mem != NULL ) {
245337 llama_memory_clear (engine -> mem , true);
246338 }
247339
248340 // Encode
249- int ret = llama_encode (engine -> ctx , batch );
341+ int ret = dbmem_local_process_batch (engine );
250342 if (ret != 0 ) {
251- dbmem_local_set_error (engine , "Llama_encode failed" );
343+ dbmem_local_set_error (engine , "llama batch processing failed" );
252344 return -1 ;
253345 }
254346
@@ -297,6 +389,11 @@ void dbmem_local_engine_free (dbmem_local_engine_t *engine) {
297389 dbmemory_free (engine -> tokens );
298390 engine -> tokens = NULL ;
299391 }
392+ if (engine -> batch_initialized ) {
393+ llama_batch_free (engine -> batch );
394+ memset (& engine -> batch , 0 , sizeof (engine -> batch ));
395+ engine -> batch_initialized = false;
396+ }
300397 if (engine -> ctx ) {
301398 llama_free (engine -> ctx );
302399 engine -> ctx = NULL ;
0 commit comments