Skip to content

Commit e37abd6

Browse files
authored
mtmd: add batching API (ggml-org#24384)
* mtmd: add batching API * wip * first working version (gemma4v) * add arg * nits * wire up support_batch() * fix 0.0 output embd * fix audio * nits * refactor a bit * nits * fix non-batching case * fix comment
1 parent f58bad4 commit e37abd6

14 files changed

Lines changed: 537 additions & 119 deletions

File tree

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2243,6 +2243,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22432243
params.image_max_tokens = value;
22442244
}
22452245
).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MAX_TOKENS"));
2246+
add_opt(common_arg(
2247+
{"--mtmd-batch-max-tokens"}, "N",
2248+
string_format("maximum number of image tokens per batch when encoding images (default: %d)", params.mtmd_batch_max_tokens),
2249+
[](common_params & params, int value) {
2250+
params.mtmd_batch_max_tokens = value;
2251+
}
2252+
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MTMD_BATCH_MAX_TOKENS"));
22462253
if (llama_supports_rpc()) {
22472254
add_opt(common_arg(
22482255
{"--rpc"}, "SERVERS",

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ struct common_params {
575575
std::vector<std::string> image; // path to image file(s) ; TODO: change the name to "media"
576576
int image_min_tokens = -1;
577577
int image_max_tokens = -1;
578+
int mtmd_batch_max_tokens = 1024;
578579

579580
// finetune
580581
struct lr_opt lr;

tools/mtmd/clip-graph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ struct clip_graph {
5454
virtual ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const;
5555
// TODO: build_mm(w, b, x) to support bias
5656

57+
virtual bool support_batch() const {
58+
return false;
59+
}
60+
5761
//
5862
// utility functions
5963
//

tools/mtmd/clip.cpp

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ struct clip_ctx {
171171
std::map<ggml_backend_dev_t, size_t> mem_usage;
172172
std::map<ggml_backend_dev_t, size_t> mem_compute;
173173

174+
bool support_batch = false;
175+
174176
clip_ctx(clip_context_params & ctx_params) {
175177
flash_attn_type = ctx_params.flash_attn_type;
176178
no_alloc = ctx_params.no_alloc;
@@ -314,7 +316,7 @@ ggml_tensor * clip_graph::build_vit(
314316
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos,
315317
const build_vit_opts & opts
316318
) {
317-
// batch dim: inp is [n_embd, n_pos] (B==1) or [n_embd, n_pos, B] (multi-tile encode)
319+
// batch dim: inp is [n_embd, n_pos, B]
318320
const int64_t B = inp->ne[2];
319321

320322
if (learned_pos_embd) {
@@ -862,7 +864,7 @@ ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale
862864
return cur;
863865
}
864866

865-
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
867+
static std::unique_ptr<clip_graph> clip_get_graph_builder(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
866868
const clip_image_f32 & img = *imgs.entries[0];
867869
std::unique_ptr<clip_graph> builder;
868870

@@ -1025,7 +1027,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
10251027
// TODO [QWEN_VIDEO]: improve this in the future
10261028
builder->n_batch = imgs.entries.size();
10271029

1028-
return builder->build();
1030+
return builder;
10291031
}
10301032

10311033
//
@@ -2819,7 +2821,7 @@ struct clip_model_loader {
28192821
std::vector<support_info_op> ops;
28202822
};
28212823

2822-
static void warmup(clip_ctx & ctx_clip) {
2824+
static clip_image_f32_batch get_dummy_batch(clip_ctx & ctx_clip) {
28232825
// create a fake batch
28242826
const auto & hparams = ctx_clip.model.hparams;
28252827
clip_image_f32_batch batch;
@@ -2833,6 +2835,20 @@ struct clip_model_loader {
28332835
LOG_INF("%s: warmup with audio size = %d\n", __func__, hparams.warmup_audio_size);
28342836
}
28352837
batch.entries.push_back(std::move(img));
2838+
return batch;
2839+
}
2840+
2841+
static void init_ctx(clip_ctx & ctx_clip) {
2842+
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
2843+
2844+
// check batching support
2845+
auto batch = get_dummy_batch(ctx_clip);
2846+
auto builder = clip_get_graph_builder(&ctx_clip, batch);
2847+
ctx_clip.support_batch = builder->support_batch();
2848+
}
2849+
2850+
static void warmup(clip_ctx & ctx_clip) {
2851+
auto batch = get_dummy_batch(ctx_clip);
28362852
warmup(ctx_clip, batch);
28372853
}
28382854

@@ -2905,9 +2921,7 @@ struct clip_model_loader {
29052921

29062922
// only initialize backend buffers, but do not allocate them yet
29072923
static support_info_graph reserve_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
2908-
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
2909-
2910-
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
2924+
ggml_cgraph * gf = clip_get_graph_builder(&ctx_clip, batch)->build();
29112925
ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
29122926

29132927
ctx_clip.mem_compute.clear();
@@ -3070,6 +3084,7 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
30703084
ctx_vision = new clip_ctx(ctx_params);
30713085
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
30723086
loader.load_tensors(*ctx_vision);
3087+
loader.init_ctx(*ctx_vision);
30733088
if (ctx_params.warmup) {
30743089
loader.warmup(*ctx_vision);
30753090
}
@@ -3083,6 +3098,7 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
30833098
ctx_audio = new clip_ctx(ctx_params);
30843099
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
30853100
loader.load_tensors(*ctx_audio);
3101+
loader.init_ctx(*ctx_audio);
30863102
if (ctx_params.warmup) {
30873103
loader.warmup(*ctx_audio);
30883104
}
@@ -3484,25 +3500,22 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
34843500
return n_patches;
34853501
}
34863502

3487-
bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
3503+
bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, std::vector<float> & out_vec) {
34883504
clip_image_f32_batch imgs;
34893505
clip_image_f32_ptr img_copy(clip_image_f32_init());
34903506
*img_copy = *img;
34913507
imgs.entries.push_back(std::move(img_copy));
34923508

3493-
return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
3509+
return clip_image_batch_encode(ctx, n_threads, &imgs, out_vec);
34943510
}
34953511

3496-
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
3512+
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, std::vector<float> & out_batch_embd) {
34973513
const clip_image_f32_batch & imgs = *imgs_c_ptr;
34983514
int n_batch_cur = imgs.entries.size();
34993515

3500-
// maximum supported batch size, usually == 2 for qwen-vl-based models
3501-
int n_batch_max = clip_model_n_batch_max(ctx);
3502-
3503-
// TODO @ngxson : implement batch size > 1 as a loop
3504-
// we don't need true batching support because the cgraph will gonna be big anyway
3505-
if (n_batch_cur > n_batch_max) {
3516+
// [QWEN_VIDEO] for video models, the batch dimension is used as temporal dimension for merged frames
3517+
if (!ctx->support_batch && n_batch_cur > clip_model_n_temporal_merge(ctx)) {
3518+
LOG_ERR("%s: batch size %d exceeds maximum supported batch/temporal-merge size %d\n", __func__, n_batch_cur, clip_model_n_temporal_merge(ctx));
35063519
return false;
35073520
}
35083521

@@ -3513,7 +3526,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
35133526

35143527
// build the inference graph
35153528
ggml_backend_sched_reset(ctx->sched.get());
3516-
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
3529+
ggml_cgraph * gf = clip_get_graph_builder(ctx, imgs)->build();
35173530
ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
35183531

35193532
// set inputs
@@ -3582,6 +3595,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
35823595
const int n = nx * ny;
35833596

35843597
for (int b = 0; b < n_batch_cur; b++) {
3598+
LOG_DBG("%s: copying image %d/%d to input buffer (nx=%d, ny=%d)\n", __func__, b+1, n_batch_cur, nx, ny);
35853599
const auto & buf = imgs.entries[b]->get_ro_buf();
35863600
float * batch_entry = inp_raw.data() + b * (3*n);
35873601
for (int y = 0; y < ny; y++) {
@@ -4416,24 +4430,34 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
44164430
// the last node is the embedding tensor
44174431
ggml_tensor * embeddings = ggml_graph_node(gf, -1);
44184432

4419-
// sanity check (only support batch size of 1 for now)
4433+
// sanity check (assuming that all images in batch have the same number of tokens, so we only check the first one)
44204434
const int n_tokens_out = embeddings->ne[1];
44214435
const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
44224436
if (n_tokens_out != expected_n_tokens_out) {
44234437
LOG_ERR("%s: expected output %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
44244438
GGML_ABORT("Invalid number of output tokens");
44254439
}
44264440

4427-
// copy the embeddings to the location passed by the user
4428-
if (vec != nullptr) {
4429-
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
4441+
LOG_DBG("%s: output embedding shape [%d, %d, %d]\n", __func__,
4442+
(int)embeddings->ne[0], (int)embeddings->ne[1], (int)embeddings->ne[2]);
4443+
4444+
// copy output to user buffer if provided
4445+
// if output is empty, skip the copy
4446+
if (!out_batch_embd.empty()) {
4447+
if (out_batch_embd.size() != (size_t)ggml_nelements(embeddings)) {
4448+
LOG_ERR("%s: output buffer has %zu elements but expected %zu\n", __func__, out_batch_embd.size(), (size_t)ggml_nelements(embeddings));
4449+
GGML_ABORT("Output buffer size mismatch");
4450+
}
4451+
ggml_backend_tensor_get(embeddings, out_batch_embd.data(), 0, ggml_nbytes(embeddings));
4452+
} else {
4453+
LOG_WRN("%s: output buffer is empty, skipping copy\n", __func__);
44304454
}
44314455

44324456
// Debug: dump final embeddings if MTMD_DEBUG_EMBEDDINGS is set
44334457
if (ctx->debug_output_embeddings) {
44344458
const int64_t n_embd = embeddings->ne[0];
44354459
const int64_t n_tokens = embeddings->ne[1];
4436-
std::vector<float> emb_data(n_embd * n_tokens);
4460+
std::vector<float> emb_data(ggml_nelements(embeddings));
44374461
ggml_backend_tensor_get(embeddings, emb_data.data(), 0, ggml_nbytes(embeddings));
44384462

44394463
LOG_INF("\n=== MTMD_DEBUG_EMBEDDINGS ===\n");
@@ -4570,7 +4594,14 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
45704594
return ctx->model.modality == CLIP_MODALITY_AUDIO;
45714595
}
45724596

4573-
int clip_model_n_batch_max(const struct clip_ctx * ctx) {
4597+
bool clip_support_batch(const struct clip_ctx * ctx) {
4598+
return ctx->support_batch;
4599+
}
4600+
4601+
// TODO @ngxson : this is no longer correct with mtmd_batch API
4602+
// this was only meant to be used by qwen-vl-based models, to fuse 2 input images into one (qwen-vl video support)
4603+
// this logic should be refactored in near future to distinctly handle "merge frames" and "batching"
4604+
int clip_model_n_temporal_merge(const struct clip_ctx * ctx) {
45744605
switch (ctx->proj_type()) {
45754606
case PROJECTOR_TYPE_QWEN2VL:
45764607
case PROJECTOR_TYPE_QWEN25VL:

tools/mtmd/clip.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int id
9797
size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
9898
struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
9999

100-
bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
101-
bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
100+
bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, std::vector<float> & out_vec);
101+
bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, std::vector<float> & out_batch_embd);
102102

103103
bool clip_is_llava(const struct clip_ctx * ctx);
104104
// note for contributor: this clip_is_(model) pattern is deprecated
@@ -107,7 +107,9 @@ bool clip_is_llava(const struct clip_ctx * ctx);
107107
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
108108
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
109109

110-
int clip_model_n_batch_max(const struct clip_ctx * ctx);
110+
bool clip_support_batch(const struct clip_ctx * ctx);
111+
112+
int clip_model_n_temporal_merge(const struct clip_ctx * ctx); // TODO @ngxson : remove, refactor this
111113

112114
std::map<ggml_backend_dev_t, size_t> clip_get_mem_usage(const struct clip_ctx * ctx);
113115

tools/mtmd/models/gemma4v.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ ggml_cgraph * clip_graph_gemma4v::build() {
1010
ggml_set_name(inp_raw, "inp_raw_scaled");
1111

1212
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
13-
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
13+
inp = ggml_reshape_3d(ctx0, inp, n_patches, n_embd, n_batch);
1414
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
1515
ggml_set_name(inp, "inp");
1616
// note: no patch bias
@@ -51,10 +51,11 @@ ggml_cgraph * clip_graph_gemma4v::build() {
5151
// first half
5252
ggml_tensor * first;
5353
{
54-
first = ggml_view_3d(ctx0, cur,
55-
n_dim/2, n_head, n_pos,
54+
first = ggml_view_4d(ctx0, cur,
55+
n_dim/2, n_head, n_pos, n_batch,
5656
cur->nb[1],
5757
cur->nb[2],
58+
cur->nb[3],
5859
0);
5960
first = ggml_rope_ext(
6061
ctx0,
@@ -70,10 +71,11 @@ ggml_cgraph * clip_graph_gemma4v::build() {
7071
// second half
7172
ggml_tensor * second;
7273
{
73-
second = ggml_view_3d(ctx0, cur,
74-
n_dim/2, n_head, n_pos,
74+
second = ggml_view_4d(ctx0, cur,
75+
n_dim/2, n_head, n_pos, n_batch,
7576
cur->nb[1],
7677
cur->nb[2],
78+
cur->nb[3],
7779
n_dim/2 * ggml_element_size(cur));
7880
second = ggml_rope_ext(
7981
ctx0,
@@ -103,14 +105,14 @@ ggml_cgraph * clip_graph_gemma4v::build() {
103105
const int kernel_size = hparams.n_merge;
104106
GGML_ASSERT(kernel_size > 0);
105107

106-
// [n_embd, n_patches] -> [n_patches_x, n_patches_y, n_embd, 1]
107-
cur = ggml_cont_4d(ctx0, ggml_transpose(ctx0, cur), n_patches_x, n_patches_y, n_embd, 1);
108+
// [n_embd, n_patches] -> [n_patches_x, n_patches_y, n_embd, n_batch]
109+
cur = ggml_cont_4d(ctx0, ggml_transpose(ctx0, cur), n_patches_x, n_patches_y, n_embd, n_batch);
108110
cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG,
109111
kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
110112
const int out_x = n_patches_x / kernel_size;
111113
const int out_y = n_patches_y / kernel_size;
112-
// [out_x, out_y, n_embd, 1] -> [n_embd, out_x * out_y]
113-
cur = ggml_reshape_3d(ctx0, cur, out_x * out_y, n_embd, 1);
114+
// [out_x, out_y, n_embd, n_batch] -> [n_embd, out_x * out_y, n_batch]
115+
cur = ggml_reshape_3d(ctx0, cur, out_x * out_y, n_embd, n_batch);
114116
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
115117
cur = ggml_scale(ctx0, cur, sqrtf((float)n_embd));
116118
cb(cur, "pooled", -1);

tools/mtmd/models/models.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct clip_graph_gemma4v : clip_graph {
1616
clip_graph_gemma4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
1717
ggml_cgraph * build() override;
1818
ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const override;
19+
bool support_batch() const override { return true; }
1920
};
2021

2122
struct clip_graph_gemma4uv : clip_graph {

tools/mtmd/mtmd-helper.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ MTMD_API void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * image,
6767

6868
// helper function that automatically:
6969
// 1. run llama_decode() on text chunks
70-
// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()
71-
// if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error
70+
// 2. run mtmd_encode_chunk() on image chunks, then mtmd_get_output_embd() and then llama_decode()
71+
// if any of the mtmd_encode_chunk() or llama_decode() calls return non-zero, stop and forward the error
7272
// otherwise, returns 0 on success
7373
// this function is NOT thread-safe
7474
MTMD_API int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
@@ -157,13 +157,16 @@ MTMD_API int32_t mtmd_helper_video_read_next(mtmd_helper_video * ctx,
157157
} // extern "C"
158158
#endif
159159

160+
#ifdef __cplusplus
161+
#include <set>
162+
#include <memory>
163+
164+
namespace mtmd_helper {
165+
160166
//
161167
// C++ wrappers
162168
//
163169

164-
#ifdef __cplusplus
165-
namespace mtmd_helper {
166-
167170
// video-related C++ wrappers
168171
struct mtmd_helper_video_deleter {
169172
void operator()(mtmd_helper_video * val) { mtmd_helper_video_free(val); }

0 commit comments

Comments
 (0)