Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import (
from transformer_engine.pytorch.ops.fused.grouped_mlp import (
_cudnn_frontend_supports_grouped_gemm_srelu,
_cudnn_frontend_version_supported,
Comment on lines 22 to 25

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test imports module-private helpers directly from grouped_mlp. These names (_cudnn_frontend_version_supported, _cudnn_frontend_supports_grouped_gemm_srelu) carry a leading _ signalling they are internal implementation details. Importing them from tests tightly couples the test suite to internal symbol names; a later rename or move would break the import with no deprecation period. Consider exposing a small public helper or adding a # noqa: PLC2701 comment with a brief rationale.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better or for worse, these tests are tightly coupled with the implementation of the grouped MLP fused op and changing them is beyond the scope of this PR.

is_glu_activation,
Expand Down Expand Up @@ -4028,25 +4028,19 @@ def _make_module():
)
if expected_grouped_mlp_fusion:
if activation_is_glu:
forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU
backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU
fused_cls = te_ops.fused.GroupedMLP_CuTeGEMMGLU
else:
forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary
backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary
if forward_cls.is_supported():
fused_cls = te_ops.fused.GroupedMLP_CuTeGEMMUnary
if fused_cls.is_supported():
forward_ops = module._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
forward_cls,
)
if backward_cls is not None and backward_cls.is_supported():
backward_ops = module._module_groups[0]._backward_ops
assert len(forward_ops) == 1
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
backward_cls,
forward_ops[0][0],
fused_cls,
)
assert backward_ops[0][0] is forward_ops[0][0]

# Loose tols for sanity checking
tols = {"rtol": 0.125, "atol": 0.25}
Expand Down Expand Up @@ -4130,10 +4124,8 @@ def test_grouped_mlp_single_weight_numerics(
) -> None:
"""single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP."""

if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported():
pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system")
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported():
pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system")
if not te_ops.fused.GroupedMLP_CuTeGEMMGLU.is_supported():
pytest.skip("MXFP8 fused grouped MLP is not supported on this system")

split_sizes = [split_alignment * (i + 1) for i in range(group_size)]
random.shuffle(split_sizes)
Expand Down Expand Up @@ -4234,13 +4226,14 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]:
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU,
te_ops.fused.GroupedMLP_CuTeGEMMGLU,
)
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU,
te_ops.fused.GroupedMLP_CuTeGEMMGLU,
)
assert backward_ops[0][0] is forward_ops[0][0]

if single_grouped_weight:
fc1_dw = fc1.weight.grad.detach().clone()
Expand Down Expand Up @@ -4352,10 +4345,8 @@ def test_grouped_mlp_overwrite_main_grad(
that read ``.grad`` don't see stale bytes from the cached dummy).
"""

if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported():
pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system")
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported():
pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system")
if not te_ops.fused.GroupedMLP_CuTeGEMMGLU.is_supported():
pytest.skip("MXFP8 fused grouped MLP is not supported on this system")

recipe = make_recipe("mxfp8")
split_sizes = [split_alignment * (i + 1) for i in range(group_size)]
Expand Down Expand Up @@ -4486,7 +4477,7 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8(
) -> None:
"""Grouped MLP forward+backward should be CUDA graph capturable (MXFP8)."""

if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported():
if not te_ops.fused.GroupedMLP_CuTeGEMMGLU.is_supported():
pytest.skip("MXFP8 fused grouped MLP is not supported on this system")
if dtype not in (torch.bfloat16, torch.float16):
pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16")
Expand Down Expand Up @@ -4628,13 +4619,14 @@ def train_step(
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU,
te_ops.fused.GroupedMLP_CuTeGEMMGLU,
)
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU,
te_ops.fused.GroupedMLP_CuTeGEMMGLU,
)
assert backward_ops[0][0] is forward_ops[0][0]

fresh_x = torch.randn_like(static_x)
fresh_probs = torch.randn_like(static_probs)
Expand Down
Loading
Loading