diff --git a/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py b/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py index f54f48e7..43890b92 100644 --- a/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py +++ b/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py @@ -16,11 +16,12 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple -import deepspeed import torch from torch import nn from transformers import Trainer +from angelslim.utils.lazy_imports import deepspeed + from ...utils import padding diff --git a/angelslim/utils/lazy_imports.py b/angelslim/utils/lazy_imports.py index 00eecd8a..ffa78e58 100644 --- a/angelslim/utils/lazy_imports.py +++ b/angelslim/utils/lazy_imports.py @@ -34,6 +34,7 @@ class LazyModule: _module_name (str): The full name of the module to import _extra_group (str): The extra dependency group required for this module _module (ModuleType): The actual imported module (None until first access) + _submodules (dict): Cache for LazyModule instances of submodules Example: >>> ray = LazyModule('ray', 'speculative') @@ -52,36 +53,81 @@ def __init__(self, module_name: str, extra_group: str = None): self._module_name = module_name self._extra_group = extra_group self._module = None + self._submodules = {} + + def _import_module(self): + """ + Import the target module if not already imported. + + Raises: + ImportError: If the module cannot be imported + """ + if self._module is None: + try: + self._module = importlib.import_module(self._module_name) + except ImportError as e: + if self._extra_group: + raise ImportError( + f"Module '{self._module_name}' requires " + f"additional dependencies. Please install: " + f"pip install 'angelslim[{self._extra_group}]'" + ) from e + raise def __getattr__(self, name: str) -> Any: """ Delegate attribute access to the actual module. On first access, this method imports the target module and then - delegates the attribute lookup to the actual module. + delegates the attribute lookup to the actual module. For submodules, + it returns a LazyModule instance to support multi-level access. Args: name: Name of the attribute to access Returns: - The requested attribute from the target module + The requested attribute from the target module, or a LazyModule + instance for submodules Raises: ImportError: If the module cannot be imported and an extra_group is specified, provides installation instructions """ - if self._module is None: + # Return cached submodule if exists + if name in self._submodules: + return self._submodules[name] + + self._import_module() + + # Try to get the attribute from the imported module + try: + attr = getattr(self._module, name) + # If it's a module, wrap it in LazyModule for consistent behavior + if isinstance(attr, type(self._module)): + submodule_name = f"{self._module_name}.{name}" + lazy_submodule = LazyModule(submodule_name, self._extra_group) + lazy_submodule._module = attr # Cache the already imported module + self._submodules[name] = lazy_submodule + return lazy_submodule + return attr + except AttributeError: + # If attribute not found, try importing as a submodule + submodule_name = f"{self._module_name}.{name}" try: - self._module = importlib.import_module(self._module_name) - except ImportError as e: - if self._extra_group: - raise ImportError( - f"Module '{self._module_name}' requires " - f"additional dependencies. Please install: " - f"pip install 'angelslim[{self._extra_group}]'" - ) from e - raise - return getattr(self._module, name) + # Create a LazyModule for the submodule + lazy_submodule = LazyModule(submodule_name, self._extra_group) + # Trigger import to verify it exists + lazy_submodule._import_module() + # Cache it + self._submodules[name] = lazy_submodule + # Also cache in parent module for consistency + setattr(self._module, name, lazy_submodule._module) + return lazy_submodule + except ImportError: + # If submodule import fails, re-raise the original AttributeError + raise AttributeError( + f"module '{self._module_name}' has no attribute '{name}'" + ) class LazyAttribute: @@ -154,6 +200,7 @@ def __getattr__(self, name: str) -> Any: anthropic = LazyModule("anthropic", "speculative") jsonschema_specifications = LazyModule("jsonschema_specifications", "speculative") referencing = LazyModule("referencing", "speculative") +deepspeed = LazyModule("deepspeed", "speculative") # --- VLM related lazy imports ---