Skip to content

Commit caaff1b

Browse files
committed
feat: add embedding support to webgpu bridge
1 parent 79d91c2 commit caaff1b

3 files changed

Lines changed: 244 additions & 1 deletion

File tree

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ set(LLAMADART_WEBGPU_LINK_OPTIONS
145145
"-sEXPORT_NAME=createLlamaWebGpuCoreModule"
146146
"-sENVIRONMENT=web,worker"
147147
"-sEXPORTED_RUNTIME_METHODS=['FS','ccall','UTF8ToString']"
148-
"-sEXPORTED_FUNCTIONS=['_main','_llamadart_webgpu_probe','_llamadart_webgpu_backends_json','_llamadart_webgpu_last_error','_llamadart_webgpu_set_log_level','_llamadart_webgpu_load_model','_llamadart_webgpu_load_model_from_url','_llamadart_webgpu_mmproj_load','_llamadart_webgpu_mmproj_free','_llamadart_webgpu_mmproj_supports_vision','_llamadart_webgpu_mmproj_supports_audio','_llamadart_webgpu_media_clear_pending','_llamadart_webgpu_media_add_file','_llamadart_webgpu_media_add_encoded','_llamadart_webgpu_media_add_rgb','_llamadart_webgpu_media_add_audio_f32','_llamadart_webgpu_tokenize_to_json','_llamadart_webgpu_last_tokens_json','_llamadart_webgpu_detokenize_from_json','_llamadart_webgpu_last_detokenized','_llamadart_webgpu_generate','_llamadart_webgpu_begin_generation','_llamadart_webgpu_next_token','_llamadart_webgpu_last_piece','_llamadart_webgpu_end_generation','_llamadart_webgpu_request_cancel','_llamadart_webgpu_last_output','_llamadart_webgpu_get_context_size','_llamadart_webgpu_model_meta_json','_llamadart_webgpu_shutdown']"
148+
"-sEXPORTED_FUNCTIONS=['_main','_llamadart_webgpu_probe','_llamadart_webgpu_backends_json','_llamadart_webgpu_last_error','_llamadart_webgpu_set_log_level','_llamadart_webgpu_load_model','_llamadart_webgpu_load_model_from_url','_llamadart_webgpu_mmproj_load','_llamadart_webgpu_mmproj_free','_llamadart_webgpu_mmproj_supports_vision','_llamadart_webgpu_mmproj_supports_audio','_llamadart_webgpu_media_clear_pending','_llamadart_webgpu_media_add_file','_llamadart_webgpu_media_add_encoded','_llamadart_webgpu_media_add_rgb','_llamadart_webgpu_media_add_audio_f32','_llamadart_webgpu_tokenize_to_json','_llamadart_webgpu_last_tokens_json','_llamadart_webgpu_detokenize_from_json','_llamadart_webgpu_last_detokenized','_llamadart_webgpu_embed_to_json','_llamadart_webgpu_last_embedding_json','_llamadart_webgpu_generate','_llamadart_webgpu_begin_generation','_llamadart_webgpu_next_token','_llamadart_webgpu_last_piece','_llamadart_webgpu_end_generation','_llamadart_webgpu_request_cancel','_llamadart_webgpu_last_output','_llamadart_webgpu_get_context_size','_llamadart_webgpu_model_meta_json','_llamadart_webgpu_shutdown']"
149149
"-lwasmfs_fetch.js"
150150
)
151151

js/llama_webgpu_bridge.js

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2700,6 +2700,52 @@ class LlamaWebGpuBridgeRuntime {
27002700
return this._core.ccall('llamadart_webgpu_last_detokenized', 'string', [], []) || '';
27012701
}
27022702

2703+
async embed(text, options = {}) {
2704+
if (this._modelBytes <= 0) {
2705+
throw new Error('No model loaded. Call loadModelFromUrl first.');
2706+
}
2707+
2708+
const normalize = options?.normalize !== false;
2709+
const rc = Number(
2710+
await this._core.ccall(
2711+
'llamadart_webgpu_embed_to_json',
2712+
'number',
2713+
['string', 'number'],
2714+
[String(text), normalize ? 1 : 0],
2715+
{ async: true },
2716+
),
2717+
);
2718+
2719+
if (rc < 0) {
2720+
throw new Error(this._coreErrorMessage('Embedding generation failed', rc));
2721+
}
2722+
2723+
const raw = this._core.ccall('llamadart_webgpu_last_embedding_json', 'string', [], []) || '[]';
2724+
const parsed = JSON.parse(raw);
2725+
return Array.isArray(parsed)
2726+
? parsed.map((v) => {
2727+
const numeric = Number(v);
2728+
return Number.isFinite(numeric) ? numeric : 0;
2729+
})
2730+
: [];
2731+
}
2732+
2733+
async embedBatch(texts, options = {}) {
2734+
const normalized = Array.isArray(texts)
2735+
? texts
2736+
: Array.from(texts || []);
2737+
if (normalized.length === 0) {
2738+
return [];
2739+
}
2740+
2741+
const normalize = options?.normalize !== false;
2742+
const vectors = [];
2743+
for (const text of normalized) {
2744+
vectors.push(await this.embed(String(text), { normalize }));
2745+
}
2746+
return vectors;
2747+
}
2748+
27032749
getModelMetadata() {
27042750
let modelMetadata = {};
27052751

@@ -3202,6 +3248,35 @@ export class LlamaWebGpuBridge {
32023248
}
32033249
}
32043250

