diff --git a/angelslim/compressor/speculative/benchmark/pytorch/benchmark_engine.py b/angelslim/compressor/speculative/benchmark/pytorch/benchmark_engine.py index ec10c50e..4444996a 100644 --- a/angelslim/compressor/speculative/benchmark/pytorch/benchmark_engine.py +++ b/angelslim/compressor/speculative/benchmark/pytorch/benchmark_engine.py @@ -57,6 +57,7 @@ class BenchmarkConfig: total_token: int = 60 depth: int = 5 top_k: int = 10 + top_p: float = 1.0 # Hardware settings num_gpus_per_model: int = 1 @@ -73,6 +74,9 @@ class BenchmarkConfig: # SpecExit early_stop_method: Optional[str] = None + # Batch settings + batch_size: int = 1 + class BenchmarkEngine: """Core benchmark engine for speculative decoding evaluation""" diff --git a/angelslim/compressor/speculative/train/configs/deepspeed_zero2.json b/angelslim/compressor/speculative/train/configs/deepspeed_zero2.json new file mode 100644 index 00000000..8c4b70df --- /dev/null +++ b/angelslim/compressor/speculative/train/configs/deepspeed_zero2.json @@ -0,0 +1,38 @@ +{ + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto", + "adam_w_mode": true, + "betas": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 2, + "stage3_gather_16bit_weights_on_model_save": true, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py index 2495f922..3352ca87 100644 --- a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py +++ b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py @@ -17,6 +17,8 @@ import torch +from angelslim.utils import decide_device_for_distributed, print_with_rank + class BaseBackend(ABC): """Base class for model backends""" @@ -49,13 +51,21 @@ class TransformersBackend(BaseBackend): def load_model(self): from transformers import AutoModelForCausalLM, AutoTokenizer + # Get device based on environment + device = decide_device_for_distributed() + + # Print device information with rank details + print_with_rank(f"Loading model to device: {device}") + + # Update kwargs with default values default_kwargs = { "dtype": torch.bfloat16, - "device_map": "auto", + "device_map": device, "trust_remote_code": True, } default_kwargs.update(self.kwargs) + # Load model to specific device based on rank self.model = AutoModelForCausalLM.from_pretrained( self.model_path, **default_kwargs ) diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py index 8f63b4d8..0fa163e6 100644 --- a/angelslim/utils/__init__.py +++ b/angelslim/utils/__init__.py @@ -16,6 +16,7 @@ from .default_compress_config import * # noqa: F401 F403 from .lazy_imports import * # noqa: F401 F403 from .utils import common_prefix # noqa: F401 +from .utils import decide_device_for_distributed # noqa: F401 from .utils import find_layers # noqa: F401 from .utils import find_parent_layer_and_sub_name # noqa: F401 from .utils import get_best_device # noqa: F401 @@ -25,5 +26,6 @@ from .utils import get_tensor_item # noqa: F401 from .utils import get_yaml_prefix_simple # noqa: F401 from .utils import print_info # noqa: F401 +from .utils import print_with_rank # noqa: F401 from .utils import rank0_print # noqa: F401 from .utils import set_op_by_name # noqa: F401 diff --git a/angelslim/utils/utils.py b/angelslim/utils/utils.py index 7a4e76fe..c11ac36b 100644 --- a/angelslim/utils/utils.py +++ b/angelslim/utils/utils.py @@ -191,3 +191,95 @@ def rank0_print(*args, **kwargs): if rank == 0: print(*args, **kwargs) + + +def _get_distributed_info(): + """ + Get distributed training information. + + Returns: + Tuple of (rank, world_size, local_rank): + - rank: Global rank in distributed training (-1 if not distributed) + - world_size: Total number of processes (1 if not distributed) + - local_rank: Local rank on current node (-1 if not set) + """ + rank = -1 + world_size = 1 + local_rank = -1 + + # Check for torchrun environment variable first + if "LOCAL_RANK" in os.environ: + local_rank = int(os.environ["LOCAL_RANK"]) + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + # Then check if distributed is initialized + elif dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + + return rank, world_size, local_rank + + +def print_with_rank(*args, **kwargs): + """ + Print function with rank information for distributed training. + + Automatically detects the current process rank and includes it in the output. + Works with torchrun, torch.distributed, or single process environments. + + Args: + *args: Arguments to print + **kwargs: Keyword arguments for print function + + Example: + print_with_rank("Model loaded successfully") + # Single node: [Rank 0/4] Model loaded successfully + # Multi-node: [Rank 0/8, Local 0] Model loaded successfully + """ + rank, world_size, local_rank = _get_distributed_info() + + # Format rank information + if rank >= 0: + # Show local_rank only when it's different from rank (multi-node scenario) + if local_rank >= 0 and local_rank != rank: + prefix = f"[Rank {rank}/{world_size}, Local {local_rank}]" + else: + prefix = f"[Rank {rank}/{world_size}]" + else: + prefix = "[Single Process]" + + # Print with rank prefix + print(prefix, *args, **kwargs) + + +def decide_device_for_distributed(): + """ + Decide the appropriate device for model in distributed training context (torchrun). + + Device selection priority: + 1. LOCAL_RANK environment variable (torchrun launcher) + 2. Distributed rank (if torch.distributed is initialized) + 3. cuda:0 or cpu (single process fallback) + + Returns: + str: Device string like 'cuda:0' or 'cpu' + + Example: + device = decide_device_for_distributed() + model.to(device) + """ + rank, _, local_rank = _get_distributed_info() + + # Determine device based on distributed info + if local_rank >= 0: + # torchrun with LOCAL_RANK + device = f"cuda:{local_rank}" + elif rank >= 0: + # Distributed initialized without LOCAL_RANK + device = f"cuda:{rank}" + else: + # Single process fallback + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + return device