@@ -1205,40 +1205,57 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
12051205
12061206 if (split_state.n_segments != 1 ) {
12071207 GGML_ASSERT (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1208- GGML_ASSERT (offset == 0 );
1209- GGML_ASSERT (size == ggml_nbytes (tensor));
12101208 GGML_ASSERT (tensor->ne [3 ] == 1 );
1209+
12111210 size_t offset_data = 0 ;
12121211 std::vector<size_t > simple_offsets (n_bufs, 0 );
12131212 if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
12141213 GGML_ASSERT (tensor->ne [2 ] == 1 );
1214+
1215+ const size_t row_stride = tensor->nb [1 ];
1216+ GGML_ASSERT (offset % row_stride == 0 );
1217+ GGML_ASSERT (size % row_stride == 0 );
1218+ const int64_t r_start = offset / row_stride;
1219+ const int64_t r_count = size / row_stride;
1220+ GGML_ASSERT (r_start + r_count <= tensor->ne [1 ]);
1221+
12151222 const int64_t blck_size = ggml_blck_size (tensor->type );
12161223 for (size_t s = 0 ; s < split_state.n_segments ; s++) {
12171224 for (size_t j = 0 ; j < n_bufs; j++) {
12181225 ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor (tensor, j);
12191226 GGML_ASSERT (split_state.ne [s*n_bufs + j] % blck_size == 0 );
12201227 const size_t nbytes = split_state.ne [s*n_bufs + j]/blck_size * tensor->nb [0 ];
1221- ggml_backend_tensor_set_2d (simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes,
1222- tensor->ne [1 ], simple_tensor->nb [1 ], tensor->nb [1 ]);
1228+ ggml_backend_tensor_set_2d (simple_tensor, (const char *) data + offset_data,
1229+ simple_offsets[j] + r_start * simple_tensor->nb [1 ], nbytes,
1230+ r_count, simple_tensor->nb [1 ], tensor->nb [1 ]);
12231231 offset_data += nbytes;
12241232 simple_offsets[j] += nbytes;
12251233 }
12261234 }
1227- GGML_ASSERT (offset_data*tensor-> ne [ 1 ] == size);
1235+ GGML_ASSERT (offset_data*r_count == size);
12281236 return ;
12291237 }
12301238 GGML_ASSERT (split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1239+
1240+ const size_t row_stride = tensor->nb [2 ];
1241+ GGML_ASSERT (offset % row_stride == 0 );
1242+ GGML_ASSERT (size % row_stride == 0 );
1243+ const int64_t r_start = offset / row_stride;
1244+ const int64_t r_count = size / row_stride;
1245+ GGML_ASSERT (r_start + r_count <= tensor->ne [2 ]);
1246+
12311247 for (size_t s = 0 ; s < split_state.n_segments ; s++) {
12321248 for (size_t j = 0 ; j < n_bufs; j++) {
12331249 ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor (tensor, j);
12341250 const size_t nbytes = split_state.ne [s*n_bufs + j] * tensor->nb [1 ];
1235- ggml_backend_tensor_set_2d (simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes,
1236- tensor->ne [2 ], simple_tensor->nb [2 ], tensor->nb [2 ]);
1251+ ggml_backend_tensor_set_2d (simple_tensor, (const char *) data + offset_data,
1252+ simple_offsets[j] + r_start * simple_tensor->nb [2 ], nbytes,
1253+ r_count, simple_tensor->nb [2 ], tensor->nb [2 ]);
12371254 offset_data += nbytes;
12381255 simple_offsets[j] += nbytes;
12391256 }
12401257 }
1241- GGML_ASSERT (offset_data*tensor-> ne [ 2 ] == size);
1258+ GGML_ASSERT (offset_data*r_count == size);
12421259 return ;
12431260 }
12441261
@@ -1295,40 +1312,57 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co
12951312
12961313 if (split_state.n_segments != 1 ) {
12971314 GGML_ASSERT (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1298- GGML_ASSERT (offset == 0 );
1299- GGML_ASSERT (size == ggml_nbytes (tensor));
13001315 GGML_ASSERT (tensor->ne [3 ] == 1 );
1316+
13011317 size_t offset_data = 0 ;
13021318 std::vector<size_t > simple_offsets (n_bufs, 0 );
13031319 if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
13041320 GGML_ASSERT (tensor->ne [2 ] == 1 );
1321+
1322+ const size_t row_stride = tensor->nb [1 ];
1323+ GGML_ASSERT (offset % row_stride == 0 );
1324+ GGML_ASSERT (size % row_stride == 0 );
1325+ const int64_t r_start = offset / row_stride;
1326+ const int64_t r_count = size / row_stride;
1327+ GGML_ASSERT (r_start + r_count <= tensor->ne [1 ]);
1328+
13051329 const int64_t blck_size = ggml_blck_size (tensor->type );
13061330 for (size_t s = 0 ; s < split_state.n_segments ; s++) {
13071331 for (size_t j = 0 ; j < n_bufs; j++) {
13081332 const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor (tensor, j);
13091333 GGML_ASSERT (split_state.ne [s*n_bufs + j] % blck_size == 0 );
13101334 const size_t nbytes = split_state.ne [s*n_bufs + j]/blck_size * tensor->nb [0 ];
1311- ggml_backend_tensor_get_2d (simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes,
1312- tensor->ne [1 ], simple_tensor->nb [1 ], tensor->nb [1 ]);
1335+ ggml_backend_tensor_get_2d (simple_tensor, (char *) data + offset_data,
1336+ simple_offsets[j] + r_start * simple_tensor->nb [1 ], nbytes,
1337+ r_count, simple_tensor->nb [1 ], tensor->nb [1 ]);
13131338 offset_data += nbytes;
13141339 simple_offsets[j] += nbytes;
13151340 }
13161341 }
1317- GGML_ASSERT (offset_data*tensor-> ne [ 1 ] == size);
1342+ GGML_ASSERT (offset_data*r_count == size);
13181343 return ;
13191344 }
13201345 GGML_ASSERT (split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1346+
1347+ const size_t row_stride = tensor->nb [2 ];
1348+ GGML_ASSERT (offset % row_stride == 0 );
1349+ GGML_ASSERT (size % row_stride == 0 );
1350+ const int64_t r_start = offset / row_stride;
1351+ const int64_t r_count = size / row_stride;
1352+ GGML_ASSERT (r_start + r_count <= tensor->ne [2 ]);
1353+
13211354 for (size_t s = 0 ; s < split_state.n_segments ; s++) {
13221355 for (size_t j = 0 ; j < n_bufs; j++) {
13231356 const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor (tensor, j);
13241357 const size_t nbytes = split_state.ne [s*n_bufs + j] * tensor->nb [1 ];
1325- ggml_backend_tensor_get_2d (simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes,
1326- tensor->ne [2 ], simple_tensor->nb [2 ], tensor->nb [2 ]);
1358+ ggml_backend_tensor_get_2d (simple_tensor, (char *) data + offset_data,
1359+ simple_offsets[j] + r_start * simple_tensor->nb [2 ], nbytes,
1360+ r_count, simple_tensor->nb [2 ], tensor->nb [2 ]);
13271361 offset_data += nbytes;
13281362 simple_offsets[j] += nbytes;
13291363 }
13301364 }
1331- GGML_ASSERT (offset_data*tensor-> ne [ 2 ] == size);
1365+ GGML_ASSERT (offset_data*r_count == size);
13321366 return ;
13331367 }
13341368
0 commit comments