@@ -662,12 +662,16 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
662662
663663 size_t pos_gen_start = pos_offset;
664664 int token = prompt.at (pos_offset);
665+ stream_token (token, 0 );
665666 for (size_t generate_pos = 0 ;
666667 pos < max_tokens && generate_pos < max_generated_tokens;
667668 ++pos, ++pos_offset, ++generate_pos) {
668669 Transformer (token, pos, c_weights, activations, kv_cache, pool, inner_pool);
669670 float * final_activation = activations.x .data ();
670- if (pos_offset >= prompt_size) {
671+ // The condition below is always true if we are doing Prefill above.
672+ // We keep it here for clarity so that the code is correct even if Prefill
673+ // is disabled.
674+ if (pos_offset >= prompt_size - 1 ) {
671675 PROFILER_ZONE (" Gen.Embedding" );
672676 // Generation phase
673677 MatVec<kVocabSize , TConfig::kModelDim >(
@@ -677,9 +681,14 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
677681 Softmax (activations.logits .data (), kVocabSize );
678682 token = SampleTopK<TConfig::kTopK >(activations.logits .data (), kVocabSize ,
679683 gen, temperature, accept_token);
680- }
681- if (!stream_token (token, activations.logits [token])) {
682- token = EOS_ID ;
684+ if (!stream_token (token, activations.logits [token])) {
685+ token = EOS_ID ;
686+ }
687+ } else {
688+ // We would take this branch if we were not doing Prefill but would
689+ // process the tokens of the prompt one at a time.
690+ token = prompt.at (pos_offset + 1 );
691+ stream_token (token, 0 );
683692 }
684693 if (token == EOS_ID ) {
685694 if (verbosity >= 2 ) {
0 commit comments