@@ -1846,6 +1846,7 @@ def test_interleaved_swiglu(self):
18461846 @pytest .mark .parametrize ("quantization" , _quantization_list )
18471847 @pytest .mark .parametrize ("quantize_forward" , (False , True ))
18481848 @pytest .mark .parametrize ("quantize_backward" , (False , True ))
1849+ @pytest .mark .parametrize ("glu_linear_offset" , (1.0 , 0.0 ))
18491850 def test_clamped_swiglu (
18501851 self ,
18511852 * ,
@@ -1856,6 +1857,7 @@ def test_clamped_swiglu(
18561857 quantization : Optional [str ],
18571858 quantize_forward : bool ,
18581859 quantize_backward : bool ,
1860+ glu_linear_offset : float ,
18591861 limit : float = 0.75 ,
18601862 alpha : float = 1.702 ,
18611863 ):
@@ -1898,7 +1900,7 @@ def test_clamped_swiglu(
18981900 x_glu = x_glu .clamp (min = None , max = limit )
18991901 x_linear = x_linear .clamp (min = - limit , max = limit )
19001902 out_glu = x_glu * torch .sigmoid (alpha * x_glu )
1901- y_ref = out_glu * (x_linear + 1 )
1903+ y_ref = out_glu * (x_linear + glu_linear_offset )
19021904 y_ref .backward (dy_ref )
19031905
19041906 # Implementation with fusible operation
@@ -1909,6 +1911,7 @@ def test_clamped_swiglu(
19091911 te_ops .ClampedSwiGLU (
19101912 limit = limit ,
19111913 alpha = alpha ,
1914+ glu_linear_offset = glu_linear_offset ,
19121915 glu_interleave_size = glu_interleave_size ,
19131916 ),
19141917 te_ops .Quantize (forward = quantize_forward , backward = False ),
@@ -1938,6 +1941,7 @@ def test_interleaved_clamped_swiglu(self):
19381941 quantize_forward = False ,
19391942 quantize_backward = False ,
19401943 glu_interleave_size = 32 ,
1944+ glu_linear_offset = 1.0 ,
19411945 )
19421946
19431947 @pytest .mark .parametrize ("scale" , (1 , 0 , - 2.5 , 3.5 ))
@@ -2594,6 +2598,7 @@ def test_scaled_activation_recompute_in_mlp_config(self, op_cls) -> None:
25942598 @pytest .mark .parametrize ("in_shape" , ((71 , 192 ), (5 , 7 , 128 )))
25952599 @pytest .mark .parametrize ("input_requires_grad" , (False , True ))
25962600 @pytest .mark .parametrize ("scales_requires_grad" , (False , True ))
2601+ @pytest .mark .parametrize ("glu_linear_offset" , (1.0 , 0.0 ))
25972602 def test_scaled_clamped_qgeglu (
25982603 self ,
25992604 * ,
@@ -2603,6 +2608,7 @@ def test_scaled_clamped_qgeglu(
26032608 device : torch .device = "cuda" ,
26042609 input_requires_grad : bool ,
26052610 scales_requires_grad : bool ,
2611+ glu_linear_offset : float ,
26062612 limit : float = 7.0 ,
26072613 alpha : float = 1.702 ,
26082614 ) -> None :
@@ -2647,7 +2653,7 @@ def test_scaled_clamped_qgeglu(
26472653 x_glu = x_glu .clamp (min = None , max = limit )
26482654 x_linear = x_linear .clamp (min = - limit , max = limit )
26492655 out_glu = x_glu * torch .sigmoid (alpha * x_glu )
2650- y = out_glu * (x_linear + 1 )
2656+ y = out_glu * (x_linear + glu_linear_offset )
26512657 y_ref = scales_ref .unsqueeze (- 1 ) * y
26522658 if input_requires_grad or scales_requires_grad :
26532659 y_ref .backward (dy_ref )
@@ -2656,6 +2662,7 @@ def test_scaled_clamped_qgeglu(
26562662 glu_interleave_size = glu_interleave_size ,
26572663 limit = limit ,
26582664 alpha = alpha ,
2665+ glu_linear_offset = glu_linear_offset ,
26592666 )
26602667 y_test = op (x_test , scales_test )
26612668 if input_requires_grad or scales_requires_grad :
@@ -2674,6 +2681,7 @@ def test_interleaved_scaled_clamped_qgeglu(self):
26742681 glu_interleave_size = 32 ,
26752682 input_requires_grad = True ,
26762683 scales_requires_grad = True ,
2684+ glu_linear_offset = 1.0 ,
26772685 )
26782686
26792687
@@ -3685,7 +3693,13 @@ def test_layernorm_mlp(
36853693 @pytest .mark .parametrize ("delay_wgrad_compute" , (False , True ))
36863694 @pytest .mark .parametrize ("hidden_size" , (128 , 256 ))
36873695 @pytest .mark .parametrize (
3688- "activation" , ("scaled_swiglu" , "scaled_clamped_qgeglu" , "scaled_srelu" )
3696+ "activation" ,
3697+ (
3698+ "scaled_swiglu" ,
3699+ "scaled_clamped_qgeglu" ,
3700+ "scaled_clamped_qgeglu_custom" ,
3701+ "scaled_srelu" ,
3702+ ),
36893703 )
36903704 def test_grouped_mlp (
36913705 self ,
@@ -3719,7 +3733,7 @@ def test_grouped_mlp(
37193733 with_quantization = quantization is not None
37203734 if activation == "scaled_swiglu" :
37213735 scaled_act = te_ops .ScaledSwiGLU (glu_interleave_size = glu_interleave_size )
3722- elif activation == "scaled_clamped_qgeglu" :
3736+ elif activation . startswith ( "scaled_clamped_qgeglu" ) :
37233737 scaled_act = te_ops .ScaledClampedQGeGLU (glu_interleave_size = glu_interleave_size )
37243738 elif activation == "scaled_srelu" :
37253739 scaled_act = te_ops .ScaledSReLU ()
@@ -3742,13 +3756,23 @@ def test_grouped_mlp(
37423756 if (
37433757 with_quantization
37443758 and quantization in ("nvfp4" , "nvfp4_row_scaled" , "nvfp4_4over6" )
3745- and activation == "scaled_clamped_qgeglu"
3759+ and activation . startswith ( "scaled_clamped_qgeglu" )
37463760 and bias
37473761 ):
37483762 # TODO: ksivaman: Need to debug numerics for this case.
37493763 pytest .skip ("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU" )
37503764 fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size
37513765
3766+ # Activation parameters for clamped QGeGLU variants
3767+ if activation == "scaled_clamped_qgeglu_custom" :
3768+ geglu_limit = 5.0
3769+ geglu_alpha = 1.5
3770+ geglu_offset = 0.5
3771+ else :
3772+ geglu_limit = 7.0
3773+ geglu_alpha = 1.702
3774+ geglu_offset = 1.0
3775+
37523776 # Random data
37533777 x_ref , x_test = make_reference_and_test_tensors (
37543778 in_shape ,
@@ -3840,13 +3864,12 @@ def test_grouped_mlp(
38403864 if activation == "scaled_swiglu" :
38413865 x1 , x2 = x .chunk (2 , dim = - 1 )
38423866 x = torch .nn .functional .silu (x1 ) * x2
3843- elif activation == "scaled_clamped_qgeglu" :
3867+ elif activation . startswith ( "scaled_clamped_qgeglu" ) :
38443868 x1 , x2 = x .chunk (2 , dim = - 1 )
3845- lim = torch .tensor (7.0 , device = x1 .device , dtype = x1 .dtype )
3846- geglu_alpha = 1.702
3869+ lim = torch .tensor (geglu_limit , device = x1 .device , dtype = x1 .dtype )
38473870 x1c = torch .minimum (x1 , lim )
38483871 x2c = torch .clamp (x2 , - lim , lim )
3849- x = (x2c + 1 ) * (x1c * torch .sigmoid (geglu_alpha * x1c ))
3872+ x = (x2c + geglu_offset ) * (x1c * torch .sigmoid (geglu_alpha * x1c ))
38503873 elif activation == "scaled_srelu" :
38513874 x = torch .nn .functional .relu (x ).square ()
38523875 else :
@@ -3861,6 +3884,13 @@ def test_grouped_mlp(
38613884
38623885 # Construct operations
38633886 recipe = make_recipe (quantization )
3887+ if activation == "scaled_clamped_qgeglu_custom" :
3888+ scaled_act = te_ops .ScaledClampedQGeGLU (
3889+ glu_interleave_size = glu_interleave_size ,
3890+ limit = geglu_limit ,
3891+ alpha = geglu_alpha ,
3892+ glu_linear_offset = geglu_offset ,
3893+ )
38643894 with te .quantized_model_init (enabled = with_quantization , recipe = recipe ):
38653895 fc1 = te_ops .GroupedLinear (
38663896 group_size ,
0 commit comments