Skip to content

Commit 78ee61f

Browse files
mnachinmergennachin
authored andcommitted
Apply Gemma 4 IT chat template in inference.py and C++ runner
Gemma 4 31B-IT is instruction-tuned and produces degenerate output without the chat template wrapping. Auto-wrap --prompt with the IT template (<bos><|turn>user\n{prompt}<turn|>\n<|turn>model\n <|channel>thought\n<channel|>) by default; --raw-prompt / --raw_prompt skips wrapping for pre-formatted input.
1 parent 54f1f28 commit 78ee61f

3 files changed

Lines changed: 41 additions & 1 deletion

File tree

examples/models/gemma4_31b/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ Writes `model.pte` and `model.ptd` into `--output-dir`.
7979

8080
## Eager inference
8181

82+
The prompt is automatically wrapped with the Gemma 4 IT chat template.
83+
Pass `--raw-prompt` to skip template wrapping for pre-formatted input.
84+
8285
```bash
8386
python examples/models/gemma4_31b/inference.py \
8487
--prequantized ./gemma4_31b_int4 \
@@ -109,6 +112,9 @@ The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`.
109112

110113
## Run the .pte
111114

115+
The prompt is automatically wrapped with the Gemma 4 IT chat template.
116+
Pass `--raw_prompt` to skip template wrapping for pre-formatted input.
117+
112118
```bash
113119
./gemma4_31b_runner \
114120
--model_path ./gemma4_31b_exports/model.pte \

examples/models/gemma4_31b/inference.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
Packs for the target backend (--backend cuda), materializes runtime buffers,
1414
optionally compiles with ``torch.compile``, and generates text autoregressively.
1515
16+
Gemma 4 31B-IT is instruction-tuned and requires chat-template formatting.
17+
The ``--prompt`` is automatically wrapped with the Gemma 4 chat template
18+
(``<bos><|turn>user\\n{prompt}<turn|>\\n<|turn>model\\n<|channel>thought\\n<channel|>``).
19+
Pass ``--raw-prompt`` to skip template wrapping (e.g., for pre-formatted input).
20+
1621
Usage:
1722
python inference.py \\
1823
--prequantized ./gemma4_31b_int4 \\
@@ -63,6 +68,16 @@ def _move_to_cuda(model, config) -> None:
6368
materialize_runtime_buffers(model, dtype=torch.bfloat16, device="cuda")
6469

6570

71+
_CHAT_TEMPLATE = (
72+
"<bos><|turn>user\n{prompt}<turn|>\n<|turn>model\n<|channel>thought\n<channel|>"
73+
)
74+
75+
76+
def apply_chat_template(prompt: str) -> str:
77+
"""Wrap a user prompt in the Gemma 4 IT chat template."""
78+
return _CHAT_TEMPLATE.format(prompt=prompt)
79+
80+
6681
def generate(
6782
model,
6883
tokenizer,
@@ -155,6 +170,11 @@ def main() -> None:
155170
default=4096,
156171
help="KV cache length to allocate for this run.",
157172
)
173+
parser.add_argument(
174+
"--raw-prompt",
175+
action="store_true",
176+
help="Skip chat-template wrapping (use if the prompt is already formatted).",
177+
)
158178
parser.add_argument(
159179
"--no-compile",
160180
action="store_true",
@@ -204,14 +224,16 @@ def main() -> None:
204224
# Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106).
205225
eos_token_ids = {1, 50, 106}
206226

227+
prompt = args.prompt if args.raw_prompt else apply_chat_template(args.prompt)
228+
207229
print(f"\nPrompt: {args.prompt}")
208230
print("-" * 40)
209231

210232
t0 = time.perf_counter()
211233
output = generate(
212234
model,
213235
tokenizer,
214-
args.prompt,
236+
prompt,
215237
max_new_tokens=args.max_new_tokens,
216238
temperature=args.temperature,
217239
eos_token_ids=eos_token_ids,

examples/models/gemma4_31b/main.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy).");
6565
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
6666
DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2).");
6767
DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1).");
68+
DEFINE_bool(
69+
raw_prompt,
70+
false,
71+
"Skip chat-template wrapping (use if the prompt is already formatted).");
6872
DEFINE_bool(
6973
cuda_graph,
7074
false,
@@ -232,6 +236,14 @@ int main(int argc, char** argv) {
232236
(std::istreambuf_iterator<char>(f)), std::istreambuf_iterator<char>());
233237
}
234238

239+
// Wrap with Gemma 4 IT chat template unless --raw_prompt is set.
240+
// BOS is prepended separately below; this adds the turn structure and the
241+
// empty thought block required by the instruction-tuned model.
242+
if (!FLAGS_raw_prompt) {
243+
prompt_text = "<|turn>user\n" + prompt_text +
244+
"<turn|>\n<|turn>model\n<|channel>thought\n<channel|>";
245+
}
246+
235247
// Encode prompt
236248
auto encode_result = tokenizer->encode(prompt_text);
237249
if (!encode_result.ok()) {

0 commit comments

Comments
 (0)