@@ -376,77 +376,32 @@ def __init__(
376376 self .weight = nn .Parameter (torch .zeros (hidden_size ))
377377 self .variance_epsilon = eps
378378
379- @staticmethod
380- def _forward_static_no_residual (
381- weight : torch .Tensor ,
382- variance_epsilon : float ,
383- x : torch .Tensor ,
384- ) -> torch .Tensor :
385- """PyTorch-native implementation equivalent to forward() without residual."""
386- orig_dtype = x .dtype
387- x = x .float ()
388- variance = x .pow (2 ).mean (dim = - 1 , keepdim = True )
389- x = x * torch .rsqrt (variance + variance_epsilon )
390- x = x * (1.0 + weight .float ())
391- x = x .to (orig_dtype )
392- return x
393-
394- @staticmethod
395- def _forward_static_with_residual (
396- weight : torch .Tensor ,
397- variance_epsilon : float ,
398- x : torch .Tensor ,
399- residual : torch .Tensor ,
400- ) -> tuple [torch .Tensor , torch .Tensor ]:
401- """PyTorch-native implementation equivalent to forward() with residual."""
402- orig_dtype = x .dtype
403- x = (
404- x .float () + residual .float ()
405- if orig_dtype == torch .float16
406- else x + residual
407- )
408- residual = x
409-
410- x = x .float ()
411- variance = x .pow (2 ).mean (dim = - 1 , keepdim = True )
412- x = x * torch .rsqrt (variance + variance_epsilon )
413- # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
414- # See https://github.com/huggingface/transformers/pull/29402
415- x = x * (1.0 + weight .float ())
416- x = x .to (orig_dtype )
417- return x , residual
418-
419379 def forward_native (
420380 self ,
421381 x : torch .Tensor ,
422382 residual : torch .Tensor | None = None ,
423383 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
424384 """PyTorch-native implementation equivalent to forward()."""
425- if residual is None :
426- return self ._forward_static_no_residual (
427- self . weight . data , self . variance_epsilon , x
428- )
429- else :
430- return self . _forward_static_with_residual (
431- self . weight . data , self . variance_epsilon , x , residual
385+ orig_dtype = x . dtype
386+ weight = self .weight . data . float () + 1.0
387+ if residual is not None :
388+ x = (
389+ x . float () + residual . float ()
390+ if orig_dtype == torch . float16
391+ else x + residual
432392 )
393+ residual = x
394+ # ir.ops.rms_norm handles fp32 upcast internally
395+ out = ir .ops .rms_norm (x , weight , self .variance_epsilon )
396+ return (
397+ out .to (orig_dtype ) if residual is None else (out .to (orig_dtype ), residual )
398+ )
433399
434400 def forward_cuda (
435401 self ,
436402 x : torch .Tensor ,
437403 residual : torch .Tensor | None = None ,
438404 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
439- if torch .compiler .is_compiling ():
440- return self .forward_native (x , residual )
441-
442- if not getattr (self , "_is_compiled" , False ):
443- self ._forward_static_no_residual = torch .compile ( # type: ignore
444- self ._forward_static_no_residual
445- )
446- self ._forward_static_with_residual = torch .compile ( # type: ignore
447- self ._forward_static_with_residual
448- )
449- self ._is_compiled = True
450405 return self .forward_native (x , residual )
451406
452407
0 commit comments