Skip to content

Commit f0386df

Browse files
committed
Consolidate Gemma RMSNorm +1 offset into single block after load
Address review feedback (shewu-quic): move the Gemma norm weight adjustment out of both if/else branches into a single block that runs regardless of checkpoint source. Also fix the misleading comment and simplify torch.ones() to scalar + 1.0.
1 parent dce6eca commit f0386df

1 file changed

Lines changed: 12 additions & 23 deletions

File tree

examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,6 @@ def _prepare_model(self): # noqa: C901
166166
state_dict = torch.load(
167167
checkpoint, weights_only=True, map_location="cpu", mmap=True
168168
)
169-
if self.control_args.decoder_model in {
170-
"gemma-2b",
171-
"gemma2-2b",
172-
"gemma3-1b",
173-
}:
174-
for k, v in state_dict.items():
175-
if "norm" not in k:
176-
continue
177-
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
178-
# See https://github.com/huggingface/transformers/pull/29402
179-
state_dict[k] = v.float() + torch.ones(v.shape, dtype=torch.float32)
180169
else:
181170
state_dict = torch.load(
182171
self.control_args.checkpoint,
@@ -192,18 +181,18 @@ def _prepare_model(self): # noqa: C901
192181
k.replace("_orig_mod.", ""): v for k, v in state_dict.items()
193182
}
194183

195-
if self.control_args.decoder_model in {
196-
"gemma-2b",
197-
"gemma2-2b",
198-
"gemma3-1b",
199-
}:
200-
for k, v in state_dict.items():
201-
if "norm" not in k:
202-
continue
203-
# Gemma RMSNorm uses (1 + w) * x, so converted checkpoints
204-
# that haven't been offset need +1 applied here.
205-
# See https://github.com/huggingface/transformers/pull/29402
206-
state_dict[k] = v.float() + torch.ones(v.shape, dtype=torch.float32)
184+
# Gemma RMSNorm computes (1 + w) * x but ExecuTorch's RMSNorm computes
185+
# w * x, so add the +1 offset to norm weights regardless of load path.
186+
# See https://github.com/huggingface/transformers/pull/29402
187+
if self.control_args.decoder_model in {
188+
"gemma-2b",
189+
"gemma2-2b",
190+
"gemma3-1b",
191+
}:
192+
for k, v in state_dict.items():
193+
if "norm" not in k:
194+
continue
195+
state_dict[k] = v.float() + 1.0
207196

208197
# change to HF weight to improve the performance of RoPE in HTP backend.
209198
if self.config.transform_weight:

0 commit comments

Comments
 (0)