Skip to content

Commit 845ae4e

Browse files
ggerganovcnsiva
authored andcommitted
server : avoid checkpoint data host copies (ggml-org#22558)
* server : avoid checkpoint data host copies * llama : refactor llama_io_read_i
1 parent ea9dc03 commit 845ae4e

6 files changed

Lines changed: 132 additions & 72 deletions

File tree

src/llama-context.cpp

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2253,6 +2253,28 @@ class llama_io_write_buffer : public llama_io_write_i {
22532253
llama_io_write_buffer(
22542254
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
22552255

2256+
~llama_io_write_buffer() {
2257+
#if 1
2258+
// TODO: add backend support to batch tensor_get? or some other way to speed this up
2259+
for (const auto & info : winfos) {
2260+
ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size);
2261+
}
2262+
#else
2263+
// flush the writes asynchronously
2264+
// this helps on Macs, but on other devices - it does not. just an example
2265+
std::vector<std::future<void>> futures;
2266+
futures.reserve(winfos.size());
2267+
for (const auto & info : winfos) {
2268+
futures.push_back(std::async(std::launch::async, [info]() {
2269+
ggml_backend_tensor_get(info.tensor, info.ptr, info.offset, info.size);
2270+
}));
2271+
}
2272+
for (auto & f : futures) {
2273+
f.wait();
2274+
}
2275+
#endif
2276+
}
2277+
22562278
void write(const void * src, size_t size) override {
22572279
if (size > buf_size) {
22582280
throw std::runtime_error("unexpectedly reached end of buffer");
@@ -2267,7 +2289,10 @@ class llama_io_write_buffer : public llama_io_write_i {
22672289
if (size > buf_size) {
22682290
throw std::runtime_error("unexpectedly reached end of buffer");
22692291
}
2270-
ggml_backend_tensor_get(tensor, ptr, offset, size);
2292+
2293+
// save the write for later during destruction
2294+
winfos.push_back({tensor, ptr, size, offset});
2295+
22712296
ptr += size;
22722297
size_written += size;
22732298
buf_size -= size;
@@ -2281,25 +2306,48 @@ class llama_io_write_buffer : public llama_io_write_i {
22812306
uint8_t * ptr;
22822307
size_t buf_size = 0;
22832308
size_t size_written = 0;
2309+
2310+
struct write_info {
2311+
const ggml_tensor * tensor;
2312+
uint8_t * ptr;
2313+
size_t size;
2314+
size_t offset;
2315+
};
2316+
std::vector<write_info> winfos;
22842317
};
22852318

22862319
class llama_io_read_buffer : public llama_io_read_i {
22872320
public:
22882321
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
22892322

2290-
const uint8_t * read(size_t size) override {
2291-
const uint8_t * base_ptr = ptr;
2323+
~llama_io_read_buffer() {
2324+
// flush the reads
2325+
for (const auto & info : rinfos) {
2326+
ggml_backend_tensor_set(info.tensor, info.ptr, info.offset, info.size);
2327+
}
2328+
}
2329+
2330+
void read(void * dst, size_t size) override {
22922331
if (size > buf_size) {
22932332
throw std::runtime_error("unexpectedly reached end of buffer");
22942333
}
2334+
memcpy(dst, ptr, size);
22952335
ptr += size;
22962336
size_read += size;
22972337
buf_size -= size;
2298-
return base_ptr;
22992338
}
23002339

2301-
void read_to(void * dst, size_t size) override {
2302-
memcpy(dst, read(size), size);
2340+
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
2341+
if (size > buf_size) {
2342+
throw std::runtime_error("unexpectedly reached end of buffer");
2343+
}
2344+
2345+
// save for later during destruction
2346+
rinfos.push_back({tensor, ptr, size, offset});
2347+
2348+
ptr += size;
2349+
size_read += size;
2350+
buf_size -= size;
23032351
}
23042352

23052353
size_t n_bytes() override {
@@ -2310,6 +2358,14 @@ class llama_io_read_buffer : public llama_io_read_i {
23102358
const uint8_t * ptr;
23112359
size_t buf_size = 0;
23122360
size_t size_read = 0;
2361+
2362+
struct read_info {
2363+
ggml_tensor * tensor;
2364+
const uint8_t * ptr;
2365+
size_t size;
2366+
size_t offset;
2367+
};
2368+
std::vector<read_info> rinfos;
23132369
};
23142370

23152371
class llama_io_write_file : public llama_io_write_i {
@@ -2341,15 +2397,15 @@ class llama_io_read_file : public llama_io_read_i {
23412397
public:
23422398
llama_io_read_file(llama_file * f) : file(f) {}
23432399

2344-
void read_to(void * dst, size_t size) override {
2400+
void read(void * dst, size_t size) override {
23452401
file->read_raw(dst, size);
23462402
size_read += size;
23472403
}
23482404

2349-
const uint8_t * read(size_t size) override {
2405+
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
23502406
temp_buffer.resize(size);
2351-
read_to(temp_buffer.data(), size);
2352-
return temp_buffer.data();
2407+
read(temp_buffer.data(), size);
2408+
ggml_backend_tensor_set(tensor, temp_buffer.data(), offset, size);
23532409
}
23542410

23552411
size_t n_bytes() override {

src/llama-io.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "llama-io.h"
22

3+
#include <vector>
4+
35
void llama_io_write_i::write_string(const std::string & str) {
46
uint32_t str_size = str.size();
57

@@ -9,7 +11,10 @@ void llama_io_write_i::write_string(const std::string & str) {
911

1012
void llama_io_read_i::read_string(std::string & str) {
1113
uint32_t str_size;
12-
read_to(&str_size, sizeof(str_size));
14+
read(&str_size, sizeof(str_size));
15+
16+
std::vector<char> buf(str_size);
17+
read(buf.data(), str_size);
1318

14-
str.assign((const char *) read(str_size), str_size);
19+
str.assign(buf.data(), str_size);
1520
}

src/llama-io.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class llama_io_read_i {
2525
llama_io_read_i() = default;
2626
virtual ~llama_io_read_i() = default;
2727

28-
virtual const uint8_t * read(size_t size) = 0;
29-
virtual void read_to(void * dst, size_t size) = 0;
28+
virtual void read(void * dst, size_t size) = 0;
29+
virtual void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0;
3030

3131
// bytes read so far
3232
virtual size_t n_bytes() = 0;

src/llama-kv-cache.cpp

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,14 +1900,14 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
19001900
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
19011901

19021902
uint32_t n_stream_cur;
1903-
io.read_to(&n_stream_cur, sizeof(n_stream_cur));
1903+
io.read(&n_stream_cur, sizeof(n_stream_cur));
19041904
if (n_stream_cur != n_stream) {
19051905
throw std::runtime_error("n_stream mismatch");
19061906
}
19071907

19081908
for (uint32_t s = 0; s < n_stream; ++s) {
19091909
uint32_t cell_count;
1910-
io.read_to(&cell_count, sizeof(cell_count));
1910+
io.read(&cell_count, sizeof(cell_count));
19111911

19121912
if (cell_count == 0) {
19131913
continue;
@@ -2082,8 +2082,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
20822082
llama_pos pos;
20832083
uint32_t n_seq_id;
20842084

2085-
io.read_to(&pos, sizeof(pos));
2086-
io.read_to(&n_seq_id, sizeof(n_seq_id));
2085+
io.read(&pos, sizeof(pos));
2086+
io.read(&n_seq_id, sizeof(n_seq_id));
20872087

20882088
if (n_seq_id != 1) {
20892089
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
@@ -2092,7 +2092,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
20922092

20932093
if (hparams.n_pos_per_embd() > 1) {
20942094
llama_kv_cell_ext ext;
2095-
io.read_to(&ext, sizeof(ext));
2095+
io.read(&ext, sizeof(ext));
20962096

20972097
ubatch.pos[i + ubatch.n_tokens] = ext.y;
20982098
ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
@@ -2101,7 +2101,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
21012101
// read the sequence id, but directly discard it - we will use dest_seq_id instead
21022102
{
21032103
llama_seq_id seq_id;
2104-
io.read_to(&seq_id, sizeof(seq_id));
2104+
io.read(&seq_id, sizeof(seq_id));
21052105
}
21062106

21072107
ubatch.pos[i] = pos;
@@ -2143,20 +2143,20 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
21432143
llama_pos pos;
21442144
uint32_t n_seq_id;
21452145

2146-
io.read_to(&pos, sizeof(pos));
2147-
io.read_to(&n_seq_id, sizeof(n_seq_id));
2146+
io.read(&pos, sizeof(pos));
2147+
io.read(&n_seq_id, sizeof(n_seq_id));
21482148

21492149
cells.pos_set(i, pos);
21502150

21512151
if (hparams.n_pos_per_embd() > 1) {
21522152
llama_kv_cell_ext ext;
2153-
io.read_to(&ext, sizeof(ext));
2153+
io.read(&ext, sizeof(ext));
21542154
cells.ext_set(i, ext);
21552155
}
21562156

21572157
for (uint32_t j = 0; j < n_seq_id; ++j) {
21582158
llama_seq_id seq_id;
2159-
io.read_to(&seq_id, sizeof(seq_id));
2159+
io.read(&seq_id, sizeof(seq_id));
21602160

21612161
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
21622162
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
@@ -2189,8 +2189,8 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
21892189
uint32_t v_trans;
21902190
uint32_t n_layer;
21912191

2192-
io.read_to(&v_trans, sizeof(v_trans));
2193-
io.read_to(&n_layer, sizeof(n_layer));
2192+
io.read(&v_trans, sizeof(v_trans));
2193+
io.read(&n_layer, sizeof(n_layer));
21942194

21952195
if (n_layer != layers.size()) {
21962196
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
@@ -2217,7 +2217,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
22172217

22182218
// Read type of key
22192219
int32_t k_type_i_ref;
2220-
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
2220+
io.read(&k_type_i_ref, sizeof(k_type_i_ref));
22212221
const int32_t k_type_i = (int32_t) k->type;
22222222
if (k_type_i != k_type_i_ref) {
22232223
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
@@ -2226,7 +2226,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
22262226

22272227
// Read row size of key
22282228
uint64_t k_size_row_ref;
2229-
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
2229+
io.read(&k_size_row_ref, sizeof(k_size_row_ref));
22302230
const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
22312231
if (k_size_row != k_size_row_ref) {
22322232
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
@@ -2236,13 +2236,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
22362236
if (cell_count) {
22372237
if (sinfo.is_contiguous()) {
22382238
// Fast path: contiguous cells, single memcpy
2239-
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
2239+
io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row);
22402240
} else {
22412241
// Slow path: scatter to non-contiguous positions
2242-
const void * src = io.read(cell_count * k_size_row);
22432242
for (uint32_t i = 0; i < cell_count; ++i) {
22442243
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
2245-
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
2244+
io.read_tensor(k, dst_offset, k_size_row);
22462245
}
22472246
}
22482247
}
@@ -2261,7 +2260,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
22612260

22622261
// Read type of value
22632262
int32_t v_type_i_ref;
2264-
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2263+
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
22652264
const int32_t v_type_i = (int32_t) v->type;
22662265
if (v_type_i != v_type_i_ref) {
22672266
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
@@ -2270,7 +2269,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
22702269

22712270
// Read row size of value
22722271
uint64_t v_size_row_ref;
2273-
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
2272+
io.read(&v_size_row_ref, sizeof(v_size_row_ref));
22742273
const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
22752274
if (v_size_row != v_size_row_ref) {
22762275
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
@@ -2280,13 +2279,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
22802279
if (cell_count) {
22812280
if (sinfo.is_contiguous()) {
22822281
// Fast path: contiguous cells, single memcpy
2283-
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
2282+
io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row);
22842283
} else {
22852284
// Slow path: scatter to non-contiguous positions
2286-
const void * src = io.read(cell_count * v_size_row);
22872285
for (uint32_t i = 0; i < cell_count; ++i) {
22882286
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
2289-
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
2287+
io.read_tensor(v, dst_offset, v_size_row);
22902288
}
22912289
}
22922290
}
@@ -2305,7 +2303,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
23052303

23062304
// Read type of value
23072305
int32_t v_type_i_ref;
2308-
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2306+
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
23092307
const int32_t v_type_i = (int32_t) v->type;
23102308
if (v_type_i != v_type_i_ref) {
23112309
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
@@ -2314,7 +2312,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
23142312

23152313
// Read element size of value
23162314
uint32_t v_size_el_ref;
2317-
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
2315+
io.read(&v_size_el_ref, sizeof(v_size_el_ref));
23182316
const size_t v_size_el = ggml_type_size(v->type);
23192317
if (v_size_el != v_size_el_ref) {
23202318
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
@@ -2323,7 +2321,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
23232321

23242322
// Read GQA embedding size
23252323
uint32_t n_embd_v_gqa_ref;
2326-
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
2324+
io.read(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
23272325
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
23282326
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
23292327
return false;
@@ -2335,15 +2333,14 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
23352333
const uint32_t h = sinfo.head();
23362334
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
23372335
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
2338-
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
2336+
io.read_tensor(v, dst_offset, cell_count * v_size_el);
23392337
}
23402338
} else {
23412339
// Slow path: scatter to non-contiguous positions
23422340
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2343-
const void * src = io.read(cell_count * v_size_el);
23442341
for (uint32_t i = 0; i < cell_count; ++i) {
23452342
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
2346-
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
2343+
io.read_tensor(v, dst_offset, v_size_el);
23472344
}
23482345
}
23492346
}

0 commit comments

Comments
 (0)