Skip to content

Commit 707c0b7

Browse files
authored
mtmd: add mtmd_image_tokens_get_decoder_pos() API (ggml-org#21851)
* mtmd: add mtmd_image_tokens_get_decoder_pos() API * consistent naming * fix build
1 parent 1f30ac0 commit 707c0b7

5 files changed

Lines changed: 49 additions & 17 deletions

File tree

tests/test-mtmd-c-api.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ int main(void) {
4141
} else if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
4242
const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
4343
size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
44-
size_t nx = mtmd_image_tokens_get_nx(image_tokens);
45-
size_t ny = mtmd_image_tokens_get_ny(image_tokens);
44+
// get position of the last token, which should be (nx - 1, ny - 1)
45+
struct mtmd_decoder_pos pos = mtmd_image_tokens_get_decoder_pos(image_tokens, n_tokens - 1);
46+
size_t nx = pos.x + 1;
47+
size_t ny = pos.y + 1;
4648
const char * id = mtmd_image_tokens_get_id(image_tokens);
4749
assert(n_tokens > 0);
4850
assert(nx > 0);

tools/mtmd/mtmd-helper.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
114114
return n_pos;
115115
}
116116

117+
void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * chunks, mtmd_decoder_pos * out_pos) {
118+
size_t n_tokens = mtmd_image_tokens_get_n_tokens(chunks);
119+
for (size_t i = 0; i < n_tokens; i++) {
120+
out_pos[i] = mtmd_image_tokens_get_decoder_pos(chunks, i);
121+
}
122+
}
123+
117124
// helper struct to make working with embd batch easier
118125
// note: this will be removed after llama_batch_ext refactoring
119126
struct decode_embd_batch {
@@ -156,18 +163,15 @@ struct decode_embd_batch {
156163
}
157164

158165
// M-RoPE for image
159-
void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
166+
void set_position_mrope_2d(llama_pos pos_0, const std::vector<mtmd_decoder_pos> & rel_pos, llama_seq_id seq_id) {
160167
GGML_ASSERT(n_pos_per_embd == 4);
161-
GGML_ASSERT(nx > 0 && ny > 0 && nx * ny == batch.n_tokens);
168+
GGML_ASSERT(!rel_pos.empty() && (int32_t)rel_pos.size() == batch.n_tokens);
162169
seq_id_0[0] = seq_id;
163-
for (int y = 0; y < ny; y++) {
164-
for (int x = 0; x < nx; x++) {
165-
int i = y * nx + x;
166-
pos[i ] = pos_0;
167-
pos[i + batch.n_tokens ] = pos_0 + y;
168-
pos[i + batch.n_tokens * 2] = pos_0 + x;
169-
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
170-
}
170+
for (int32_t i = 0; i < batch.n_tokens; i++) {
171+
pos[i ] = pos_0 + rel_pos[i].t;
172+
pos[i + batch.n_tokens ] = pos_0 + rel_pos[i].y;
173+
pos[i + batch.n_tokens * 2] = pos_0 + rel_pos[i].x;
174+
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
171175
}
172176
for (int i = 0; i < batch.n_tokens; i++) {
173177
batch.n_seq_id[i] = 1;
@@ -262,9 +266,10 @@ int32_t mtmd_helper_decode_image_chunk(
262266
LOG_ERR("failed to decode chunk: image tokens are null\n");
263267
return -1;
264268
}
265-
const int nx = mtmd_image_tokens_get_nx(image_tokens);
266-
const int ny = mtmd_image_tokens_get_ny(image_tokens);
267-
batch_embd.set_position_mrope_2d(n_past, nx, ny, seq_id);
269+
const auto n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
270+
std::vector<mtmd_decoder_pos> rel_pos(n_tokens);
271+
mtmd_helper_image_get_decoder_pos(image_tokens, rel_pos.data());
272+
batch_embd.set_position_mrope_2d(n_past, rel_pos, seq_id);
268273
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
269274
batch_embd.set_position_mrope_1d(n_past, seq_id);
270275
} else {

tools/mtmd/mtmd-helper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);
4747
// normally, n_pos is equal to n_tokens, but for M-RoPE it is different
4848
MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks);
4949

50+
// helper to get the list of relative positions corresponding to the embedding tokens, to be used by M-RoPE
51+
// out_pos must have length == mtmd_helper_get_n_tokens(image)
52+
MTMD_API void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * image, mtmd_decoder_pos * out_pos);
53+
5054
// helper function that automatically:
5155
// 1. run llama_decode() on text chunks
5256
// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()

tools/mtmd/mtmd.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,14 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
12491249
return image_tokens->ny;
12501250
}
12511251

1252+
mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i) {
1253+
mtmd_decoder_pos pos;
1254+
pos.t = 0;
1255+
pos.x = i % image_tokens->nx;
1256+
pos.y = i / image_tokens->nx;
1257+
return pos;
1258+
}
1259+
12521260
const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
12531261
return image_tokens->id.c_str();
12541262
}

tools/mtmd/mtmd.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,25 @@ MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
186186
// the instance will be constructed via mtmd_tokenize()
187187
// it will be freed along with mtmd_input_chunk
188188
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate
189-
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
190-
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
191189
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
192190
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
193191
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
194192

193+
DEPRECATED(MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens),
194+
"use mtmd_image_tokens_get_decoder_pos() instead");
195+
DEPRECATED(MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens),
196+
"use mtmd_image_tokens_get_decoder_pos() instead");
197+
198+
struct mtmd_decoder_pos {
199+
uint32_t t;
200+
uint32_t x;
201+
uint32_t y;
202+
};
203+
// get position for decoder attention, to be used by M-RoPE models
204+
// i is the index of the embedding token, ranging from 0 to mtmd_image_tokens_get_n_tokens() - 1
205+
// return relative position (for example, embedding 0 will have position (0, 0, 0); remember to adjust it to the current absolute position)
206+
MTMD_API struct mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i);
207+
195208
// tokenize an input text prompt and a list of bitmaps (images/audio)
196209
// the prompt must have the input image marker (default: "<__media__>") in it
197210
// the default marker is defined by mtmd_default_marker()

0 commit comments

Comments
 (0)