@@ -1795,6 +1795,7 @@ def test_interleaved_swiglu(self):
17951795 @pytest .mark .parametrize ("quantization" , _quantization_list )
17961796 @pytest .mark .parametrize ("quantize_forward" , (False , True ))
17971797 @pytest .mark .parametrize ("quantize_backward" , (False , True ))
1798+ @pytest .mark .parametrize ("glu_linear_offset" , (1.0 , 0.0 ))
17981799 def test_clamped_swiglu (
17991800 self ,
18001801 * ,
@@ -1805,6 +1806,7 @@ def test_clamped_swiglu(
18051806 quantization : Optional [str ],
18061807 quantize_forward : bool ,
18071808 quantize_backward : bool ,
1809+ glu_linear_offset : float ,
18081810 limit : float = 0.75 ,
18091811 alpha : float = 1.702 ,
18101812 ):
@@ -1847,7 +1849,7 @@ def test_clamped_swiglu(
18471849 x_glu = x_glu .clamp (min = None , max = limit )
18481850 x_linear = x_linear .clamp (min = - limit , max = limit )
18491851 out_glu = x_glu * torch .sigmoid (alpha * x_glu )
1850- y_ref = out_glu * (x_linear + 1 )
1852+ y_ref = out_glu * (x_linear + glu_linear_offset )
18511853 y_ref .backward (dy_ref )
18521854
18531855 # Implementation with fusible operation
@@ -1858,6 +1860,7 @@ def test_clamped_swiglu(
18581860 te_ops .ClampedSwiGLU (
18591861 limit = limit ,
18601862 alpha = alpha ,
1863+ glu_linear_offset = glu_linear_offset ,
18611864 glu_interleave_size = glu_interleave_size ,
18621865 ),
18631866 te_ops .Quantize (forward = quantize_forward , backward = False ),
@@ -2240,6 +2243,7 @@ def test_interleaved_scaled_swiglu(self):
22402243 @pytest .mark .parametrize ("in_shape" , ((71 , 192 ), (5 , 7 , 128 )))
22412244 @pytest .mark .parametrize ("input_requires_grad" , (False , True ))
22422245 @pytest .mark .parametrize ("scales_requires_grad" , (False , True ))
2246+ @pytest .mark .parametrize ("glu_linear_offset" , (1.0 , 0.0 ))
22432247 def test_scaled_clamped_qgeglu (
22442248 self ,
22452249 * ,
@@ -2249,6 +2253,7 @@ def test_scaled_clamped_qgeglu(
22492253 device : torch .device = "cuda" ,
22502254 input_requires_grad : bool ,
22512255 scales_requires_grad : bool ,
2256+ glu_linear_offset : float ,
22522257 limit : float = 7.0 ,
22532258 alpha : float = 1.702 ,
22542259 ) -> None :
@@ -2293,7 +2298,7 @@ def test_scaled_clamped_qgeglu(
22932298 x_glu = x_glu .clamp (min = None , max = limit )
22942299 x_linear = x_linear .clamp (min = - limit , max = limit )
22952300 out_glu = x_glu * torch .sigmoid (alpha * x_glu )
2296- y = out_glu * (x_linear + 1 )
2301+ y = out_glu * (x_linear + glu_linear_offset )
22972302 y_ref = scales_ref .unsqueeze (- 1 ) * y
22982303 if input_requires_grad or scales_requires_grad :
22992304 y_ref .backward (dy_ref )
@@ -2302,6 +2307,7 @@ def test_scaled_clamped_qgeglu(
23022307 glu_interleave_size = glu_interleave_size ,
23032308 limit = limit ,
23042309 alpha = alpha ,
2310+ glu_linear_offset = glu_linear_offset ,
23052311 )
23062312 y_test = op (x_test , scales_test )
23072313 if input_requires_grad or scales_requires_grad :
0 commit comments