Skip to content

Commit 257253a

Browse files
committed
update the fusion path
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
1 parent 1e9cdac commit 257253a

4 files changed

Lines changed: 96 additions & 25 deletions

File tree

tests/pytorch/test_fusible_ops.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3578,7 +3578,14 @@ def test_layernorm_mlp(
35783578
@pytest.mark.parametrize("glu_interleave_size", (None, 32))
35793579
@pytest.mark.parametrize("delay_wgrad_compute", (False, True))
35803580
@pytest.mark.parametrize("hidden_size", (128, 256))
3581-
@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu"))
3581+
@pytest.mark.parametrize(
3582+
"activation",
3583+
(
3584+
"scaled_swiglu",
3585+
"scaled_clamped_qgeglu",
3586+
"scaled_clamped_qgeglu_custom",
3587+
),
3588+
)
35823589
def test_grouped_mlp(
35833590
self,
35843591
*,
@@ -3623,10 +3630,20 @@ def test_grouped_mlp(
36233630
pytest.skip("single_grouped_bias requires bias=True")
36243631
if with_quantization and dtype not in (torch.bfloat16, torch.float16):
36253632
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
3626-
if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias:
3633+
if quantization == "nvfp4" and activation.startswith("scaled_clamped_qgeglu") and bias:
36273634
# TODO: ksivaman: Need to debug numerics for this case.
36283635
pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU")
36293636

3637+
# Activation parameters for clamped QGeGLU variants
3638+
if activation == "scaled_clamped_qgeglu_custom":
3639+
geglu_limit = 5.0
3640+
geglu_alpha = 1.5
3641+
geglu_offset = 0.5
3642+
else:
3643+
geglu_limit = 7.0
3644+
geglu_alpha = 1.702
3645+
geglu_offset = 1.0
3646+
36303647
# Random data
36313648
x_ref, x_test = make_reference_and_test_tensors(
36323649
in_shape,
@@ -3717,11 +3734,10 @@ def test_grouped_mlp(
37173734
if activation == "scaled_swiglu":
37183735
x = torch.nn.functional.silu(x1) * x2
37193736
else:
3720-
lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype)
3721-
geglu_alpha = 1.702
3737+
lim = torch.tensor(geglu_limit, device=x1.device, dtype=x1.dtype)
37223738
x1c = torch.minimum(x1, lim)
37233739
x2c = torch.clamp(x2, -lim, lim)
3724-
x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c))
3740+
x = (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c))
37253741
x = x * probs[group_idx].unsqueeze(-1)
37263742
x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx])
37273743
if bias:
@@ -3732,11 +3748,15 @@ def test_grouped_mlp(
37323748

37333749
# Construct operations
37343750
recipe = make_recipe(quantization)
3735-
scaled_act = (
3736-
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
3737-
if activation == "scaled_swiglu"
3738-
else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
3739-
)
3751+
if activation == "scaled_swiglu":
3752+
scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
3753+
else:
3754+
scaled_act = te_ops.ScaledClampedQGeGLU(
3755+
glu_interleave_size=glu_interleave_size,
3756+
limit=geglu_limit,
3757+
alpha=geglu_alpha,
3758+
glu_linear_offset=geglu_offset,
3759+
)
37403760
with te.quantized_model_init(enabled=with_quantization, recipe=recipe):
37413761
fc1 = te_ops.GroupedLinear(
37423762
group_size,

transformer_engine/pytorch/ops/_common.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,34 @@
2121
from ..utils import canonicalize_dtype
2222

2323

24+
@functools.lru_cache(maxsize=1)
25+
def _cudnn_frontend_version() -> Optional[PkgVersion]:
26+
"""Return the installed cuDNN-frontend version, or ``None``."""
27+
try:
28+
return PkgVersion(get_pkg_version("nvidia-cudnn-frontend"))
29+
except PackageNotFoundError:
30+
return None
31+
32+
2433
@functools.lru_cache(maxsize=1)
2534
def _cudnn_frontend_version_supported() -> bool:
2635
"""Check cuDNN frontend is at least 1.23.0.
2736
28-
All grouped MLP fused-kernel features require cuDNN frontend 1.23.0.
37+
All grouped MLP fused-kernel features require cuDNN frontend >= 1.23.0.
2938
"""
30-
try:
31-
return PkgVersion(get_pkg_version("nvidia-cudnn-frontend")) >= PkgVersion("1.23.0")
32-
except PackageNotFoundError:
33-
return False
39+
ver = _cudnn_frontend_version()
40+
return ver is not None and ver >= PkgVersion("1.23.0")
41+
42+
43+
@functools.lru_cache(maxsize=1)
44+
def _cudnn_frontend_geglu_runtime_params() -> bool:
45+
"""Check cuDNN frontend is at least 1.24.0.
46+
47+
Runtime-configurable GeGLU parameters (linear_offset, geglu_alpha,
48+
glu_clamp_max, glu_clamp_min) require cuDNN frontend >= 1.24.0.
49+
"""
50+
ver = _cudnn_frontend_version()
51+
return ver is not None and ver >= PkgVersion("1.24.0")
3452

3553

3654
def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool:
@@ -256,9 +274,14 @@ def fuse_grouped_mlp_ops(
256274
and isinstance(window[2], GroupedLinear)
257275
):
258276
matches_pattern = False
259-
elif isinstance(window[1], ScaledClampedQGeGLU) and (
260-
abs(window[1]._clamped.alpha - 1.702) > 0.001
261-
or abs(window[1]._clamped.glu_linear_offset - 1.0) > 0.001
277+
elif (
278+
isinstance(window[1], ScaledClampedQGeGLU)
279+
and not _cudnn_frontend_geglu_runtime_params()
280+
and (
281+
abs(window[1]._clamped.alpha - 1.702) > 0.001
282+
or abs(window[1]._clamped.glu_linear_offset - 1.0) > 0.001
283+
or abs(window[1]._clamped.limit - 7.0) > 0.001
284+
)
262285
):
263286
matches_pattern = False
264287
elif window[0].num_groups != window[2].num_groups:

transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..fuser import register_backward_fusion
2323
from ..op import FusedOperation, FusibleOperation, OperationContext
2424
from .._common import (
25+
_cudnn_frontend_geglu_runtime_params,
2526
_cudnn_frontend_version_supported,
2627
fuse_grouped_mlp_ops,
2728
get_accumulate_flag_in_param,
@@ -314,11 +315,16 @@ def __init__(
314315
self.grouped_gemm_dglu_kernel() # Try triggering import error
315316
raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.")
316317
validate_grouped_mlp_dims(fc1, swiglu, fc2)
317-
# The cuDNN dgeglu implementation corresponds to ScaledClampedQGeGLU.
318-
# The act_func string should be fixed on the cuDNN FE side.
319-
self._cudnn_dact_func: str = (
320-
"dgeglu" if isinstance(swiglu, ScaledClampedQGeGLU) else "dswiglu"
318+
is_geglu = isinstance(swiglu, ScaledClampedQGeGLU)
319+
self._cudnn_dact_func: str = "dgeglu" if is_geglu else "dswiglu"
320+
self._pass_geglu_runtime_params: bool = (
321+
is_geglu and _cudnn_frontend_geglu_runtime_params()
321322
)
323+
if self._pass_geglu_runtime_params:
324+
self._cudnn_linear_offset: float = swiglu._clamped.glu_linear_offset
325+
self._cudnn_geglu_alpha: float = swiglu._clamped.alpha
326+
self._cudnn_glu_clamp_max: float = swiglu._clamped.limit
327+
self._cudnn_glu_clamp_min: float = -swiglu._clamped.limit
322328

323329
def fuser_backward(
324330
self,
@@ -472,6 +478,13 @@ def fuser_backward(
472478
"act_func": self._cudnn_dact_func,
473479
"use_dynamic_sched": True,
474480
}
481+
if self._pass_geglu_runtime_params:
482+
fc2_dglu_kwargs.update(
483+
linear_offset=self._cudnn_linear_offset,
484+
geglu_alpha=self._cudnn_geglu_alpha,
485+
glu_clamp_max=self._cudnn_glu_clamp_max,
486+
glu_clamp_min=self._cudnn_glu_clamp_min,
487+
)
475488

476489
if fc2_op.single_grouped_weight:
477490
# Clone and swizzle scales for GEMM

transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ..fuser import register_forward_fusion
2424
from ..op import FusedOperation, FusibleOperation, OperationContext
2525
from .._common import (
26+
_cudnn_frontend_geglu_runtime_params,
2627
_cudnn_frontend_version_supported,
2728
fuse_grouped_mlp_ops,
2829
is_quantized_tensor,
@@ -97,9 +98,16 @@ def __init__(
9798
self.grouped_gemm_glu_kernel() # Try triggering import error
9899
raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.")
99100
validate_grouped_mlp_dims(fc1, swiglu, fc2)
100-
# The cuDNN geglu implementation corresponds to ScaledClampedQGeGLU.
101-
# The act_func string should be fixed on the cuDNN FE side.
102-
self._cudnn_act_func: str = "geglu" if isinstance(swiglu, ScaledClampedQGeGLU) else "swiglu"
101+
is_geglu = isinstance(swiglu, ScaledClampedQGeGLU)
102+
self._cudnn_act_func: str = "geglu" if is_geglu else "swiglu"
103+
self._pass_geglu_runtime_params: bool = (
104+
is_geglu and _cudnn_frontend_geglu_runtime_params()
105+
)
106+
if self._pass_geglu_runtime_params:
107+
self._cudnn_linear_offset: float = swiglu._clamped.glu_linear_offset
108+
self._cudnn_geglu_alpha: float = swiglu._clamped.alpha
109+
self._cudnn_glu_clamp_max: float = swiglu._clamped.limit
110+
self._cudnn_glu_clamp_min: float = -swiglu._clamped.limit
103111

104112
def fuser_forward(
105113
self,
@@ -305,6 +313,13 @@ def fuser_forward(
305313
"act_func": self._cudnn_act_func,
306314
"use_dynamic_sched": True,
307315
}
316+
if self._pass_geglu_runtime_params:
317+
fc1_glu_kwargs.update(
318+
linear_offset=self._cudnn_linear_offset,
319+
geglu_alpha=self._cudnn_geglu_alpha,
320+
glu_clamp_max=self._cudnn_glu_clamp_max,
321+
glu_clamp_min=self._cudnn_glu_clamp_min,
322+
)
308323

309324
if fc1_op.single_grouped_weight:
310325
# Clone and swizzle scales for GEMM.

0 commit comments

Comments
 (0)