Skip to content

Commit 98c35df

Browse files
ssam18ArberSephirotheca
authored andcommitted
ggml-backend-meta: add multi-segment read support in get_tensor (ggml-org#22063)
1 parent 66ffbec commit 98c35df

1 file changed

Lines changed: 39 additions & 1 deletion

File tree

ggml/src/ggml-backend-meta.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,45 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co
12701270
GGML_ASSERT(ggml_is_contiguous(tensor));
12711271

12721272
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1273-
GGML_ASSERT(split_state.n_segments == 1);
1273+
1274+
if (split_state.n_segments != 1) {
1275+
GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1276+
GGML_ASSERT(offset == 0);
1277+
GGML_ASSERT(size == ggml_nbytes(tensor));
1278+
GGML_ASSERT(tensor->ne[3] == 1);
1279+
size_t offset_data = 0;
1280+
std::vector<size_t> simple_offsets(n_bufs, 0);
1281+
if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
1282+
GGML_ASSERT(tensor->ne[2] == 1);
1283+
const int64_t blck_size = ggml_blck_size(tensor->type);
1284+
for (size_t s = 0; s < split_state.n_segments; s++) {
1285+
for (size_t j = 0; j < n_bufs; j++) {
1286+
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1287+
GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
1288+
const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
1289+
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes,
1290+
tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]);
1291+
offset_data += nbytes;
1292+
simple_offsets[j] += nbytes;
1293+
}
1294+
}
1295+
GGML_ASSERT(offset_data*tensor->ne[1] == size);
1296+
return;
1297+
}
1298+
GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1299+
for (size_t s = 0; s < split_state.n_segments; s++) {
1300+
for (size_t j = 0; j < n_bufs; j++) {
1301+
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1302+
const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
1303+
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes,
1304+
tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]);
1305+
offset_data += nbytes;
1306+
simple_offsets[j] += nbytes;
1307+
}
1308+
}
1309+
GGML_ASSERT(offset_data*tensor->ne[2] == size);
1310+
return;
1311+
}
12741312

12751313
switch (split_state.axis) {
12761314
case GGML_BACKEND_SPLIT_AXIS_0:

0 commit comments

Comments
 (0)