diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 8ad51d4594a..f18cf824b9c 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -42,16 +42,23 @@ def forward(self, x): class ScalelessRMSNorm(torch.nn.Module): + """RMSNorm without learnable scaling. + + Calls F.rms_norm with weight=None so the op composes/decomposes cleanly for + backends like QNN instead of being expressed as a hand-rolled decomposition + of mean / rsqrt / mul. Semantically equivalent to + torch.nn.RMSNorm(elementwise_affine=False), but implemented as a plain + Module to preserve the previous parameterless state_dict signature (no + `weight` attribute / parameter). + """ + def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = dim self.eps = eps def forward(self, x): - x_float = x.float() - return ( - x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps) - ).type_as(x) + return F.rms_norm(x, (self.dim,), None, self.eps) class RMSNormWithInputScale(torch.nn.Module): diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index b10f684ccc0..0285f3562cb 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -69,7 +69,7 @@ def forward( 0, # dropout probability. Ignored by the code True, # is_causal ) - return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) + return output.reshape(bsz, seqlen, self.dim).to(dtype=input_dtype) def _replace_sdpa_with_custom_op( @@ -198,7 +198,7 @@ def forward( v_scale_fp32, ) - return output.view(bsz, seqlen, self.dim) + return output.reshape(bsz, seqlen, self.dim) def _update_attention_module_with_quantized_sdpa(