Skip to content

Commit d6e7b03

Browse files
authored
llama : add option to save memory in device buffers (#22679)
* llama : add option to save memory in device buffers * tests : extend llama-save-load-state
1 parent fa59546 commit d6e7b03

11 files changed

Lines changed: 402 additions & 58 deletions

File tree

common/speculative.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,14 @@ struct common_speculative_state_draft : public common_speculative_state {
252252

253253
size_t create_checkpoint(int n_tokens_prompt) {
254254
int slot_id = 0;
255-
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
255+
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
256256

257257
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
258258
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
259259
ckpt.n_tokens = n_tokens_prompt;
260260
ckpt.data.resize(checkpoint_size);
261261

262-
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
262+
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
263263
if (n != checkpoint_size) {
264264
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
265265
}
@@ -272,7 +272,7 @@ struct common_speculative_state_draft : public common_speculative_state {
272272
size_t restore_checkpoint() {
273273
int slot_id = 0;
274274
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
275-
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
275+
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
276276
if (n != ckpt.size()) {
277277
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
278278
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());

examples/save-load-state/save-load-state.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ int main(int argc, char ** argv) {
3838
std::string result0;
3939
std::string result1;
4040
std::string result2;
41+
std::string result3;
4142

4243
// init
4344
auto llama_init = common_init_from_params(params);
@@ -213,11 +214,83 @@ int main(int argc, char ** argv) {
213214
n_past += 1;
214215
}
215216

217+
// test on-device state save/load
218+
auto params_ctx4 = common_context_params_to_llama(params);
219+
params_ctx4.n_seq_max = 2;
220+
llama_context * ctx4 = llama_init_from_model(model, params_ctx4);
221+
222+
llama_sampler * smpl4 = llama_sampler_chain_init(sparams);
223+
224+
llama_sampler_chain_add(smpl4, llama_sampler_init_dist(params.sampling.seed));
225+
226+
printf("\nsingle seq run: %s", params.prompt.c_str());
227+
228+
// load state (rng, logits, embedding and kv_cache) from file
229+
n_token_count_out = 0;
230+
231+
if (!llama_state_load_file(ctx4, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
232+
fprintf(stderr, "\n%s : failed to load state\n", __func__);
233+
return 1;
234+
}
235+
236+
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
237+
238+
// restore state (last tokens)
239+
n_past = n_token_count_out;
240+
if (!common_replay_last_token(ctx4, tokens.back(), n_past)) {
241+
return 1;
242+
}
243+
++n_past;
244+
245+
// save seq 0 and load into seq 1
246+
{
247+
// save kv of seq 0
248+
std::vector<uint8_t> seq_store(llama_state_seq_get_size_ext(ctx4, 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE));
249+
const size_t ncopy = llama_state_seq_get_data_ext(ctx4, seq_store.data(), seq_store.size(), 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
250+
if (ncopy != seq_store.size()) {
251+
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
252+
return 1;
253+
}
254+
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
255+
256+
// erase whole kv
257+
llama_memory_clear(llama_get_memory(ctx4), true);
258+
fprintf(stderr, "%s : kv cache cleared\n", __func__);
259+
260+
// restore kv into seq 0
261+
const size_t nset = llama_state_seq_set_data_ext(ctx4, seq_store.data(), seq_store.size(), 1, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
262+
if (nset != seq_store.size()) {
263+
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
264+
return 1;
265+
}
266+
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
267+
}
268+
269+
// forth run
270+
for (auto i = 0; i < params.n_predict; i++) {
271+
auto next_token = llama_sampler_sample(smpl4, ctx4, -1);
272+
auto next_token_str = common_token_to_piece(ctx4, next_token);
273+
274+
printf("%s", next_token_str.c_str());
275+
result3 += next_token_str;
276+
277+
common_batch_clear(batch);
278+
common_batch_add(batch, next_token, n_past, {1}, true);
279+
280+
if (llama_decode(ctx4, batch)) {
281+
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
282+
llama_batch_free(batch);
283+
return 1;
284+
}
285+
n_past += 1;
286+
}
287+
216288
printf("\n");
217289

218290
llama_sampler_free(smpl);
219291
llama_sampler_free(smpl2);
220292
llama_sampler_free(smpl3);
293+
llama_sampler_free(smpl4);
221294

222295
llama_batch_free(batch);
223296

@@ -226,12 +299,18 @@ int main(int argc, char ** argv) {
226299

227300
llama_free(ctx2);
228301
llama_free(ctx3);
302+
llama_free(ctx4);
229303

230304
if (result0 != result2) {
231305
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
232306
return 1;
233307
}
234308

309+
if (result0 != result3) {
310+
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
311+
return 1;
312+
}
313+
235314
fprintf(stderr, "\n%s : success\n", __func__);
236315

237316
return 0;

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf);
282282
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
283283
void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
284284
void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
285+
bool ggml_metal_buffer_cpy_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * src, struct ggml_tensor * dst);
285286
void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value);
286287

287288
// finds the Metal buffer that contains the tensor data on the GPU device

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#import "ggml-metal-device.h"
22

33
#import "ggml-impl.h"
4+
#import "ggml-backend-impl.h"
45

56
#include <Foundation/Foundation.h>
67

@@ -1737,6 +1738,47 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten
17371738
}
17381739
}
17391740

1741+
bool ggml_metal_buffer_cpy_tensor(ggml_metal_buffer_t buf_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
1742+
ggml_metal_buffer_t buf_src = (ggml_metal_buffer_t)src->buffer->context;
1743+
1744+
const size_t size = ggml_nbytes(src);
1745+
1746+
// if both buffers are shared, we can use memcpy directly
1747+
if (buf_dst->is_shared && buf_src->is_shared) {
1748+
memcpy(dst->data, src->data, size);
1749+
return true;
1750+
}
1751+
1752+
// for private buffers, we need to use Metal blit commands
1753+
@autoreleasepool {
1754+
struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf_src, src);
1755+
struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf_dst, dst);
1756+
1757+
if (bid_src.metal == nil || bid_dst.metal == nil) {
1758+
return false;
1759+
}
1760+
1761+
id<MTLCommandBuffer> cmd_buf = [buf_dst->dev->mtl_queue commandBufferWithUnretainedReferences];
1762+
1763+
{
1764+
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
1765+
1766+
[encoder copyFromBuffer:bid_src.metal
1767+
sourceOffset:bid_src.offs
1768+
toBuffer:bid_dst.metal
1769+
destinationOffset:bid_dst.offs
1770+
size:size];
1771+
1772+
[encoder endEncoding];
1773+
}
1774+
1775+
[cmd_buf commit];
1776+
[cmd_buf waitUntilCompleted];
1777+
}
1778+
1779+
return true;
1780+
}
1781+
17401782
void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) {
17411783
if (buf->is_shared) {
17421784
memset(buf->all_data, value, buf->all_size);

ggml/src/ggml-metal/ggml-metal.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
// note: can be overridden with GGML_METAL_DEVICES env to simulate virtual devices
1818
static int g_devices = 1;
1919

20+
// forward declaration
21+
static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer);
22+
2023
////////////////////////////////////////////////////////////////////////////////
2124
// backend interface
2225
////////////////////////////////////////////////////////////////////////////////
@@ -68,11 +71,11 @@ static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t bu
6871

6972
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
7073

71-
GGML_UNUSED(buffer);
72-
GGML_UNUSED(src);
73-
GGML_UNUSED(dst);
74+
if (!ggml_backend_buffer_is_metal(src->buffer)) {
75+
return false;
76+
}
7477

75-
return false;
78+
return ggml_metal_buffer_cpy_tensor(ctx, src, dst);
7679
}
7780

7881
static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -144,11 +147,11 @@ static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t b
144147

145148
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
146149

147-
GGML_UNUSED(buffer);
148-
GGML_UNUSED(src);
149-
GGML_UNUSED(dst);
150+
if (!ggml_backend_buffer_is_metal(src->buffer)) {
151+
return false;
152+
}
150153

151-
return false;
154+
return ggml_metal_buffer_cpy_tensor(ctx, src, dst);
152155
}
153156

154157
static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) {

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,9 @@ extern "C" {
864864
// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
865865
#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
866866

867+
// keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load)
868+
#define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2
869+
867870
typedef uint32_t llama_state_seq_flags;
868871

869872
LLAMA_API size_t llama_state_seq_get_size_ext(

0 commit comments

Comments
 (0)