Skip to content

Commit 8d9b992

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Fix VLM prefill batch size - prompt+tokens
PiperOrigin-RevId: 879159709
1 parent 5081341 commit 8d9b992

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

gemma/run.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,14 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
200200
config.wrapping, abs_pos, prompt_string,
201201
image_tokens.Rows());
202202
runtime_config.image_tokens = &image_tokens;
203+
// PrefixLM sees/attends to all tokens.
204+
runtime_config.prefill_tbatch_size = prompt.size();
205+
203206
prompt_size = prompt.size() - image_tokens.Rows();
204207
if (config.wrapping == PromptWrapping::PALIGEMMA) {
205208
// The end of the prefix for prefix-LM style attention in Paligemma.
206209
// See Figure 2 of https://arxiv.org/abs/2407.07726.
207210
prefix_end = prompt_size;
208-
// We need to look at all the tokens for the prefix.
209-
// NOTE: Online softmax is on the roadmap, after which this requirement
210-
// can be lifted.
211-
runtime_config.prefill_tbatch_size = prompt_size;
212211
}
213212
} else {
214213
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),

0 commit comments

Comments
 (0)