Skip to content

Commit 78433f6

Browse files
authored
Fix recurrent state serialization for partial reads and writes (#22362)
The previous code worked only for full tensor reads and writes and was hitting `GGML_ASSERT(size == ggml_nbytes(tensor)); ` assert when tested with llama-server.
1 parent 7ec36aa commit 78433f6

1 file changed

Lines changed: 50 additions & 16 deletions

File tree

ggml/src/ggml-backend-meta.cpp

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,40 +1205,57 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
12051205

12061206
if (split_state.n_segments != 1) {
12071207
GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1208-
GGML_ASSERT(offset == 0);
1209-
GGML_ASSERT(size == ggml_nbytes(tensor));
12101208
GGML_ASSERT(tensor->ne[3] == 1);
1209+
12111210
size_t offset_data = 0;
12121211
std::vector<size_t> simple_offsets(n_bufs, 0);
12131212
if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
12141213
GGML_ASSERT(tensor->ne[2] == 1);
1214+
1215+
const size_t row_stride = tensor->nb[1];
1216+
GGML_ASSERT(offset % row_stride == 0);
1217+
GGML_ASSERT(size % row_stride == 0);
1218+
const int64_t r_start = offset / row_stride;
1219+
const int64_t r_count = size / row_stride;
1220+
GGML_ASSERT(r_start + r_count <= tensor->ne[1]);
1221+
12151222
const int64_t blck_size = ggml_blck_size(tensor->type);
12161223
for (size_t s = 0; s < split_state.n_segments; s++) {
12171224
for (size_t j = 0; j < n_bufs; j++) {
12181225
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
12191226
GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
12201227
const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
1221-
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes,
1222-
tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]);
1228+
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
1229+
simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes,
1230+
r_count, simple_tensor->nb[1], tensor->nb[1]);
12231231
offset_data += nbytes;
12241232
simple_offsets[j] += nbytes;
12251233
}
12261234
}
1227-
GGML_ASSERT(offset_data*tensor->ne[1] == size);
1235+
GGML_ASSERT(offset_data*r_count == size);
12281236
return;
12291237
}
12301238
GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1239+
1240+
const size_t row_stride = tensor->nb[2];
1241+
GGML_ASSERT(offset % row_stride == 0);
1242+
GGML_ASSERT(size % row_stride == 0);
1243+
const int64_t r_start = offset / row_stride;
1244+
const int64_t r_count = size / row_stride;
1245+
GGML_ASSERT(r_start + r_count <= tensor->ne[2]);
1246+
12311247
for (size_t s = 0; s < split_state.n_segments; s++) {
12321248
for (size_t j = 0; j < n_bufs; j++) {
12331249
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
12341250
const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
1235-
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes,
1236-
tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]);
1251+
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
1252+
simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes,
1253+
r_count, simple_tensor->nb[2], tensor->nb[2]);
12371254
offset_data += nbytes;
12381255
simple_offsets[j] += nbytes;
12391256
}
12401257
}
1241-
GGML_ASSERT(offset_data*tensor->ne[2] == size);
1258+
GGML_ASSERT(offset_data*r_count == size);
12421259
return;
12431260
}
12441261

@@ -1295,40 +1312,57 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co
12951312

12961313
if (split_state.n_segments != 1) {
12971314
GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1298-
GGML_ASSERT(offset == 0);
1299-
GGML_ASSERT(size == ggml_nbytes(tensor));
13001315
GGML_ASSERT(tensor->ne[3] == 1);
1316+
13011317
size_t offset_data = 0;
13021318
std::vector<size_t> simple_offsets(n_bufs, 0);
13031319
if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
13041320
GGML_ASSERT(tensor->ne[2] == 1);
1321+
1322+
const size_t row_stride = tensor->nb[1];
1323+
GGML_ASSERT(offset % row_stride == 0);
1324+
GGML_ASSERT(size % row_stride == 0);
1325+
const int64_t r_start = offset / row_stride;
1326+
const int64_t r_count = size / row_stride;
1327+
GGML_ASSERT(r_start + r_count <= tensor->ne[1]);
1328+
13051329
const int64_t blck_size = ggml_blck_size(tensor->type);
13061330
for (size_t s = 0; s < split_state.n_segments; s++) {
13071331
for (size_t j = 0; j < n_bufs; j++) {
13081332
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
13091333
GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
13101334
const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
1311-
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes,
1312-
tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]);
1335+
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
1336+
simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes,
1337+
r_count, simple_tensor->nb[1], tensor->nb[1]);
13131338
offset_data += nbytes;
13141339
simple_offsets[j] += nbytes;
13151340
}
13161341
}
1317-
GGML_ASSERT(offset_data*tensor->ne[1] == size);
1342+
GGML_ASSERT(offset_data*r_count == size);
13181343
return;
13191344
}
13201345
GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1346+
1347+
const size_t row_stride = tensor->nb[2];
1348+
GGML_ASSERT(offset % row_stride == 0);
1349+
GGML_ASSERT(size % row_stride == 0);
1350+
const int64_t r_start = offset / row_stride;
1351+
const int64_t r_count = size / row_stride;
1352+
GGML_ASSERT(r_start + r_count <= tensor->ne[2]);
1353+
13211354
for (size_t s = 0; s < split_state.n_segments; s++) {
13221355
for (size_t j = 0; j < n_bufs; j++) {
13231356
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
13241357
const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
1325-
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes,
1326-
tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]);
1358+
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
1359+
simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes,
1360+
r_count, simple_tensor->nb[2], tensor->nb[2]);
13271361
offset_data += nbytes;
13281362
simple_offsets[j] += nbytes;
13291363
}
13301364
}
1331-
GGML_ASSERT(offset_data*tensor->ne[2] == size);
1365+
GGML_ASSERT(offset_data*r_count == size);
13321366
return;
13331367
}
13341368

0 commit comments

Comments
 (0)