@@ -3578,7 +3578,14 @@ def test_layernorm_mlp(
35783578 @pytest .mark .parametrize ("glu_interleave_size" , (None , 32 ))
35793579 @pytest .mark .parametrize ("delay_wgrad_compute" , (False , True ))
35803580 @pytest .mark .parametrize ("hidden_size" , (128 , 256 ))
3581- @pytest .mark .parametrize ("activation" , ("scaled_swiglu" , "scaled_clamped_qgeglu" ))
3581+ @pytest .mark .parametrize (
3582+ "activation" ,
3583+ (
3584+ "scaled_swiglu" ,
3585+ "scaled_clamped_qgeglu" ,
3586+ "scaled_clamped_qgeglu_custom" ,
3587+ ),
3588+ )
35823589 def test_grouped_mlp (
35833590 self ,
35843591 * ,
@@ -3623,10 +3630,20 @@ def test_grouped_mlp(
36233630 pytest .skip ("single_grouped_bias requires bias=True" )
36243631 if with_quantization and dtype not in (torch .bfloat16 , torch .float16 ):
36253632 pytest .skip ("Quantized group GEMM is only supported with BF16/FP16" )
3626- if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias :
3633+ if quantization == "nvfp4" and activation . startswith ( "scaled_clamped_qgeglu" ) and bias :
36273634 # TODO: ksivaman: Need to debug numerics for this case.
36283635 pytest .skip ("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU" )
36293636
3637+ # Activation parameters for clamped QGeGLU variants
3638+ if activation == "scaled_clamped_qgeglu_custom" :
3639+ geglu_limit = 5.0
3640+ geglu_alpha = 1.5
3641+ geglu_offset = 0.5
3642+ else :
3643+ geglu_limit = 7.0
3644+ geglu_alpha = 1.702
3645+ geglu_offset = 1.0
3646+
36303647 # Random data
36313648 x_ref , x_test = make_reference_and_test_tensors (
36323649 in_shape ,
@@ -3717,11 +3734,10 @@ def test_grouped_mlp(
37173734 if activation == "scaled_swiglu" :
37183735 x = torch .nn .functional .silu (x1 ) * x2
37193736 else :
3720- lim = torch .tensor (7.0 , device = x1 .device , dtype = x1 .dtype )
3721- geglu_alpha = 1.702
3737+ lim = torch .tensor (geglu_limit , device = x1 .device , dtype = x1 .dtype )
37223738 x1c = torch .minimum (x1 , lim )
37233739 x2c = torch .clamp (x2 , - lim , lim )
3724- x = (x2c + 1 ) * (x1c * torch .sigmoid (geglu_alpha * x1c ))
3740+ x = (x2c + geglu_offset ) * (x1c * torch .sigmoid (geglu_alpha * x1c ))
37253741 x = x * probs [group_idx ].unsqueeze (- 1 )
37263742 x = torch .nn .functional .linear (x , fc2_ws_ref [group_idx ])
37273743 if bias :
@@ -3732,11 +3748,15 @@ def test_grouped_mlp(
37323748
37333749 # Construct operations
37343750 recipe = make_recipe (quantization )
3735- scaled_act = (
3736- te_ops .ScaledSwiGLU (glu_interleave_size = glu_interleave_size )
3737- if activation == "scaled_swiglu"
3738- else te_ops .ScaledClampedQGeGLU (glu_interleave_size = glu_interleave_size )
3739- )
3751+ if activation == "scaled_swiglu" :
3752+ scaled_act = te_ops .ScaledSwiGLU (glu_interleave_size = glu_interleave_size )
3753+ else :
3754+ scaled_act = te_ops .ScaledClampedQGeGLU (
3755+ glu_interleave_size = glu_interleave_size ,
3756+ limit = geglu_limit ,
3757+ alpha = geglu_alpha ,
3758+ glu_linear_offset = geglu_offset ,
3759+ )
37403760 with te .quantized_model_init (enabled = with_quantization , recipe = recipe ):
37413761 fc1 = te_ops .GroupedLinear (
37423762 group_size ,
0 commit comments