File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -41,17 +41,20 @@ def forward(self, x):
4141 return output * self .weight .type_as (x )
4242
4343
44- class ScalelessRMSNorm (torch .nn .Module ):
44+ class ScalelessRMSNorm (torch .nn .RMSNorm ):
45+ """RMSNorm with weight hardcoded to ones and not trainable.
46+
47+ Equivalent to a scaleless RMSNorm (no learnable scaling) but implemented as a
48+ torch.nn.RMSNorm so the op composes/decomposes cleanly for backends like QNN
49+ instead of being expressed as a hand-rolled decomposition.
50+ """
51+
4552 def __init__ (self , dim : int , eps : float = 1e-6 ):
46- super ().__init__ ()
53+ super ().__init__ (dim , eps )
4754 self .dim = dim
48- self .eps = eps
49-
50- def forward (self , x ):
51- x_float = x .float ()
52- return (
53- x_float * torch .rsqrt ((x_float * x_float ).mean (- 1 , keepdim = True ) + self .eps )
54- ).type_as (x )
55+ with torch .no_grad ():
56+ self .weight .fill_ (1.0 )
57+ self .weight .requires_grad = False
5558
5659
5760class RMSNormWithInputScale (torch .nn .Module ):
Original file line number Diff line number Diff line change @@ -69,7 +69,7 @@ def forward(
6969 0 , # dropout probability. Ignored by the code
7070 True , # is_causal
7171 )
72- return output .view (bsz , seqlen , self .dim ).to (dtype = input_dtype )
72+ return output .reshape (bsz , seqlen , self .dim ).to (dtype = input_dtype )
7373
7474
7575def _replace_sdpa_with_custom_op (
@@ -198,7 +198,7 @@ def forward(
198198 v_scale_fp32 ,
199199 )
200200
201- return output .view (bsz , seqlen , self .dim )
201+ return output .reshape (bsz , seqlen , self .dim )
202202
203203
204204def _update_attention_module_with_quantized_sdpa (
You can’t perform that action at this time.
0 commit comments