Skip to content

Commit bae4e37

Browse files
committed
Fix Gemma RMSNorm +1 offset missing on --checkpoint path
The `--checkpoint` code path skipped the Gemma-specific RMSNorm weight adjustment (`weight + 1`). Gemma stores norm weights as deviations from 1 and computes `(1 + w) * x`, but ExecuTorch's RMSNorm computes `w * x`. The HF download path applied the +1 offset correctly, but passing a converted checkpoint via `--checkpoint` silently produced garbage output from all 36+ norm layers, regardless of quantization recipe.
1 parent 42581f1 commit bae4e37

1 file changed

Lines changed: 13 additions & 0 deletions

File tree

examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,19 @@ def _prepare_model(self): # noqa: C901
192192
k.replace("_orig_mod.", ""): v for k, v in state_dict.items()
193193
}
194194

195+
if self.control_args.decoder_model in {
196+
"gemma-2b",
197+
"gemma2-2b",
198+
"gemma3-1b",
199+
}:
200+
for k, v in state_dict.items():
201+
if "norm" not in k:
202+
continue
203+
# Gemma RMSNorm uses (1 + w) * x, so converted checkpoints
204+
# that haven't been offset need +1 applied here.
205+
# See https://github.com/huggingface/transformers/pull/29402
206+
state_dict[k] = v.float() + torch.ones(v.shape, dtype=torch.float32)
207+
195208
# change to HF weight to improve the performance of RoPE in HTP backend.
196209
if self.config.transform_weight:
197210

0 commit comments

Comments
 (0)