Skip to content

Commit 82776bc

Browse files
authored
refactor(distributed): deduplicate TE module class lookups with caching (#2992)
- Extract common get_te_classes() with @lru_cache for reuse - Refactor has_te_modules() and _is_te_module() to use tuple isinstance check - Remove duplicated import lists across multiple functions Signed-off-by: Muu <koimuu@163.com>
1 parent d95b34c commit 82776bc

1 file changed

Lines changed: 18 additions & 30 deletions

File tree

transformer_engine/pytorch/distributed.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -613,31 +613,39 @@ def get_activation_recompute_contexts():
613613
return forward_ctx, recompute_ctx
614614

615615

616-
def has_te_modules(network):
616+
@lru_cache
617+
def get_te_classes():
617618
"""
618-
Check if there are any Transformer Engine modules in the network.
619+
Return all Transformer Engine modules.
619620
"""
620621
from .module import LayerNorm, RMSNorm
621622
from .module.base import TransformerEngineBaseModule
623+
from .attention.dot_product_attention.dot_product_attention import (
624+
DotProductAttention,
625+
)
622626
from .attention.dot_product_attention.backends import UnfusedDotProductAttention
623-
from .attention.dot_product_attention.dot_product_attention import DotProductAttention
624627
from .attention.multi_head_attention import MultiheadAttention
625628
from .transformer import TransformerLayer
626629

627-
te_classes_list = [
630+
return (
628631
LayerNorm,
629632
RMSNorm,
630633
TransformerEngineBaseModule,
631634
UnfusedDotProductAttention,
632635
DotProductAttention,
633636
MultiheadAttention,
634637
TransformerLayer,
635-
]
638+
)
636639

640+
641+
def has_te_modules(network):
642+
"""
643+
Check if there are any Transformer Engine modules in the network.
644+
"""
645+
te_classes = get_te_classes()
637646
if isinstance(network, torch.nn.Module):
638-
for module in network.modules():
639-
if any(isinstance(module, te_class) for te_class in te_classes_list):
640-
return True
647+
if any(isinstance(module, te_classes) for module in network.modules()):
648+
return True
641649
return False
642650

643651
# Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module,
@@ -2040,28 +2048,8 @@ def _is_te_module(module):
20402048
Check if given module is a Transformer Engine module that requires the TE checkpoint
20412049
implementation for activation recompute.
20422050
"""
2043-
from .module import LayerNorm, RMSNorm
2044-
from .module.base import TransformerEngineBaseModule
2045-
from .attention.dot_product_attention.dot_product_attention import DotProductAttention
2046-
from .attention.dot_product_attention.backends import UnfusedDotProductAttention
2047-
from .attention.multi_head_attention import MultiheadAttention
2048-
from .transformer import TransformerLayer
2049-
2050-
te_classes_list = [
2051-
LayerNorm,
2052-
RMSNorm,
2053-
TransformerEngineBaseModule,
2054-
UnfusedDotProductAttention,
2055-
DotProductAttention,
2056-
MultiheadAttention,
2057-
TransformerLayer,
2058-
]
2059-
is_te_module = False
2060-
for te_class in te_classes_list:
2061-
if isinstance(module, te_class):
2062-
is_te_module = True
2063-
break
2064-
return is_te_module
2051+
te_classes = get_te_classes()
2052+
return isinstance(module, te_classes)
20652053

20662054

20672055
def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:

0 commit comments

Comments
 (0)