Skip to content

Commit 3731c76

Browse files
ngxsonxiukeding
authored andcommitted
mtmd: add clip_graph::build_mm() (ggml-org#20751)
* clip: add build_mm() * apply to all models * add TODO for bias overload
1 parent 4af0d28 commit 3731c76

15 files changed

Lines changed: 75 additions & 66 deletions

tools/mtmd/clip-graph.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ struct clip_graph {
4141
virtual ~clip_graph() = default;
4242
virtual ggml_cgraph * build() = 0;
4343

44+
// wrapper around ggml_mul_mat, allow hooking (e.g. LoRA, clamping) depending on the model
45+
// tensor w should be the weight matrix, and tensor x should be the input
46+
virtual ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const;
47+
// TODO: build_mm(w, b, x) to support bias
48+
4449
//
4550
// utility functions
4651
//

tools/mtmd/clip.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,10 @@ clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
255255
gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
256256
}
257257

258+
ggml_tensor * clip_graph::build_mm(ggml_tensor * w, ggml_tensor * x) const {
259+
return ggml_mul_mat(ctx0, w, x);
260+
}
261+
258262
void clip_graph::cb(ggml_tensor * cur, const char * name, int il) const {
259263
if (il >= 0) {
260264
ggml_format_name(cur, "%s-%d", name, il);
@@ -326,7 +330,7 @@ ggml_tensor * clip_graph::build_vit(
326330
ggml_tensor * Vcur = nullptr;
327331
if (layer.qkv_w != nullptr) {
328332
// fused qkv
329-
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
333+
cur = build_mm(layer.qkv_w, cur);
330334
if (layer.qkv_b != nullptr) {
331335
cur = ggml_add(ctx0, cur, layer.qkv_b);
332336
}
@@ -360,17 +364,17 @@ ggml_tensor * clip_graph::build_vit(
360364

361365
} else {
362366
// separate q, k, v
363-
Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
367+
Qcur = build_mm(layer.q_w, cur);
364368
if (layer.q_b) {
365369
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
366370
}
367371

368-
Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
372+
Kcur = build_mm(layer.k_w, cur);
369373
if (layer.k_b) {
370374
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
371375
}
372376

373-
Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
377+
Vcur = build_mm(layer.v_w, cur);
374378
if (layer.v_b) {
375379
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
376380
}
@@ -517,7 +521,7 @@ ggml_tensor * clip_graph::build_ffn(
517521
ffn_op_type type_op,
518522
int il) const {
519523

520-
ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur;
524+
ggml_tensor * tmp = up ? build_mm(up, cur) : cur;
521525
cb(tmp, "ffn_up", il);
522526

523527
if (up_b) {
@@ -526,7 +530,7 @@ ggml_tensor * clip_graph::build_ffn(
526530
}
527531

528532
if (gate) {
529-
cur = ggml_mul_mat(ctx0, gate, cur);
533+
cur = build_mm(gate, cur);
530534
cb(cur, "ffn_gate", il);
531535

532536
if (gate_b) {
@@ -580,7 +584,7 @@ ggml_tensor * clip_graph::build_ffn(
580584
}
581585

582586
if (down) {
583-
cur = ggml_mul_mat(ctx0, down, cur);
587+
cur = build_mm(down, cur);
584588
}
585589

586590
if (down_b) {
@@ -646,7 +650,7 @@ ggml_tensor * clip_graph::build_attn(
646650
cb(cur, "kqv_out", il);
647651

648652
if (wo) {
649-
cur = ggml_mul_mat(ctx0, wo, cur);
653+
cur = build_mm(wo, cur);
650654
}
651655

652656
if (wo_b) {

tools/mtmd/models/cogvlm.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ ggml_cgraph * clip_graph_cogvlm::build() {
1919
auto & layer = model.layers[il];
2020
ggml_tensor * cur = inpL;
2121

22-
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
22+
cur = build_mm(layer.qkv_w, cur);
2323

2424
cur = ggml_add(ctx0, cur, layer.qkv_b);
2525

@@ -67,7 +67,7 @@ ggml_cgraph * clip_graph_cogvlm::build() {
6767
ggml_row_size(inpL->type, n_embd), 0);
6868

6969
// Multiply with mm_model_proj
70-
cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
70+
cur = build_mm(model.mm_model_proj, cur);
7171

7272
// Apply layernorm, weight, bias
7373
cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
@@ -76,16 +76,16 @@ ggml_cgraph * clip_graph_cogvlm::build() {
7676
cur = ggml_gelu_inplace(ctx0, cur);
7777

7878
// Branch 1: multiply with mm_h_to_4h_w
79-
ggml_tensor * h_to_4h = ggml_mul_mat(ctx0, model.mm_h_to_4h_w, cur);
79+
ggml_tensor * h_to_4h = build_mm(model.mm_h_to_4h_w, cur);
8080

8181
// Branch 2: multiply with mm_gate_w
82-
ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_gate_w, cur);
82+
ggml_tensor * gate = build_mm(model.mm_gate_w, cur);
8383

8484
// Apply silu
8585
gate = ggml_swiglu_split(ctx0, gate, h_to_4h);
8686

8787
// Apply mm_4h_to_h_w
88-
cur = ggml_mul_mat(ctx0, model.mm_4h_to_h_w, gate);
88+
cur = build_mm(model.mm_4h_to_h_w, gate);
8989

9090
// Concatenate with boi and eoi
9191
cur = ggml_concat(ctx0, model.mm_boi, cur, 1);

tools/mtmd/models/conformer.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ ggml_cgraph * clip_graph_conformer::build() {
5656
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
5757

5858
// calculate out
59-
cur = ggml_mul_mat(ctx0, model.pre_encode_out_w, cur);
59+
cur = build_mm(model.pre_encode_out_w, cur);
6060
cur = ggml_add(ctx0, cur, model.pre_encode_out_b);
6161
cb(cur, "conformer.pre_encode.out", -1);
6262
}
@@ -87,7 +87,7 @@ ggml_cgraph * clip_graph_conformer::build() {
8787
cur = build_norm(residual, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, 1e-5, il);
8888
cb(cur, "conformer.layers.{}.norm_self_att", il);
8989

90-
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
90+
ggml_tensor * Qcur = build_mm(layer.q_w, cur);
9191
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
9292
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, Qcur->ne[1]);
9393
ggml_tensor * Q_bias_u = ggml_add(ctx0, Qcur, layer.pos_bias_u);
@@ -96,12 +96,12 @@ ggml_cgraph * clip_graph_conformer::build() {
9696
Q_bias_v = ggml_permute(ctx0, Q_bias_v, 0, 2, 1, 3);
9797

9898
// TODO @ngxson : some cont can/should be removed when ggml_mul_mat support these cases
99-
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
99+
ggml_tensor * Kcur = build_mm(layer.k_w, cur);
100100
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
101101
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, Kcur->ne[1]);
102102
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
103103

104-
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
104+
ggml_tensor * Vcur = build_mm(layer.v_w, cur);
105105
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
106106
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, Vcur->ne[1]);
107107
Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3));
@@ -111,7 +111,7 @@ ggml_cgraph * clip_graph_conformer::build() {
111111
matrix_ac = ggml_cont(ctx0, ggml_permute(ctx0, matrix_ac, 1, 0, 2, 3));
112112
cb(matrix_ac, "conformer.layers.{}.self_attn.id3", il);
113113

114-
auto * p = ggml_mul_mat(ctx0, layer.linear_pos_w, pos_emb);
114+
auto * p = build_mm(layer.linear_pos_w, pos_emb);
115115
cb(p, "conformer.layers.{}.self_attn.linear_pos", il);
116116
p = ggml_reshape_3d(ctx0, p, d_head, n_head, p->ne[1]);
117117
p = ggml_permute(ctx0, p, 0, 2, 1, 3);
@@ -143,7 +143,7 @@ ggml_cgraph * clip_graph_conformer::build() {
143143
x = ggml_permute(ctx0, x, 2, 0, 1, 3);
144144
x = ggml_cont_2d(ctx0, x, x->ne[0] * x->ne[1], x->ne[2]);
145145

146-
ggml_tensor * out = ggml_mul_mat(ctx0, layer.o_w, x);
146+
ggml_tensor * out = build_mm(layer.o_w, x);
147147
out = ggml_add(ctx0, out, layer.o_b);
148148
cb(out, "conformer.layers.{}.self_attn.linear_out", il);
149149

@@ -157,7 +157,7 @@ ggml_cgraph * clip_graph_conformer::build() {
157157
// conv
158158
{
159159
auto * x = cur;
160-
x = ggml_mul_mat(ctx0, layer.conv_pw1_w, x);
160+
x = build_mm(layer.conv_pw1_w, x);
161161
x = ggml_add(ctx0, x, layer.conv_pw1_b);
162162
cb(x, "conformer.layers.{}.conv.pointwise_conv1", il);
163163

@@ -181,7 +181,7 @@ ggml_cgraph * clip_graph_conformer::build() {
181181
x = ggml_silu(ctx0, x);
182182

183183
// pointwise_conv2
184-
x = ggml_mul_mat(ctx0, layer.conv_pw2_w, x);
184+
x = build_mm(layer.conv_pw2_w, x);
185185
x = ggml_add(ctx0, x, layer.conv_pw2_b);
186186

187187
cur = x;

tools/mtmd/models/glm4v.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ ggml_cgraph * clip_graph_glm4v::build() {
9797

9898
// FC projector
9999
{
100-
cur = ggml_mul_mat(ctx0, model.projection, cur);
100+
cur = build_mm(model.projection, cur);
101101
// default LayerNorm (post_projection_norm)
102102
cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
103103
cur = ggml_gelu_erf(ctx0, cur);

tools/mtmd/models/llama4.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ ggml_cgraph * clip_graph_llama4::build() {
2222
ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
2323
patch_size, patch_size, 3, n_embd);
2424
inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
25-
inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
25+
inp = build_mm(model.patch_embeddings_0, inp);
2626
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
2727
cb(inp, "patch_conv", -1);
2828
}
@@ -78,15 +78,15 @@ ggml_cgraph * clip_graph_llama4::build() {
7878

7979
// based on Llama4VisionMLP2 (always uses GELU activation, no bias)
8080
{
81-
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
81+
cur = build_mm(model.mm_model_mlp_1_w, cur);
8282
cur = ggml_gelu(ctx0, cur);
83-
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
83+
cur = build_mm(model.mm_model_mlp_2_w, cur);
8484
cur = ggml_gelu(ctx0, cur);
8585
cb(cur, "adapter_mlp", -1);
8686
}
8787

8888
// Llama4MultiModalProjector
89-
cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
89+
cur = build_mm(model.mm_model_proj, cur);
9090
cb(cur, "projected", -1);
9191

9292
// build the graph

tools/mtmd/models/llava.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,17 @@ ggml_cgraph * clip_graph_llava::build() {
7070

7171
// self-attention
7272
{
73-
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
73+
ggml_tensor * Qcur = build_mm(layer.q_w, cur);
7474
if (layer.q_b) {
7575
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
7676
}
7777

78-
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
78+
ggml_tensor * Kcur = build_mm(layer.k_w, cur);
7979
if (layer.k_b) {
8080
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
8181
}
8282

83-
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
83+
ggml_tensor * Vcur = build_mm(layer.v_w, cur);
8484
if (layer.v_b) {
8585
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
8686
}
@@ -164,17 +164,17 @@ ggml_cgraph * clip_graph_llava::build() {
164164

165165
// llava projector
166166
if (proj_type == PROJECTOR_TYPE_MLP) {
167-
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
167+
embeddings = build_mm(model.mm_0_w, embeddings);
168168
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
169169

170170
embeddings = ggml_gelu(ctx0, embeddings);
171171
if (model.mm_2_w) {
172-
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
172+
embeddings = build_mm(model.mm_2_w, embeddings);
173173
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
174174
}
175175
}
176176
else if (proj_type == PROJECTOR_TYPE_MLP_NORM) {
177-
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
177+
embeddings = build_mm(model.mm_0_w, embeddings);
178178
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
179179
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
180180
// First LayerNorm
@@ -186,7 +186,7 @@ ggml_cgraph * clip_graph_llava::build() {
186186
embeddings = ggml_gelu(ctx0, embeddings);
187187

188188
// Second linear layer
189-
embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
189+
embeddings = build_mm(model.mm_3_w, embeddings);
190190
embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
191191

192192
// Second LayerNorm
@@ -197,10 +197,10 @@ ggml_cgraph * clip_graph_llava::build() {
197197
else if (proj_type == PROJECTOR_TYPE_LDP) {
198198
// MobileVLM projector
199199
int n_patch = 24;
200-
ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
200+
ggml_tensor * mlp_1 = build_mm(model.mm_model_mlp_1_w, embeddings);
201201
mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
202202
mlp_1 = ggml_gelu(ctx0, mlp_1);
203-
ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
203+
ggml_tensor * mlp_3 = build_mm(model.mm_model_mlp_3_w, mlp_1);
204204
mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
205205
// mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
206206

@@ -229,10 +229,10 @@ ggml_cgraph * clip_graph_llava::build() {
229229
// block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
230230
// pointwise conv
231231
block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
232-
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
232+
block_1 = build_mm(model.mm_model_block_1_block_1_fc1_w, block_1);
233233
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
234234
block_1 = ggml_relu(ctx0, block_1);
235-
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
235+
block_1 = build_mm(model.mm_model_block_1_block_1_fc2_w, block_1);
236236
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
237237
block_1 = ggml_hardsigmoid(ctx0, block_1);
238238
// block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
@@ -244,7 +244,7 @@ ggml_cgraph * clip_graph_llava::build() {
244244
block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
245245

246246
// block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
247-
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
247+
block_1 = build_mm(model.mm_model_block_1_block_2_0_w, block_1);
248248
block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
249249

250250
// block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
@@ -277,10 +277,10 @@ ggml_cgraph * clip_graph_llava::build() {
277277
// block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
278278
// pointwise conv
279279
block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
280-
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
280+
block_1 = build_mm(model.mm_model_block_2_block_1_fc1_w, block_1);
281281
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
282282
block_1 = ggml_relu(ctx0, block_1);
283-
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
283+
block_1 = build_mm(model.mm_model_block_2_block_1_fc2_w, block_1);
284284
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
285285
block_1 = ggml_hardsigmoid(ctx0, block_1);
286286

@@ -292,7 +292,7 @@ ggml_cgraph * clip_graph_llava::build() {
292292
block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
293293
block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
294294
// block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
295-
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
295+
block_1 = build_mm(model.mm_model_block_2_block_2_0_w, block_1);
296296
block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
297297

298298

@@ -307,10 +307,10 @@ ggml_cgraph * clip_graph_llava::build() {
307307
else if (proj_type == PROJECTOR_TYPE_LDPV2)
308308
{
309309
int n_patch = 24;
310-
ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
310+
ggml_tensor * mlp_0 = build_mm(model.mm_model_mlp_0_w, embeddings);
311311
mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
312312
mlp_0 = ggml_gelu(ctx0, mlp_0);
313-
ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
313+
ggml_tensor * mlp_2 = build_mm(model.mm_model_mlp_2_w, mlp_0);
314314
mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
315315
// mlp_2 ne = [2048, 576, 1, 1]
316316
// // AVG Pool Layer 2*2, strides = 2
@@ -344,15 +344,15 @@ ggml_cgraph * clip_graph_llava::build() {
344344
embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
345345
// GLU
346346
{
347-
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
347+
embeddings = build_mm(model.mm_model_mlp_0_w, embeddings);
348348
embeddings = ggml_norm(ctx0, embeddings, eps);
349349
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
350350
embeddings = ggml_gelu_inplace(ctx0, embeddings);
351351
ggml_tensor * x = embeddings;
352-
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
353-
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
352+
embeddings = build_mm(model.mm_model_mlp_2_w, embeddings);
353+
x = build_mm(model.mm_model_mlp_1_w,x);
354354
embeddings = ggml_swiglu_split(ctx0, embeddings, x);
355-
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
355+
embeddings = build_mm(model.mm_model_mlp_3_w, embeddings);
356356
}
357357
// arrangement of BOI/EOI token embeddings
358358
// note: these embeddings are not present in text model, hence we cannot process them as text tokens

0 commit comments

Comments
 (0)