@@ -114,6 +114,13 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
114114 return n_pos;
115115}
116116
117+ void mtmd_helper_image_get_decoder_pos (const mtmd_image_tokens * chunks, mtmd_decoder_pos * out_pos) {
118+ size_t n_tokens = mtmd_image_tokens_get_n_tokens (chunks);
119+ for (size_t i = 0 ; i < n_tokens; i++) {
120+ out_pos[i] = mtmd_image_tokens_get_decoder_pos (chunks, i);
121+ }
122+ }
123+
117124// helper struct to make working with embd batch easier
118125// note: this will be removed after llama_batch_ext refactoring
119126struct decode_embd_batch {
@@ -156,18 +163,15 @@ struct decode_embd_batch {
156163 }
157164
158165 // M-RoPE for image
159- void set_position_mrope_2d (llama_pos pos_0, int nx, int ny , llama_seq_id seq_id) {
166+ void set_position_mrope_2d (llama_pos pos_0, const std::vector<mtmd_decoder_pos> & rel_pos , llama_seq_id seq_id) {
160167 GGML_ASSERT (n_pos_per_embd == 4 );
161- GGML_ASSERT (nx > 0 && ny > 0 && nx * ny == batch.n_tokens );
168+ GGML_ASSERT (!rel_pos. empty () && ( int32_t )rel_pos. size () == batch.n_tokens );
162169 seq_id_0[0 ] = seq_id;
163- for (int y = 0 ; y < ny; y++) {
164- for (int x = 0 ; x < nx; x++) {
165- int i = y * nx + x;
166- pos[i ] = pos_0;
167- pos[i + batch.n_tokens ] = pos_0 + y;
168- pos[i + batch.n_tokens * 2 ] = pos_0 + x;
169- pos[i + batch.n_tokens * 3 ] = 0 ; // last pos dim is unused
170- }
170+ for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
171+ pos[i ] = pos_0 + rel_pos[i].t ;
172+ pos[i + batch.n_tokens ] = pos_0 + rel_pos[i].y ;
173+ pos[i + batch.n_tokens * 2 ] = pos_0 + rel_pos[i].x ;
174+ pos[i + batch.n_tokens * 3 ] = 0 ; // last pos dim is unused
171175 }
172176 for (int i = 0 ; i < batch.n_tokens ; i++) {
173177 batch.n_seq_id [i] = 1 ;
@@ -262,9 +266,10 @@ int32_t mtmd_helper_decode_image_chunk(
262266 LOG_ERR (" failed to decode chunk: image tokens are null\n " );
263267 return -1 ;
264268 }
265- const int nx = mtmd_image_tokens_get_nx (image_tokens);
266- const int ny = mtmd_image_tokens_get_ny (image_tokens);
267- batch_embd.set_position_mrope_2d (n_past, nx, ny, seq_id);
269+ const auto n_tokens = mtmd_image_tokens_get_n_tokens (image_tokens);
270+ std::vector<mtmd_decoder_pos> rel_pos (n_tokens);
271+ mtmd_helper_image_get_decoder_pos (image_tokens, rel_pos.data ());
272+ batch_embd.set_position_mrope_2d (n_past, rel_pos, seq_id);
268273 } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
269274 batch_embd.set_position_mrope_1d (n_past, seq_id);
270275 } else {
0 commit comments