File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -41,20 +41,17 @@ def forward(self, x):
4141 return output * self .weight .type_as (x )
4242
4343
44- class ScalelessRMSNorm (torch .nn .RMSNorm ):
45- """RMSNorm with weight hardcoded to ones and not trainable.
46-
47- Equivalent to a scaleless RMSNorm (no learnable scaling) but implemented as a
48- torch.nn.RMSNorm so the op composes/decomposes cleanly for backends like QNN
49- instead of being expressed as a hand-rolled decomposition.
50- """
51-
44+ class ScalelessRMSNorm (torch .nn .Module ):
5245 def __init__ (self , dim : int , eps : float = 1e-6 ):
53- super ().__init__ (dim , eps )
46+ super ().__init__ ()
5447 self .dim = dim
55- with torch .no_grad ():
56- self .weight .fill_ (1.0 )
57- self .weight .requires_grad = False
48+ self .eps = eps
49+
50+ def forward (self , x ):
51+ x_float = x .float ()
52+ return (
53+ x_float * torch .rsqrt ((x_float * x_float ).mean (- 1 , keepdim = True ) + self .eps )
54+ ).type_as (x )
5855
5956
6057class RMSNormCoreML (torch .nn .Module ):
Original file line number Diff line number Diff line change @@ -69,7 +69,7 @@ def forward(
6969 0 , # dropout probability. Ignored by the code
7070 True , # is_causal
7171 )
72- return output .reshape (bsz , seqlen , self .dim ).to (dtype = input_dtype )
72+ return output .view (bsz , seqlen , self .dim ).to (dtype = input_dtype )
7373
7474
7575def _replace_sdpa_with_custom_op (
@@ -198,7 +198,7 @@ def forward(
198198 v_scale_fp32 ,
199199 )
200200
201- return output .reshape (bsz , seqlen , self .dim )
201+ return output .view (bsz , seqlen , self .dim )
202202
203203
204204def _update_attention_module_with_quantized_sdpa (
You can’t perform that action at this time.
0 commit comments