Skip to content

Commit 08948f1

Browse files
Merge pull request #127 from szabadka:gemma3
PiperOrigin-RevId: 621815677
2 parents 44e6274 + 71ead04 commit 08948f1

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

gemma.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -662,12 +662,16 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
662662

663663
size_t pos_gen_start = pos_offset;
664664
int token = prompt.at(pos_offset);
665+
stream_token(token, 0);
665666
for (size_t generate_pos = 0;
666667
pos < max_tokens && generate_pos < max_generated_tokens;
667668
++pos, ++pos_offset, ++generate_pos) {
668669
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
669670
float* final_activation = activations.x.data();
670-
if (pos_offset >= prompt_size) {
671+
// The condition below is always true if we are doing Prefill above.
672+
// We keep it here for clarity so that the code is correct even if Prefill
673+
// is disabled.
674+
if (pos_offset >= prompt_size - 1) {
671675
PROFILER_ZONE("Gen.Embedding");
672676
// Generation phase
673677
MatVec<kVocabSize, TConfig::kModelDim>(
@@ -677,9 +681,14 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
677681
Softmax(activations.logits.data(), kVocabSize);
678682
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
679683
gen, temperature, accept_token);
680-
}
681-
if (!stream_token(token, activations.logits[token])) {
682-
token = EOS_ID;
684+
if (!stream_token(token, activations.logits[token])) {
685+
token = EOS_ID;
686+
}
687+
} else {
688+
// We would take this branch if we were not doing Prefill but would
689+
// process the tokens of the prompt one at a time.
690+
token = prompt.at(pos_offset + 1);
691+
stream_token(token, 0);
683692
}
684693
if (token == EOS_ID) {
685694
if (verbosity >= 2) {

run.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
116116
verbosity](int token, float) {
117117
++abs_pos;
118118
++current_pos;
119-
if (current_pos < prompt_size) {
119+
// <= since position is incremented before
120+
if (current_pos <= prompt_size) {
120121
std::cerr << "." << std::flush;
121122
} else if (token == gcpp::EOS_ID) {
122123
if (!args.multiturn) {

0 commit comments

Comments
 (0)