|
2 | 2 | #include <atomic> |
3 | 3 | #include <cerrno> |
4 | 4 | #include <cctype> |
| 5 | +#include <cmath> |
5 | 6 | #include <cstdlib> |
6 | 7 | #include <cstdint> |
7 | 8 | #include <cstring> |
@@ -46,6 +47,7 @@ std::string g_last_output; |
46 | 47 | std::string g_last_piece; |
47 | 48 | std::string g_last_tokens_json = "[]"; |
48 | 49 | std::string g_last_detokenized; |
| 50 | +std::string g_last_embedding_json = "[]"; |
49 | 51 | std::string g_backend_json = "[]"; |
50 | 52 | std::string g_model_meta_json = "{}"; |
51 | 53 | std::vector<llama_token> g_cached_prompt_tokens; |
@@ -185,6 +187,7 @@ void free_runtime() { |
185 | 187 | g_last_piece.clear(); |
186 | 188 | g_last_tokens_json = "[]"; |
187 | 189 | g_last_detokenized.clear(); |
| 190 | + g_last_embedding_json = "[]"; |
188 | 191 | g_model_meta_json = "{}"; |
189 | 192 | g_cached_prompt_tokens.clear(); |
190 | 193 | } |
@@ -447,6 +450,35 @@ std::string serialize_tokens_json(const std::vector<llama_token> & tokens) { |
447 | 450 | return json; |
448 | 451 | } |
449 | 452 |
|
| 453 | +std::string serialize_embedding_json(const std::vector<float> & embedding) { |
| 454 | + std::string json = "["; |
| 455 | + for (size_t i = 0; i < embedding.size(); ++i) { |
| 456 | + if (i > 0) { |
| 457 | + json += ","; |
| 458 | + } |
| 459 | + json += std::to_string(static_cast<double>(embedding[i])); |
| 460 | + } |
| 461 | + json += "]"; |
| 462 | + return json; |
| 463 | +} |
| 464 | + |
| 465 | +void normalize_embedding_inplace(std::vector<float> & embedding) { |
| 466 | + double norm_squared = 0.0; |
| 467 | + for (const float value : embedding) { |
| 468 | + const double dv = static_cast<double>(value); |
| 469 | + norm_squared += dv * dv; |
| 470 | + } |
| 471 | + |
| 472 | + if (norm_squared <= 0.0) { |
| 473 | + return; |
| 474 | + } |
| 475 | + |
| 476 | + const double scale = 1.0 / std::sqrt(norm_squared); |
| 477 | + for (float & value : embedding) { |
| 478 | + value = static_cast<float>(static_cast<double>(value) * scale); |
| 479 | + } |
| 480 | +} |
| 481 | + |
450 | 482 | void parse_token_list(const char * token_text, std::vector<llama_token> & out_tokens) { |
451 | 483 | out_tokens.clear(); |
452 | 484 | if (token_text == nullptr) { |
@@ -1267,6 +1299,142 @@ EMSCRIPTEN_KEEPALIVE const char * llamadart_webgpu_last_detokenized() { |
1267 | 1299 | return g_last_detokenized.c_str(); |
1268 | 1300 | } |
1269 | 1301 |
|
| 1302 | +EMSCRIPTEN_KEEPALIVE int32_t llamadart_webgpu_embed_to_json( |
| 1303 | + const char * text, |
| 1304 | + int32_t normalize) { |
| 1305 | + clear_error(); |
| 1306 | + g_last_embedding_json = "[]"; |
| 1307 | + |
| 1308 | + if (!ensure_loaded()) { |
| 1309 | + return -1; |
| 1310 | + } |
| 1311 | + |
| 1312 | + if (text == nullptr) { |
| 1313 | + set_error("Text is null"); |
| 1314 | + return -2; |
| 1315 | + } |
| 1316 | + |
| 1317 | + const bool has_encoder = llama_model_has_encoder(g_state.model); |
| 1318 | + const bool has_decoder = llama_model_has_decoder(g_state.model); |
| 1319 | + if (has_encoder && has_decoder) { |
| 1320 | + set_error("Embedding extraction for encoder-decoder models is not supported"); |
| 1321 | + return -3; |
| 1322 | + } |
| 1323 | + const bool use_encoder_path = has_encoder && !has_decoder; |
| 1324 | + |
| 1325 | + std::vector<llama_token> tokens; |
| 1326 | + if (!tokenize_text(std::string(text), true, tokens)) { |
| 1327 | + return -4; |
| 1328 | + } |
| 1329 | + |
| 1330 | + if (tokens.empty()) { |
| 1331 | + set_error("Embedding input tokenized to an empty sequence"); |
| 1332 | + return -5; |
| 1333 | + } |
| 1334 | + |
| 1335 | + int32_t embedding_size = llama_model_n_embd_out(g_state.model); |
| 1336 | + if (embedding_size <= 0) { |
| 1337 | + embedding_size = llama_model_n_embd(g_state.model); |
| 1338 | + } |
| 1339 | + if (embedding_size <= 0) { |
| 1340 | + set_error("Failed to resolve embedding dimension"); |
| 1341 | + return -6; |
| 1342 | + } |
| 1343 | + |
| 1344 | + int32_t max_batch = static_cast<int32_t>(llama_n_batch(g_state.ctx)); |
| 1345 | + if (max_batch <= 0) { |
| 1346 | + max_batch = static_cast<int32_t>(tokens.size()); |
| 1347 | + } |
| 1348 | + max_batch = std::max<int32_t>(1, std::min<int32_t>(max_batch, static_cast<int32_t>(tokens.size()))); |
| 1349 | + |
| 1350 | + llama_batch batch = llama_batch_init(max_batch, 0, 1); |
| 1351 | + if (batch.token == nullptr || batch.pos == nullptr || |
| 1352 | + batch.n_seq_id == nullptr || batch.seq_id == nullptr || |
| 1353 | + batch.logits == nullptr) { |
| 1354 | + llama_batch_free(batch); |
| 1355 | + set_error("Failed to allocate embedding batch buffers"); |
| 1356 | + return -7; |
| 1357 | + } |
| 1358 | + |
| 1359 | + int32_t rc = embedding_size; |
| 1360 | + |
| 1361 | + llama_synchronize(g_state.ctx); |
| 1362 | + auto * memory = llama_get_memory(g_state.ctx); |
| 1363 | + if (memory != nullptr) { |
| 1364 | + llama_memory_clear(memory, false); |
| 1365 | + } |
| 1366 | + g_cached_prompt_tokens.clear(); |
| 1367 | + llama_set_embeddings(g_state.ctx, true); |
| 1368 | + |
| 1369 | + int32_t decoded_tokens = 0; |
| 1370 | + while (decoded_tokens < static_cast<int32_t>(tokens.size())) { |
| 1371 | + const int32_t remaining = static_cast<int32_t>(tokens.size()) - decoded_tokens; |
| 1372 | + const int32_t chunk_token_count = std::min(max_batch, remaining); |
| 1373 | + batch.n_tokens = chunk_token_count; |
| 1374 | + |
| 1375 | + for (int32_t i = 0; i < chunk_token_count; ++i) { |
| 1376 | + const int32_t token_index = decoded_tokens + i; |
| 1377 | + batch.token[i] = tokens[static_cast<size_t>(token_index)]; |
| 1378 | + batch.pos[i] = token_index; |
| 1379 | + batch.n_seq_id[i] = 1; |
| 1380 | + batch.seq_id[i][0] = 0; |
| 1381 | + batch.logits[i] = 1; |
| 1382 | + } |
| 1383 | + |
| 1384 | + const int status = use_encoder_path |
| 1385 | + ? llama_encode(g_state.ctx, batch) |
| 1386 | + : llama_decode(g_state.ctx, batch); |
| 1387 | + if (status != 0) { |
| 1388 | + set_error("Embedding forward pass failed"); |
| 1389 | + rc = -8; |
| 1390 | + break; |
| 1391 | + } |
| 1392 | + |
| 1393 | + decoded_tokens += chunk_token_count; |
| 1394 | + } |
| 1395 | + |
| 1396 | + if (rc > 0) { |
| 1397 | + const enum llama_pooling_type pooling_type = llama_pooling_type(g_state.ctx); |
| 1398 | + float * embedding_ptr = nullptr; |
| 1399 | + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { |
| 1400 | + embedding_ptr = llama_get_embeddings_ith(g_state.ctx, batch.n_tokens - 1); |
| 1401 | + if (embedding_ptr == nullptr) { |
| 1402 | + embedding_ptr = llama_get_embeddings(g_state.ctx); |
| 1403 | + } |
| 1404 | + } else { |
| 1405 | + embedding_ptr = llama_get_embeddings_seq(g_state.ctx, 0); |
| 1406 | + if (embedding_ptr == nullptr) { |
| 1407 | + embedding_ptr = llama_get_embeddings(g_state.ctx); |
| 1408 | + } |
| 1409 | + } |
| 1410 | + |
| 1411 | + if (embedding_ptr == nullptr) { |
| 1412 | + set_error("Embedding output is unavailable"); |
| 1413 | + rc = -9; |
| 1414 | + } else { |
| 1415 | + std::vector<float> embedding( |
| 1416 | + embedding_ptr, |
| 1417 | + embedding_ptr + static_cast<size_t>(embedding_size)); |
| 1418 | + if (normalize != 0) { |
| 1419 | + normalize_embedding_inplace(embedding); |
| 1420 | + } |
| 1421 | + |
| 1422 | + g_last_embedding_json = serialize_embedding_json(embedding); |
| 1423 | + } |
| 1424 | + } |
| 1425 | + |
| 1426 | + { |
| 1427 | + llama_set_embeddings(g_state.ctx, false); |
| 1428 | + llama_batch_free(batch); |
| 1429 | + } |
| 1430 | + |
| 1431 | + return rc; |
| 1432 | +} |
| 1433 | + |
| 1434 | +EMSCRIPTEN_KEEPALIVE const char * llamadart_webgpu_last_embedding_json() { |
| 1435 | + return g_last_embedding_json.c_str(); |
| 1436 | +} |
| 1437 | + |
1270 | 1438 | EMSCRIPTEN_KEEPALIVE int32_t llamadart_webgpu_generate( |
1271 | 1439 | const char * prompt, |
1272 | 1440 | int32_t n_predict, |
|
0 commit comments