3251+
async embed(text, options = {}) {
3252+
if (!this._workerProxy) {
3253+
return this._runtime.embed(text, options);
3254+
}
3255+
3256+
try {
3257+
return await this._callWorker('embed', [text, options]);
3258+
} catch (error) {
3259+
this._disableWorkerFallback(error);
3260+
return this._runtime.embed(text, options);
3261+
}
3262+
}
3263+
3264+
async embedBatch(texts, options = {}) {
3265+
const normalized = Array.isArray(texts)
3266+
? texts
3267+
: Array.from(texts || []);
3268+
if (!this._workerProxy) {
3269+
return this._runtime.embedBatch(normalized, options);
3270+
}
3271+
3272+
try {
3273+
return await this._callWorker('embedBatch', [normalized, options]);
3274+
} catch (error) {
3275+
this._disableWorkerFallback(error);
3276+
return this._runtime.embedBatch(normalized, options);
3277+
}
3278+
}
3279+
32053280
getModelMetadata() {
32063281
if (this._workerProxy) {
32073282
return {

src/llama_webgpu_core.cpp

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <atomic>
33
#include <cerrno>
44
#include <cctype>
5+
#include <cmath>
56
#include <cstdlib>
67
#include <cstdint>
78
#include <cstring>
@@ -46,6 +47,7 @@ std::string g_last_output;
4647
std::string g_last_piece;
4748
std::string g_last_tokens_json = "[]";
4849
std::string g_last_detokenized;
50+
std::string g_last_embedding_json = "[]";
4951
std::string g_backend_json = "[]";
5052
std::string g_model_meta_json = "{}";
5153
std::vector<llama_token> g_cached_prompt_tokens;
@@ -185,6 +187,7 @@ void free_runtime() {
185187
g_last_piece.clear();
186188
g_last_tokens_json = "[]";
187189
g_last_detokenized.clear();
190+
g_last_embedding_json = "[]";
188191
g_model_meta_json = "{}";
189192
g_cached_prompt_tokens.clear();
190193
}
@@ -447,6 +450,35 @@ std::string serialize_tokens_json(const std::vector<llama_token> & tokens) {
447450
return json;
448451
}
449452

453+
std::string serialize_embedding_json(const std::vector<float> & embedding) {
454+
std::string json = "[";
455+
for (size_t i = 0; i < embedding.size(); ++i) {
456+
if (i > 0) {
457+
json += ",";
458+
}
459+
json += std::to_string(static_cast<double>(embedding[i]));
460+
}
461+
json += "]";
462+
return json;
463+
}
464+
465+
void normalize_embedding_inplace(std::vector<float> & embedding) {
466+
double norm_squared = 0.0;
467+
for (const float value : embedding) {
468+
const double dv = static_cast<double>(value);
469+
norm_squared += dv * dv;
470+
}
471+
472+
if (norm_squared <= 0.0) {
473+
return;
474+
}
475+
476+
const double scale = 1.0 / std::sqrt(norm_squared);
477+
for (float & value : embedding) {
478+
value = static_cast<float>(static_cast<double>(value) * scale);
479+
}
480+
}
481+
450482
void parse_token_list(const char * token_text, std::vector<llama_token> & out_tokens) {
451483
out_tokens.clear();
452484
if (token_text == nullptr) {
@@ -1267,6 +1299,142 @@ EMSCRIPTEN_KEEPALIVE const char * llamadart_webgpu_last_detokenized() {
12671299
return g_last_detokenized.c_str();
12681300
}
12691301

1302+
EMSCRIPTEN_KEEPALIVE int32_t llamadart_webgpu_embed_to_json(
1303+
const char * text,
1304+
int32_t normalize) {
1305+
clear_error();
1306+
g_last_embedding_json = "[]";
1307+
1308+
if (!ensure_loaded()) {
1309+
return -1;
1310+
}
1311+
1312+
if (text == nullptr) {
1313+
set_error("Text is null");
1314+
return -2;
1315+
}
1316+
1317+
const bool has_encoder = llama_model_has_encoder(g_state.model);
1318+
const bool has_decoder = llama_model_has_decoder(g_state.model);
1319+
if (has_encoder && has_decoder) {
1320+
set_error("Embedding extraction for encoder-decoder models is not supported");
1321+
return -3;
1322+
}
1323+
const bool use_encoder_path = has_encoder && !has_decoder;
1324+
1325+
std::vector<llama_token> tokens;
1326+
if (!tokenize_text(std::string(text), true, tokens)) {
1327+
return -4;
1328+
}
1329+
1330+
if (tokens.empty()) {
1331+
set_error("Embedding input tokenized to an empty sequence");
1332+
return -5;
1333+
}
1334+
1335+
int32_t embedding_size = llama_model_n_embd_out(g_state.model);
1336+
if (embedding_size <= 0) {
1337+
embedding_size = llama_model_n_embd(g_state.model);
1338+
}
1339+
if (embedding_size <= 0) {
1340+
set_error("Failed to resolve embedding dimension");
1341+
return -6;
1342+
}
1343+
1344+
int32_t max_batch = static_cast<int32_t>(llama_n_batch(g_state.ctx));
1345+
if (max_batch <= 0) {
1346+
max_batch = static_cast<int32_t>(tokens.size());
1347+
}
1348+
max_batch = std::max<int32_t>(1, std::min<int32_t>(max_batch, static_cast<int32_t>(tokens.size())));
1349+
1350+
llama_batch batch = llama_batch_init(max_batch, 0, 1);
1351+
if (batch.token == nullptr || batch.pos == nullptr ||
1352+
batch.n_seq_id == nullptr || batch.seq_id == nullptr ||
1353+
batch.logits == nullptr) {
1354+
llama_batch_free(batch);
1355+
set_error("Failed to allocate embedding batch buffers");
1356+
return -7;
1357+
}
1358+
1359+
int32_t rc = embedding_size;
1360+
1361+
llama_synchronize(g_state.ctx);
1362+
auto * memory = llama_get_memory(g_state.ctx);
1363+
if (memory != nullptr) {
1364+
llama_memory_clear(memory, false);
1365+
}
1366+
g_cached_prompt_tokens.clear();
1367+
llama_set_embeddings(g_state.ctx, true);
1368+
1369+
int32_t decoded_tokens = 0;
1370+
while (decoded_tokens < static_cast<int32_t>(tokens.size())) {
1371+
const int32_t remaining = static_cast<int32_t>(tokens.size()) - decoded_tokens;
1372+
const int32_t chunk_token_count = std::min(max_batch, remaining);
1373+
batch.n_tokens = chunk_token_count;
1374+
1375+
for (int32_t i = 0; i < chunk_token_count; ++i) {
1376+
const int32_t token_index = decoded_tokens + i;
1377+
batch.token[i] = tokens[static_cast<size_t>(token_index)];
1378+
batch.pos[i] = token_index;
1379+
batch.n_seq_id[i] = 1;
1380+
batch.seq_id[i][0] = 0;
1381+
batch.logits[i] = 1;
1382+
}
1383+
1384+
const int status = use_encoder_path
1385+
? llama_encode(g_state.ctx, batch)
1386+
: llama_decode(g_state.ctx, batch);
1387+
if (status != 0) {
1388+
set_error("Embedding forward pass failed");
1389+
rc = -8;
1390+
break;
1391+
}
1392+
1393+
decoded_tokens += chunk_token_count;
1394+
}
1395+
1396+
if (rc > 0) {
1397+
const enum llama_pooling_type pooling_type = llama_pooling_type(g_state.ctx);
1398+
float * embedding_ptr = nullptr;
1399+
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
1400+
embedding_ptr = llama_get_embeddings_ith(g_state.ctx, batch.n_tokens - 1);
1401+
if (embedding_ptr == nullptr) {
1402+
embedding_ptr = llama_get_embeddings(g_state.ctx);
1403+
}
1404+
} else {
1405+
embedding_ptr = llama_get_embeddings_seq(g_state.ctx, 0);
1406+
if (embedding_ptr == nullptr) {
1407+
embedding_ptr = llama_get_embeddings(g_state.ctx);
1408+
}
1409+
}
1410+
1411+
if (embedding_ptr == nullptr) {
1412+
set_error("Embedding output is unavailable");
1413+
rc = -9;
1414+
} else {
1415+
std::vector<float> embedding(
1416+
embedding_ptr,
1417+
embedding_ptr + static_cast<size_t>(embedding_size));
1418+
if (normalize != 0) {
1419+
normalize_embedding_inplace(embedding);
1420+
}
1421+
1422+
g_last_embedding_json = serialize_embedding_json(embedding);
1423+
}
1424+
}
1425+
1426+
{
1427+
llama_set_embeddings(g_state.ctx, false);
1428+
llama_batch_free(batch);
1429+
}
1430+
1431+
return rc;
1432+
}
1433+
1434+
EMSCRIPTEN_KEEPALIVE const char * llamadart_webgpu_last_embedding_json() {
1435+
return g_last_embedding_json.c_str();
1436+
}
1437+
12701438
EMSCRIPTEN_KEEPALIVE int32_t llamadart_webgpu_generate(
12711439
const char * prompt,
12721440
int32_t n_predict,

0 commit comments

Comments
 (0)