Skip to content

Commit bedd3c6

Browse files
howard0suCopilot
andcommitted
gemma4: save/restore target_feat in prefix cache snapshot
Matching Qwen35's approach: save target_feat (BF16 feature ring buffer) and last_tok as part of the KV snapshot. On restore, target_feat is copied back to GPU before the delta prefill + feature mirror resync. Previously, only K/V tensors were snapshotted. After restore, the feature mirror contained stale data from the previous request's decode phase, causing the draft model to make poor predictions and halving speculative decode acceptance rate (52% → 24%). With this fix, the full feature state is correctly restored, and the subsequent draft_feature_mirror_sync_tail ensures the mirror matches. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 8e24fc5 commit bedd3c6

3 files changed

Lines changed: 35 additions & 4 deletions

File tree

dflash/src/gemma4/gemma4_backend.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,15 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot,
633633
}
634634
}
635635

636+
// Restore target_feat from snapshot
637+
if (snap.feat_snap && cache_.target_feat) {
638+
const size_t feat_nbytes = ggml_nbytes(snap.feat_snap);
639+
ggml_backend_tensor_set(cache_.target_feat, snap.feat_snap->data, 0, feat_nbytes);
640+
}
641+
636642
const int snap_pos = snap.cur_pos;
637643
cache_.cur_pos = snap_pos;
644+
cache_.last_tok = snap.last_tok;
638645

639646
// Set up sampler
640647
sampler_ = req.sampler;
@@ -766,8 +773,9 @@ bool Gemma4Backend::snapshot_save(int slot) {
766773
if (needs_alloc) {
767774
free_gemma4_snapshot(snap);
768775

776+
const int n_feat_tensors = (cache_.target_feat && cache_.target_feat_cap > 0) ? 1 : 0;
769777
ggml_init_params ip{};
770-
ip.mem_size = ggml_tensor_overhead() * (size_t)(n_layer * 2 + 4) + 4096;
778+
ip.mem_size = ggml_tensor_overhead() * (size_t)(n_layer * 2 + n_feat_tensors + 4) + 4096;
771779
ip.no_alloc = true;
772780
snap.ctx = ggml_init(ip);
773781
if (!snap.ctx) return false;
@@ -787,10 +795,21 @@ bool Gemma4Backend::snapshot_save(int slot) {
787795
}
788796
}
789797

798+
// target_feat: save min(snap_pos, target_feat_cap) positions
799+
snap.feat_snap = nullptr;
800+
snap.feat_cap = 0;
801+
if (cache_.target_feat && cache_.target_feat_cap > 0) {
802+
const int feat_len = std::min(snap_pos, cache_.target_feat_cap);
803+
snap.feat_snap = ggml_new_tensor_2d(snap.ctx, cache_.target_feat->type,
804+
cache_.target_feat->ne[0], feat_len);
805+
snap.feat_cap = cache_.target_feat_cap;
806+
}
807+
790808
snap.buf = ggml_backend_alloc_ctx_tensors(snap.ctx, snap_backend_);
791809
if (!snap.buf) {
792810
ggml_free(snap.ctx); snap.ctx = nullptr;
793811
snap.k_snap.clear(); snap.v_snap.clear();
812+
snap.feat_snap = nullptr;
794813
return false;
795814
}
796815
}
@@ -820,9 +839,15 @@ bool Gemma4Backend::snapshot_save(int slot) {
820839
}
821840
}
822841
snap.cur_pos = snap_pos;
842+
snap.last_tok = cache_.last_tok;
823843

824-
std::printf("[gemma4] snapshot saved slot=%d pos=%d\n", slot, snap.cur_pos);
825-
std::fflush(stdout);
844+
// target_feat: copy min(snap_pos, cap) positions from GPU to snapshot
845+
if (snap.feat_snap && cache_.target_feat) {
846+
const size_t feat_nbytes = ggml_nbytes(snap.feat_snap);
847+
ggml_backend_tensor_get(cache_.target_feat, snap.feat_snap->data, 0, feat_nbytes);
848+
}
849+
850+
std::fprintf(stderr, "[gemma4] snapshot saved slot=%d pos=%d\n", slot, snap.cur_pos);
826851
return true;
827852
}
828853

dflash/src/gemma4/gemma4_internal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,11 @@ bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache,
193193
// Snapshot
194194
struct Gemma4Snapshot {
195195
int cur_pos = 0;
196+
int32_t last_tok = -1;
196197
std::vector<ggml_tensor *> k_snap;
197198
std::vector<ggml_tensor *> v_snap;
199+
ggml_tensor * feat_snap = nullptr; // [fc_in, feat_len]
200+
int feat_cap = 0;
198201
ggml_context * ctx = nullptr;
199202
ggml_backend_buffer_t buf = nullptr;
200203
};

dflash/src/gemma4/gemma4_loader.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,10 @@ void free_gemma4_snapshot(Gemma4Snapshot & s) {
478478
if (s.buf) { ggml_backend_buffer_free(s.buf); s.buf = nullptr; }
479479
if (s.ctx) { ggml_free(s.ctx); s.ctx = nullptr; }
480480
s.k_snap.clear(); s.v_snap.clear();
481-
s.cur_pos = 0;
481+
s.feat_snap = nullptr;
482+
s.feat_cap = 0;
483+
s.cur_pos = 0;
484+
s.last_tok = -1;
482485
}
483486

484487
} // namespace dflash27b

0 commit comments

Comments
 (0)