@@ -613,31 +613,39 @@ def get_activation_recompute_contexts():
613613 return forward_ctx , recompute_ctx
614614
615615
616- def has_te_modules (network ):
616+ @lru_cache
617+ def get_te_classes ():
617618 """
618- Check if there are any Transformer Engine modules in the network .
619+ Return all Transformer Engine modules.
619620 """
620621 from .module import LayerNorm , RMSNorm
621622 from .module .base import TransformerEngineBaseModule
623+ from .attention .dot_product_attention .dot_product_attention import (
624+ DotProductAttention ,
625+ )
622626 from .attention .dot_product_attention .backends import UnfusedDotProductAttention
623- from .attention .dot_product_attention .dot_product_attention import DotProductAttention
624627 from .attention .multi_head_attention import MultiheadAttention
625628 from .transformer import TransformerLayer
626629
627- te_classes_list = [
630+ return (
628631 LayerNorm ,
629632 RMSNorm ,
630633 TransformerEngineBaseModule ,
631634 UnfusedDotProductAttention ,
632635 DotProductAttention ,
633636 MultiheadAttention ,
634637 TransformerLayer ,
635- ]
638+ )
636639
640+
641+ def has_te_modules (network ):
642+ """
643+ Check if there are any Transformer Engine modules in the network.
644+ """
645+ te_classes = get_te_classes ()
637646 if isinstance (network , torch .nn .Module ):
638- for module in network .modules ():
639- if any (isinstance (module , te_class ) for te_class in te_classes_list ):
640- return True
647+ if any (isinstance (module , te_classes ) for module in network .modules ()):
648+ return True
641649 return False
642650
643651 # Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module,
@@ -2040,28 +2048,8 @@ def _is_te_module(module):
20402048 Check if given module is a Transformer Engine module that requires the TE checkpoint
20412049 implementation for activation recompute.
20422050 """
2043- from .module import LayerNorm , RMSNorm
2044- from .module .base import TransformerEngineBaseModule
2045- from .attention .dot_product_attention .dot_product_attention import DotProductAttention
2046- from .attention .dot_product_attention .backends import UnfusedDotProductAttention
2047- from .attention .multi_head_attention import MultiheadAttention
2048- from .transformer import TransformerLayer
2049-
2050- te_classes_list = [
2051- LayerNorm ,
2052- RMSNorm ,
2053- TransformerEngineBaseModule ,
2054- UnfusedDotProductAttention ,
2055- DotProductAttention ,
2056- MultiheadAttention ,
2057- TransformerLayer ,
2058- ]
2059- is_te_module = False
2060- for te_class in te_classes_list :
2061- if isinstance (module , te_class ):
2062- is_te_module = True
2063- break
2064- return is_te_module
2051+ te_classes = get_te_classes ()
2052+ return isinstance (module , te_classes )
20652053
20662054
20672055def prepare_te_modules_for_fsdp (fsdp_root : torch .nn .Module ) -> None :
0 commit comments