From 897fe44cb1261cf24bcea25875daacdd72d01a23 Mon Sep 17 00:00:00 2001 From: Saatvik Arya Date: Thu, 10 Jul 2025 16:53:41 +0530 Subject: [PATCH] implement lazy imports for BetterTransformer to handle import errors gracefully --- .../infinity_emb/transformer/acceleration.py | 39 ++++++++++++++++--- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/transformer/acceleration.py b/libs/infinity_emb/infinity_emb/transformer/acceleration.py index 1d7b7c7f3..5b2f23129 100644 --- a/libs/infinity_emb/infinity_emb/transformer/acceleration.py +++ b/libs/infinity_emb/infinity_emb/transformer/acceleration.py @@ -7,11 +7,24 @@ from infinity_emb._optional_imports import CHECK_OPTIMUM, CHECK_TORCH, CHECK_TRANSFORMERS from infinity_emb.primitives import Device -if CHECK_OPTIMUM.is_available: - from optimum.bettertransformer import ( # type: ignore[import-untyped] - BetterTransformer, - BetterTransformerManager, - ) +# lazy imports to avoid issues with deprecated BetterTransformer +BetterTransformer = None +BetterTransformerManager = None + +def _import_bettertransformer(): + """Lazy import BetterTransformer to avoid import errors when it's not needed.""" + global BetterTransformer, BetterTransformerManager + if BetterTransformer is None and CHECK_OPTIMUM.is_available: + try: + from optimum.bettertransformer import ( # type: ignore[import-untyped] + BetterTransformer as _BetterTransformer, + BetterTransformerManager as _BetterTransformerManager, + ) + BetterTransformer = _BetterTransformer + BetterTransformerManager = _BetterTransformerManager + except Exception: + # If import fails, keep them as None + pass if CHECK_TORCH.is_available: import torch @@ -37,6 +50,11 @@ def check_if_bettertransformer_possible(engine_args: "EngineArgs") -> bool: if not engine_args.bettertransformer: return False + _import_bettertransformer() + + if BetterTransformerManager is None: + return False + config = AutoConfig.from_pretrained( pretrained_model_name_or_path=engine_args.model_name_or_path, revision=engine_args.revision, @@ -65,6 +83,15 @@ def to_bettertransformer(model: "PreTrainedModel", engine_args: "EngineArgs", lo "INFINITY_DISABLE_OPTIMUM is no longer supported, please use the CLI / ENV for that." ) + _import_bettertransformer() + + if BetterTransformer is None: + logger.warning( + "BetterTransformer is not available (likely due to transformers version incompatibility). " + "Continue without bettertransformer modeling code." + ) + return model + if ( hasattr(model.config, "_attn_implementation") and model.config._attn_implementation != "eager" @@ -80,7 +107,7 @@ def to_bettertransformer(model: "PreTrainedModel", engine_args: "EngineArgs", lo "Since torch 2.5.0, this combination leads to a segfault. Please report if you find this check to be incorrect." ) try: - model = BetterTransformer.transform(model) + model = BetterTransformer.transform(model) # type: ignore except Exception as ex: # if level is debug then show the exception if logger.level <= 10: