Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class BenchmarkConfig:
total_token: int = 60
depth: int = 5
top_k: int = 10
top_p: float = 1.0

# Hardware settings
num_gpus_per_model: int = 1
Expand All @@ -73,6 +74,9 @@ class BenchmarkConfig:
# SpecExit
early_stop_method: Optional[str] = None

# Batch settings
batch_size: int = 1


class BenchmarkEngine:
"""Core benchmark engine for speculative decoding evaluation"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"adam_w_mode": true,
"betas": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"stage3_gather_16bit_weights_on_model_save": true,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import torch

from angelslim.utils import decide_device_for_distributed, print_with_rank


class BaseBackend(ABC):
"""Base class for model backends"""
Expand Down Expand Up @@ -49,13 +51,21 @@ class TransformersBackend(BaseBackend):
def load_model(self):
from transformers import AutoModelForCausalLM, AutoTokenizer

# Get device based on environment
device = decide_device_for_distributed()

# Print device information with rank details
print_with_rank(f"Loading model to device: {device}")

# Update kwargs with default values
default_kwargs = {
"dtype": torch.bfloat16,
"device_map": "auto",
"device_map": device,
"trust_remote_code": True,
}
default_kwargs.update(self.kwargs)

# Load model to specific device based on rank
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, **default_kwargs
)
Expand Down
2 changes: 2 additions & 0 deletions angelslim/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .default_compress_config import * # noqa: F401 F403
from .lazy_imports import * # noqa: F401 F403
from .utils import common_prefix # noqa: F401
from .utils import decide_device_for_distributed # noqa: F401
from .utils import find_layers # noqa: F401
from .utils import find_parent_layer_and_sub_name # noqa: F401
from .utils import get_best_device # noqa: F401
Expand All @@ -25,5 +26,6 @@
from .utils import get_tensor_item # noqa: F401
from .utils import get_yaml_prefix_simple # noqa: F401
from .utils import print_info # noqa: F401
from .utils import print_with_rank # noqa: F401
from .utils import rank0_print # noqa: F401
from .utils import set_op_by_name # noqa: F401
92 changes: 92 additions & 0 deletions angelslim/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,95 @@ def rank0_print(*args, **kwargs):

if rank == 0:
print(*args, **kwargs)


def _get_distributed_info():
"""
Get distributed training information.

Returns:
Tuple of (rank, world_size, local_rank):
- rank: Global rank in distributed training (-1 if not distributed)
- world_size: Total number of processes (1 if not distributed)
- local_rank: Local rank on current node (-1 if not set)
"""
rank = -1
world_size = 1
local_rank = -1

# Check for torchrun environment variable first
if "LOCAL_RANK" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
# Then check if distributed is initialized
elif dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()

return rank, world_size, local_rank


def print_with_rank(*args, **kwargs):
"""
Print function with rank information for distributed training.

Automatically detects the current process rank and includes it in the output.
Works with torchrun, torch.distributed, or single process environments.

Args:
*args: Arguments to print
**kwargs: Keyword arguments for print function

Example:
print_with_rank("Model loaded successfully")
# Single node: [Rank 0/4] Model loaded successfully
# Multi-node: [Rank 0/8, Local 0] Model loaded successfully
"""
rank, world_size, local_rank = _get_distributed_info()

# Format rank information
if rank >= 0:
# Show local_rank only when it's different from rank (multi-node scenario)
if local_rank >= 0 and local_rank != rank:
prefix = f"[Rank {rank}/{world_size}, Local {local_rank}]"
else:
prefix = f"[Rank {rank}/{world_size}]"
else:
prefix = "[Single Process]"

# Print with rank prefix
print(prefix, *args, **kwargs)


def decide_device_for_distributed():
"""
Decide the appropriate device for model in distributed training context (torchrun).

Device selection priority:
1. LOCAL_RANK environment variable (torchrun launcher)
2. Distributed rank (if torch.distributed is initialized)
3. cuda:0 or cpu (single process fallback)

Returns:
str: Device string like 'cuda:0' or 'cpu'

Example:
device = decide_device_for_distributed()
model.to(device)
"""
rank, _, local_rank = _get_distributed_info()

# Determine device based on distributed info
if local_rank >= 0:
# torchrun with LOCAL_RANK
device = f"cuda:{local_rank}"
elif rank >= 0:
# Distributed initialized without LOCAL_RANK
device = f"cuda:{rank}"
else:
# Single process fallback
device = "cuda:0" if torch.cuda.is_available() else "cpu"

return device