@@ -36,37 +36,56 @@ def __init__(
3636 dtype : DType ,
3737 eps : float = 1e-5 ,
3838 use_bias : bool = True ,
39+ keep_dtype : bool = False ,
40+ elementwise_affine : bool = True ,
3941 ) -> None :
4042 super ().__init__ ()
4143 self .devices = devices
42- self .weight = Weight ("weight" , dtype , (dims ,), device = self .devices [0 ])
43- self .bias = (
44- Weight ("bias" , dtype , (dims ,), device = self .devices [0 ])
45- if use_bias
46- else None
47- )
44+ if elementwise_affine :
45+ self .weight = Weight (
46+ "weight" , dtype , (dims ,), device = self .devices [0 ]
47+ )
48+ self .bias = (
49+ Weight ("bias" , dtype , (dims ,), device = self .devices [0 ])
50+ if use_bias
51+ else None
52+ )
53+ else :
54+ self .weight = None
55+ self .bias = None
4856 self .eps = eps
4957 self .dim = dims
5058 self .dtype = dtype
59+ self .keep_dtype = keep_dtype
5160 self ._sharding_strategy : ShardingStrategy | None = None
5261
5362 def __call__ (self , input : TensorValue ):
5463 # TODO: AIPIPE-95 Replace with a broadcasting rmo.layer_norm
5564 bias = (
56- ops . cast ( self .bias , DType . float32 )
65+ self .bias
5766 if self .bias
5867 # If bias wasn't passed then use bias-less layer norm (beta = 0).
5968 else ops .broadcast_to (
60- ops .constant (0.0 , DType .float32 , self .weight .device ),
69+ ops .constant (0.0 , self .dtype , input .device ),
70+ shape = (input .shape [- 1 ],),
71+ )
72+ )
73+ gamma = (
74+ self .weight
75+ if self .weight
76+ else ops .broadcast_to (
77+ ops .constant (1.0 , self .dtype , input .device ),
6178 shape = (input .shape [- 1 ],),
6279 )
6380 )
64- return ops .layer_norm (
65- input .cast (DType .float32 ),
66- gamma = ops .cast (self .weight , DType .float32 ),
67- beta = bias ,
81+
82+ output = ops .layer_norm (
83+ input = input if self .keep_dtype else input .cast (DType .float32 ),
84+ gamma = gamma if self .keep_dtype else ops .cast (gamma , DType .float32 ),
85+ beta = bias if self .keep_dtype else ops .cast (bias , DType .float32 ),
6886 epsilon = self .eps ,
69- ).cast (input .dtype )
87+ )
88+ return output if self .keep_dtype else output .cast (input .dtype )
7089
7190 @property
7291 def sharding_strategy (self ) -> ShardingStrategy | None :
0 commit comments