Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions examples/models/llama/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,23 @@ def forward(self, x):


class ScalelessRMSNorm(torch.nn.Module):
"""RMSNorm without learnable scaling.

Calls F.rms_norm with weight=None so the op composes/decomposes cleanly for
backends like QNN instead of being expressed as a hand-rolled decomposition
of mean / rsqrt / mul. Semantically equivalent to
torch.nn.RMSNorm(elementwise_affine=False), but implemented as a plain
Module to preserve the previous parameterless state_dict signature (no
`weight` attribute / parameter).
"""

def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps

def forward(self, x):
x_float = x.float()
return (
x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)
).type_as(x)
return F.rms_norm(x, (self.dim,), None, self.eps)


class RMSNormWithInputScale(torch.nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def forward(
0, # dropout probability. Ignored by the code
True, # is_causal
)
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
return output.reshape(bsz, seqlen, self.dim).to(dtype=input_dtype)


def _replace_sdpa_with_custom_op(
Expand Down Expand Up @@ -198,7 +198,7 @@ def forward(
v_scale_fp32,
)

return output.view(bsz, seqlen, self.dim)
return output.reshape(bsz, seqlen, self.dim)


def _update_attention_module_with_quantized_sdpa(
Expand Down
Loading