Skip to content

Commit 71ead04

Browse files
committed
Fix off-by-one errors in generation code and token streaming callback.
In the generation code we were feeding the last token of the prompt twice through the transformer. The new version fixes that and also works in the case where Prefill is completely disabled.
1 parent ede337f commit 71ead04

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

gemma.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -666,12 +666,16 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
666666

667667
size_t pos_gen_start = pos_offset;
668668
int token = prompt.at(pos_offset);
669+
stream_token(token, 0);
669670
for (size_t generate_pos = 0;
670671
pos < max_tokens && generate_pos < max_generated_tokens;
671672
++pos, ++pos_offset, ++generate_pos) {
672673
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
673674
float* final_activation = activations.x.data();
674-
if (pos_offset >= prompt_size) {
675+
// The condition below is always true if we are doing Prefill above.
676+
// We keep it here for clarity so that the code is correct even if Prefill
677+
// is disabled.
678+
if (pos_offset >= prompt_size - 1) {
675679
PROFILER_ZONE("Gen.Embedding");
676680
// Generation phase
677681
MatVec<kVocabSize, TConfig::kModelDim>(
@@ -681,9 +685,14 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
681685
Softmax(activations.logits.data(), kVocabSize);
682686
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
683687
gen, temperature, accept_token);
684-
}
685-
if (!stream_token(token, activations.logits[token])) {
686-
token = EOS_ID;
688+
if (!stream_token(token, activations.logits[token])) {
689+
token = EOS_ID;
690+
}
691+
} else {
692+
// We would take this branch if we were not doing Prefill but would
693+
// process the tokens of the prompt one at a time.
694+
token = prompt.at(pos_offset + 1);
695+
stream_token(token, 0);
687696
}
688697
if (token == EOS_ID) {
689698
if (verbosity >= 2) {

run.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
116116
verbosity](int token, float) {
117117
++abs_pos;
118118
++current_pos;
119-
if (current_pos < prompt_size) {
119+
// <= since position is incremented before
120+
if (current_pos <= prompt_size) {
120121
std::cerr << "." << std::flush;
121122
} else if (token == gcpp::EOS_ID) {
122123
if (!args.multiturn) {

0 commit comments

Comments
 (0)