diff --git a/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml b/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml new file mode 100644 index 000000000..b17bb89e1 --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml @@ -0,0 +1,384 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum_long.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + experiments_root_path: ${modalities_env:experiments_root_path} + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 32 + evaluation_interval_in_steps: 32 + consistency_enforcement: + enforce_tokens_per_step_consistency: false + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 1 + sequence_length: 256 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + drop_last: true + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: moe_cross_entropy + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + model: + instance_key: model_raw + pass_type: BY_REFERENCE + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: -1 + expert_parallel_degree: 4 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.num_layers} + multi_device_generator_policy: error + +ep_model: + component_key: model + variant_key: ep_wrapped + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [TransformerBlock] + +ac_model: + component_key: model + variant_key: activation_checkpointed + config: + model: + instance_key: ep_model + pass_type: BY_REFERENCE + ac_variant: full_activation_checkpointing + layers_fqn: layers + ac_fun_params: + ac_freq: 1 + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: ac_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + reshard_after_forward: true + block_names: [TransformerBlock] + +model_raw: + component_key: model + variant_key: moe + config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 + max_seq_len: ${settings.step_profile.sequence_length} + d_model: 128 + n_heads: 8 + n_kv_heads: 4 + num_layers: 2 + d_ff: 128 + attn_dropout: 0.0 + ffn_dropout: 0.0 + tie_embeddings: false + norm_eps: 1e-6 + rope_base: 1000000.0 + moe_num_experts: 8 + moe_top_k: 2 + moe_d_ff: 128 + moe_capacity_factor: 1.25 + moe_min_capacity: 4 + moe_overflow_policy: residual + moe_router_noise_std: 0.0 + moe_router_temperature: 1.0 + moe_router_dropout: 0.0 + moe_aux_loss_coef: 0.001 + moe_z_loss_coef: 0.0 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: ep_adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: ep + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +mfu_calculator: + component_key: mfu_calculator + variant_key: gpt2 + config: + n_layer: ${model_raw.config.num_layers} + sequence_length: ${settings.step_profile.sequence_length} + n_embd: ${model_raw.config.d_model} + world_size: ${settings.cuda_env.world_size} + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE diff --git a/scripts/monitor_gpus.sh b/scripts/monitor_gpus.sh new file mode 100755 index 000000000..7c221ffb1 --- /dev/null +++ b/scripts/monitor_gpus.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +# GPU Monitoring Script - saves metrics to CSV with timestamps +# Usage: ./monitor_gpus.sh [interval_seconds] [output_file] + +INTERVAL=${1:-5} # Default: sample every 5 seconds +OUTPUT=${2:-logs/gpu_metrics_$(date +%Y%m%d_%H%M%S).csv} +PIDFILE=/tmp/gpu_monitor_$$.pid + +echo "Starting GPU monitoring..." +echo "Interval: ${INTERVAL}s" +echo "Output: ${OUTPUT}" +echo "PID file: ${PIDFILE}" + +# Create output directory if needed +mkdir -p "$(dirname "$OUTPUT")" + +# Save PID for cleanup +echo $$ > "$PIDFILE" + +# Write CSV header +echo "timestamp,gpu_id,memory_used_mb,memory_total_mb,memory_util_pct,gpu_util_pct,temperature_c,power_draw_w,power_limit_w" > "$OUTPUT" + +# Cleanup function +cleanup() { + echo "" + echo "Stopping GPU monitoring..." + rm -f "$PIDFILE" + echo "Metrics saved to: $OUTPUT" + exit 0 +} + +trap cleanup SIGINT SIGTERM EXIT + +# Monitoring loop +while true; do + TIMESTAMP=$(date +%Y-%m-%d\ %H:%M:%S) + + # Query nvidia-smi for all metrics at once + nvidia-smi --query-gpu=index,memory.used,memory.total,utilization.memory,utilization.gpu,temperature.gpu,power.draw,power.limit \ + --format=csv,noheader,nounits 2>/dev/null | while IFS=',' read -r gpu_id mem_used mem_total mem_util gpu_util temp power power_limit; do + # Trim whitespace + gpu_id=$(echo "$gpu_id" | xargs) + mem_used=$(echo "$mem_used" | xargs) + mem_total=$(echo "$mem_total" | xargs) + mem_util=$(echo "$mem_util" | xargs) + gpu_util=$(echo "$gpu_util" | xargs) + temp=$(echo "$temp" | xargs) + power=$(echo "$power" | xargs) + power_limit=$(echo "$power_limit" | xargs) + + # Write to CSV + echo "$TIMESTAMP,$gpu_id,$mem_used,$mem_total,$mem_util,$gpu_util,$temp,$power,$power_limit" >> "$OUTPUT" + done + + # Live display (optional, comment out if too verbose) + # echo "[$(date +%H:%M:%S)] Logged GPU metrics ($(wc -l < "$OUTPUT") samples)" + + sleep "$INTERVAL" +done diff --git a/src/modalities/checkpointing/stateful/app_state.py b/src/modalities/checkpointing/stateful/app_state.py index 2da3ab236..25c1efc91 100644 --- a/src/modalities/checkpointing/stateful/app_state.py +++ b/src/modalities/checkpointing/stateful/app_state.py @@ -184,6 +184,16 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: class OptimizerStateRetriever(StateRetrieverIF): + @staticmethod + def _uses_standard_optimizer_state_dict(app_state: AppState) -> bool: + """Checks whether the optimizer state dict follows the standard torch Optimizer schema. + + Standard optimizer state dicts contain top-level "state" and "param_groups" keys, + which are required by distributed optimizer checkpoint utilities. + """ + state_dict = app_state.optimizer.state_dict() + return isinstance(state_dict, dict) and "state" in state_dict and "param_groups" in state_dict + @staticmethod def get_state_dict(app_state: AppState) -> dict[str, Any]: """Returns the state dict of the optimizer in the AppState object. @@ -196,6 +206,10 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]: """ if isinstance(app_state.optimizer, OptimizersList): sd = app_state.optimizer.state_dict() + elif not OptimizerStateRetriever._uses_standard_optimizer_state_dict(app_state): + # Custom optimizers (e.g. EP wrappers) may not expose the standard torch + # optimizer format expected by get_optimizer_state_dict. + sd = app_state.optimizer.state_dict() else: assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." sd = get_optimizer_state_dict( @@ -217,6 +231,8 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: """ if isinstance(app_state.optimizer, OptimizersList): app_state.optimizer.load_state_dict(state_dict) + elif not OptimizerStateRetriever._uses_standard_optimizer_state_dict(app_state): + app_state.optimizer.load_state_dict(state_dict) else: assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." set_optimizer_state_dict( diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 42a19b99a..fea034546 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -83,6 +83,16 @@ class CLMCrossEntropyLossConfig(BaseModel): prediction_key: str +class MoECrossEntropyLossConfig(BaseModel): + target_key: str + prediction_key: str + model: Any + tag: str = "MoECrossEntropyLoss" + + class Config: + arbitrary_types_allowed = True + + # Checkpointing class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): k: PositiveInt @@ -167,6 +177,19 @@ class AdamWOptimizerConfig(BaseModel): fused: bool | None = None +class EPAdamWConfig(BaseModel): + wrapped_model: PydanticPytorchModuleOrListType + device_mesh: PydanticDeviceMeshIFType + lr: float + betas: tuple[float, float] + eps: float + weight_decay: float + weight_decay_groups_excluded: list[str] + + class Config: + arbitrary_types_allowed = True + + class DummyLRSchedulerConfig(BaseModel): optimizer: PydanticOptimizerIFType @@ -311,6 +334,13 @@ def validate_dp_mesh_existence(self): return self +class EPWrappedModelConfig(BaseModel): + model: PydanticPytorchModuleOrListType + block_names: list[str] + device_mesh: PydanticDeviceMeshIFType + mixed_precision_settings: FSDP2MixedPrecisionSettings + + class DebuggingEnrichedModelConfig(BaseModel): model: PydanticPytorchModuleOrListType logging_dir_path: Path diff --git a/src/modalities/models/components/rotary_embedding.py b/src/modalities/models/components/rotary_embedding.py new file mode 100644 index 000000000..c569a787e --- /dev/null +++ b/src/modalities/models/components/rotary_embedding.py @@ -0,0 +1,126 @@ +import math +from typing import Optional + +import torch + + +def compute_default_inv_freq(dim_model: int, base_freq: float, device: Optional[torch.device] = None) -> torch.Tensor: + return 1.0 / (base_freq ** (torch.arange(0, dim_model, 2, device=device).float() / dim_model)) + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seq_length_dim: int) -> torch.Tensor: + cos = cos[:, :, : x.shape[seq_length_dim], :] + sin = sin[:, :, : x.shape[seq_length_dim], :] + return (x * cos) + (rotate_half(x) * sin) + + +def update_cos_sin_tables( + x: torch.Tensor, + inv_freq: torch.Tensor, + attention_scaling: float, + seq_length_dim: int, + seq_len_cached: Optional[int], + cos_cached: Optional[torch.Tensor], + sin_cached: Optional[torch.Tensor], +) -> tuple[int, torch.Tensor, torch.Tensor]: + seq_len = x.shape[seq_length_dim] + + if ( + seq_len != seq_len_cached + or cos_cached is None + or sin_cached is None + or cos_cached.device != x.device + or cos_cached.dtype != x.dtype + ): + t = torch.arange(seq_len, device=x.device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + cos_cached = (emb.cos() * attention_scaling)[None, None, :, :].to(x.dtype) + sin_cached = (emb.sin() * attention_scaling)[None, None, :, :].to(x.dtype) + seq_len_cached = seq_len + + return seq_len_cached, cos_cached, sin_cached + + +def compute_yarn_inv_freq_and_attention_scaling( + dim_model: int, + base_freq: float, + max_position_embeddings: int, + original_max_position_embeddings: int, + factor: Optional[float], + attention_factor: Optional[float], + mscale: Optional[float], + mscale_all_dim: Optional[float], + beta_fast: float, + beta_slow: float, + truncate: bool, + device: Optional[torch.device] = None, +) -> tuple[torch.Tensor, float]: + factor_float = ( + float(factor) if factor is not None else float(max_position_embeddings / original_max_position_embeddings) + ) + + def get_mscale(scale: float, mscale_value: float = 1.0) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale_value * math.log(scale) + 1.0 + + if attention_factor is None: + if mscale is not None and mscale_all_dim is not None: + attention_factor = float( + get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim)) + ) + else: + attention_factor = get_mscale(factor_float) + + def find_correction_dim(num_rotations: float, dim: int, base: float, max_pos_emb: int) -> float: + return (dim * math.log(max_pos_emb / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range( + low_rot: float, + high_rot: float, + dim: int, + base: float, + max_pos_emb: int, + do_truncate: bool, + ) -> tuple[float, float]: + low = find_correction_dim(low_rot, dim, base, max_pos_emb) + high = find_correction_dim(high_rot, dim, base, max_pos_emb) + if do_truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor: + if min_value == max_value: + max_value += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) + return torch.clamp(linear_func, 0, 1) + + pos_freqs = base_freq ** (torch.arange(0, dim_model, 2, device=device, dtype=torch.float) / dim_model) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor_float * pos_freqs) + + low, high = find_correction_range( + beta_fast, + beta_slow, + dim_model, + base_freq, + original_max_position_embeddings, + bool(truncate), + ) + + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim_model // 2).to( + device=device, dtype=torch.float + ) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, float(attention_factor) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index f43e6e87b..2e93b0be1 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -3,7 +3,7 @@ from abc import abstractmethod from enum import Enum from numbers import Real -from typing import Annotated, Literal, Optional, overload +from typing import Annotated, Literal, Optional, cast, overload import torch import torch.nn as nn @@ -17,6 +17,13 @@ RMSLayerNorm, RMSLayerNormConfig, ) +from modalities.models.components.rotary_embedding import ( + apply_rotary_pos_emb, + compute_default_inv_freq, + compute_yarn_inv_freq_and_attention_scaling, + rotate_half, + update_cos_sin_tables, +) from modalities.models.model import ActivationType, NNModel, SwiGLU from modalities.util import parse_enum_by_name @@ -221,9 +228,7 @@ def reset_parameters(self): if rope_type == "yarn": inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device) else: - inv_freq = 1.0 / ( - self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) - ) + inv_freq = compute_default_inv_freq(dim_model=self.dim_model, base_freq=self.base_freq, device=device) self.attention_scaling = 1.0 self.register_buffer("inv_freq", inv_freq) @@ -243,8 +248,7 @@ def rotate_half(self, x: torch.Tensor): torch.Tensor: The output tensor. """ - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) + return rotate_half(x) def apply_rotary_pos_emb(self, x, cos, sin): """ @@ -258,16 +262,7 @@ def apply_rotary_pos_emb(self, x, cos, sin): Returns: torch.Tensor: Tensor after applying rotary positional embedding. """ - # NOTE: This could probably be moved to Triton - - # Handle a possible sequence length mismatch in between q and k - cos = cos[:, :, : x.shape[self.seq_length_dim], :] - sin = sin[:, :, : x.shape[self.seq_length_dim], :] - - # the rotation is not really a rotation in higher dimensions, - # It merely swaps and negates certain dimensions to make - # the rotation below work - return (x * cos) + (self.rotate_half(x) * sin) + return apply_rotary_pos_emb(x=x, cos=cos, sin=sin, seq_length_dim=self.seq_length_dim) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor @@ -297,109 +292,31 @@ def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.T if self.max_position_embeddings is None: raise ValueError("YaRN requires max_position_embeddings to be set.") - original_max_position_embeddings = self.rope_scaling.original_max_position_embeddings - factor = self.rope_scaling.factor - if factor is None: - factor = self.max_position_embeddings / original_max_position_embeddings - factor_float = float(factor) - - attention_factor = self.rope_scaling.attention_factor - mscale_pair = None - if self.rope_scaling.mscale is not None and self.rope_scaling.mscale_all_dim is not None: - mscale_pair = (self.rope_scaling.mscale, self.rope_scaling.mscale_all_dim) - - beta_fast = self.rope_scaling.beta_fast - beta_slow = self.rope_scaling.beta_slow - truncate = self.rope_scaling.truncate - - def get_mscale(scale: float, mscale: float = 1.0) -> float: - """Return the YaRN mscale coefficient for a given scaling factor.""" - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - if attention_factor is None: - if mscale_pair is not None: - mscale, mscale_all_dim = mscale_pair - attention_factor = float( - get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim)) - ) - else: - attention_factor = get_mscale(factor_float) - - def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float: - """Map a target number of rotations to a rotary dimension index.""" - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - def find_correction_range( - low_rot: float, - high_rot: float, - dim: int, - base: int, - max_position_embeddings: int, - truncate: bool, - ) -> tuple[float, float]: - """Compute the lower and upper rotary-dimension correction bounds for YaRN.""" - low = find_correction_dim(low_rot, dim, base, max_position_embeddings) - high = find_correction_dim(high_rot, dim, base, max_position_embeddings) - if truncate: - low = math.floor(low) - high = math.ceil(high) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor: - """Create a clamped linear ramp used to blend interpolation and extrapolation.""" - if min_value == max_value: - max_value += 0.001 - linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - dim = self.dim_model - base = self.base_freq - - pos_freqs = base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (factor_float * pos_freqs) - - low, high = find_correction_range( - beta_fast, - beta_slow, - dim, - base, - original_max_position_embeddings, - bool(truncate), - ) - inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - + inv_freq_extrapolation * inv_freq_extrapolation_factor + return compute_yarn_inv_freq_and_attention_scaling( + dim_model=self.dim_model, + base_freq=self.base_freq, + max_position_embeddings=self.max_position_embeddings, + original_max_position_embeddings=self.rope_scaling.original_max_position_embeddings, + factor=self.rope_scaling.factor, + attention_factor=self.rope_scaling.attention_factor, + mscale=self.rope_scaling.mscale, + mscale_all_dim=self.rope_scaling.mscale_all_dim, + beta_fast=self.rope_scaling.beta_fast, + beta_slow=self.rope_scaling.beta_slow, + truncate=self.rope_scaling.truncate, + device=device, ) - return inv_freq, float(attention_factor) - def _update_cos_sin_tables(self, x): - # Update the cosine and sine tables. - seq_len = x.shape[self.seq_length_dim] - - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seq_len != self._seq_len_cached - or self._cos_cached is None - or self._sin_cached is None - or self._cos_cached.device != x.device - or self._cos_cached.dtype != x.dtype - ): - self._seq_len_cached = seq_len - t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) - emb = torch.cat((freqs, freqs), dim=-1).to( - x.device - ) # here, we combine the two matrices (not zipping them). - self._cos_cached = (emb.cos() * self.attention_scaling)[None, None, :, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.attention_scaling)[None, None, :, :].to(x.dtype) - + self._seq_len_cached, self._cos_cached, self._sin_cached = update_cos_sin_tables( + x=x, + inv_freq=cast(torch.Tensor, self.inv_freq), + attention_scaling=self.attention_scaling, + seq_length_dim=self.seq_length_dim, + seq_len_cached=self._seq_len_cached, + cos_cached=self._cos_cached, + sin_cached=self._sin_cached, + ) return self._cos_cached, self._sin_cached diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 62933794d..acef23f71 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -212,6 +212,12 @@ def get_fsdp2_wrapped_model( modules = list(model.modules()) + # Collect EP parameters to exclude from FSDP2 sharding + ep_params = { + p for m in model.modules() if getattr(m, "_ep_enabled", False) for p in m.parameters(recurse=False) + } + ignored_params = ep_params if ep_params else None + # we first shard all the blocks grouped_modules: list[nn.Module] = [] module_id = 0 @@ -226,6 +232,7 @@ def get_fsdp2_wrapped_model( grouped_modules, **fsdp_config, reshard_after_forward=reshard_block_after_forward, + ignored_params=ignored_params, ) grouped_modules = list() @@ -235,10 +242,11 @@ def get_fsdp2_wrapped_model( grouped_modules, **fsdp_config, reshard_after_forward=reshard_block_after_forward, + ignored_params=ignored_params, ) # finally, we shard the entire model - fully_shard(model, **fsdp_config, reshard_after_forward=reshard_after_forward) + fully_shard(model, **fsdp_config, reshard_after_forward=reshard_after_forward, ignored_params=ignored_params) logger.info( f"Rank {dist.get_rank()} sharded number of parameters: " f"{get_local_number_of_trainable_parameters(model)}" diff --git a/src/modalities/models/moe/__init__.py b/src/modalities/models/moe/__init__.py new file mode 100644 index 000000000..5e55327a1 --- /dev/null +++ b/src/modalities/models/moe/__init__.py @@ -0,0 +1,10 @@ +from modalities.models.moe.loss_functions import MoECrossEntropyLoss +from modalities.models.moe.model_factory import get_ep_wrapped_model +from modalities.models.moe.qwen_model import QwenModel, QwenModelConfig + +__all__ = [ + "MoECrossEntropyLoss", + "QwenModel", + "QwenModelConfig", + "get_ep_wrapped_model", +] diff --git a/src/modalities/models/moe/loss_functions.py b/src/modalities/models/moe/loss_functions.py new file mode 100644 index 000000000..57f30da69 --- /dev/null +++ b/src/modalities/models/moe/loss_functions.py @@ -0,0 +1,39 @@ +import torch +from torch.nn import CrossEntropyLoss + +from modalities.batch import InferenceResultBatch +from modalities.loss_functions import Loss + + +class MoECrossEntropyLoss(Loss): + """Cross entropy loss with optional MoE auxiliary losses from model layers.""" + + def __init__( + self, + target_key: str, + prediction_key: str, + model, + tag: str = "MoECrossEntropyLoss", + ): + super().__init__(tag) + self.target_key = target_key + self.prediction_key = prediction_key + self.model = model + self.loss_fun = CrossEntropyLoss(reduction="mean") + + def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: + labels = forward_batch.get_targets(self.target_key) + lm_logits = forward_batch.get_predictions(self.prediction_key) + + labels = labels.to(lm_logits.device) + loss = self.loss_fun( + lm_logits.contiguous().view(-1, lm_logits.size(-1)), + labels.contiguous().long().view(-1), + ) + + # Aux loss + for layer in self.model.layers.values(): + if hasattr(layer, "aux_loss") and layer.aux_loss is not None: + loss = loss + layer.aux_loss.to(loss.dtype) + + return loss diff --git a/src/modalities/models/moe/model_factory.py b/src/modalities/models/moe/model_factory.py new file mode 100644 index 000000000..35e9b7e10 --- /dev/null +++ b/src/modalities/models/moe/model_factory.py @@ -0,0 +1,96 @@ +import warnings + +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor + +from modalities.models.parallelism.expert_parallelism import ExpertParallel +from modalities.running_env.env_utils import FSDP2MixedPrecisionSettings +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method +from modalities.util import get_module_class_from_name + + +def get_ep_wrapped_model( + model, + block_names: list[str], + device_mesh: DeviceMesh, + mixed_precision_settings: FSDP2MixedPrecisionSettings, +) -> nn.Module: + block_types = [] + missing_block_names = [] + for name in block_names: + block_type = get_module_class_from_name(model, name) + if block_type is None: + missing_block_names.append(name) + else: + block_types.append(block_type) + + if len(missing_block_names) > 0 and (not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0): + warnings.warn( + f"Could not resolve some requested MoE block names and they will be ignored: {missing_block_names}", + stacklevel=2, + ) + + block_types = tuple(block_types) + + if len(block_types) == 0: + raise ValueError(f"None of the requested MoE block names were found: {block_names}") + + ep_mesh = get_mesh_for_parallelism_method(device_mesh, ParallelismDegrees.EP) + target_dtype = mixed_precision_settings.param_dtype.value + + wrapped_blocks = 0 + for module in model.modules(): + if isinstance(module, block_types): + if hasattr(module, "experts"): + ep_target = module + elif (ffn := getattr(module, "ffn", None)) is not None and hasattr(ffn, "experts"): + ep_target = ffn + else: + raise ValueError( + f"Module {type(module).__name__} has no EP-compatible experts location. " + "Expected `experts` or `ffn.experts`." + ) + + if getattr(ep_target, "_ep_enabled", False): + continue + + experts = ep_target.experts + missing = [a for a in ("w1", "w2") if not hasattr(experts, a)] + if missing: + raise ValueError( + f"Module {type(ep_target).__name__}.experts is not grouped-experts compatible. Missing: {missing}" + ) + if experts.w1.ndim != 3 or experts.w2.ndim != 3: + raise ValueError( + f"Expected grouped expert parameters with ndim=3. Got w1.ndim={experts.w1.ndim}, " + f"w2.ndim={experts.w2.ndim}" + ) + + ep_target._ep_mesh = ep_mesh + ep_target._ep_group = ep_mesh.get_group() + ep_target._ep_size = ep_mesh.size() + ep_target._ep_rank = ep_mesh.get_local_rank() + + ep_target.experts = ExpertParallel()._apply(ep_target.experts, ep_mesh) + ep_target.experts._ep_enabled = True + + for pname, p in list(ep_target.experts._parameters.items()): + if isinstance(p, DTensor) and p.dtype != target_dtype: + local = p.to_local().to(target_dtype) + ep_target.experts._parameters[pname] = nn.Parameter( + DTensor.from_local(local, p.device_mesh, p.placements, run_check=False), + requires_grad=p.requires_grad, + ) + + wrapped_blocks += 1 + + if wrapped_blocks == 0: + raise ValueError(f"No blocks matched the requested types: {[t.__name__ for t in block_types]}") + + model._ep_wrapped = True + model._ep_mesh = ep_mesh + model._ep_num_wrapped_blocks = wrapped_blocks + + return model diff --git a/src/modalities/models/moe/qwen_model.py b/src/modalities/models/moe/qwen_model.py new file mode 100644 index 000000000..b20dd05ef --- /dev/null +++ b/src/modalities/models/moe/qwen_model.py @@ -0,0 +1,486 @@ +import math +from typing import Literal, Optional, overload + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pydantic import BaseModel + +from modalities.models.components.rotary_embedding import ( + apply_rotary_pos_emb, + compute_default_inv_freq, + update_cos_sin_tables, +) +from modalities.models.model import NNModel + +try: + from torch.distributed.tensor import DTensor +except Exception: + DTensor = None + + +class QwenModelConfig(BaseModel): + vocab_size: int + max_seq_len: int + d_model: int + n_heads: int + n_kv_heads: int + num_layers: int + d_ff: int + sample_key: str = "input_ids" + prediction_key: str = "logits" + attn_dropout: float = 0.0 + ffn_dropout: float = 0.0 + tie_embeddings: bool = False + norm_eps: float = 1e-6 + rope_base: float = 1000000.0 + + moe_num_experts: int = 128 + moe_top_k: int = 8 + moe_d_ff: int = 768 + moe_capacity_factor: float = 1.25 + moe_min_capacity: int = 4 + moe_overflow_policy: Literal["drop", "residual"] = "residual" + moe_router_noise_std: float = 0.0 + moe_router_temperature: float = 1.0 + moe_router_dropout: float = 0.0 + moe_aux_loss_coef: float = 0.001 + moe_z_loss_coef: float = 0.0 + + +class RMSNorm(nn.Module): + def __init__(self, d_model: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d_model)) + + def reset_parameters(self): + nn.init.ones_(self.weight) + + def forward(self, x): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim: int, max_seq_len: int, base: float = 1000000.0): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.base = base + self.register_buffer("inv_freq", None, persistent=False) + self.register_buffer("cos_cached", None, persistent=False) + self.register_buffer("sin_cached", None, persistent=False) + self._seq_len_cached: Optional[int] = None + + def _compute_cache(self, device): + self.inv_freq = compute_default_inv_freq(dim_model=self.head_dim, base_freq=self.base, device=device) + self._seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + + def forward(self, x: torch.Tensor, seq_len: int): + if self.inv_freq is None: + self._compute_cache(x.device) + self._seq_len_cached, self.cos_cached, self.sin_cached = update_cos_sin_tables( + x=x, + inv_freq=self.inv_freq, + attention_scaling=1.0, + seq_length_dim=-2, + seq_len_cached=self._seq_len_cached, + cos_cached=self.cos_cached, + sin_cached=self.sin_cached, + ) + return ( + self.cos_cached[:, :, :seq_len, :].to(x.dtype), + self.sin_cached[:, :, :seq_len, :].to(x.dtype), + ) + + +def apply_rotary_emb(q, k, cos, sin): + return ( + apply_rotary_pos_emb(x=q, cos=cos, sin=sin, seq_length_dim=-2), + apply_rotary_pos_emb(x=k, cos=cos, sin=sin, seq_length_dim=-2), + ) + + +class GroupedQueryAttention(nn.Module): + def __init__(self, d_model, n_heads, n_kv_heads, max_seq_len, rope_base, norm_eps, attn_dropout): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = d_model // n_heads + self.n_rep = n_heads // n_kv_heads + + self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False) + + self.q_norm = RMSNorm(self.head_dim, eps=norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=norm_eps) + + self.rope = RotaryEmbedding(self.head_dim, max_seq_len, base=rope_base) + self.dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else nn.Identity() + + def forward(self, x, mask=None): + B, T, _ = x.shape + + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + + q = self.q_norm(q) + k = self.k_norm(k) + + cos, sin = self.rope(q, seq_len=T) + q, k = apply_rotary_emb(q, k, cos, sin) + + if self.n_rep > 1: + k = ( + k.unsqueeze(2) + .expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim) + .reshape(B, self.n_heads, T, self.head_dim) + ) + v = ( + v.unsqueeze(2) + .expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim) + .reshape(B, self.n_heads, T, self.head_dim) + ) + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=mask is None) + return self.o_proj(out.transpose(1, 2).contiguous().view(B, T, -1)) + + +class GroupedExperts(nn.Module): + def __init__( + self, + num_experts, + d_model, + d_ff, + ffn_dropout, + ): + super().__init__() + self.num_experts = num_experts + self.d_model = d_model + self.d_ff = d_ff + self.dropout = nn.Dropout(ffn_dropout) if ffn_dropout > 0 else nn.Identity() + + self.w1 = nn.Parameter(torch.empty(self.num_experts, self.d_ff, self.d_model)) + self.w2 = nn.Parameter(torch.empty(self.num_experts, self.d_ff, self.d_model)) + self.w3 = nn.Parameter(torch.empty(self.num_experts, self.d_model, self.d_ff)) + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5)) + + def _forward_local(self, routed_input: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + w1 = self.w1.to_local() if DTensor is not None and isinstance(self.w1, DTensor) else self.w1 + w2 = self.w2.to_local() if DTensor is not None and isinstance(self.w2, DTensor) else self.w2 + w3 = self.w3.to_local() if DTensor is not None and isinstance(self.w3, DTensor) else self.w3 + # F.linear requires matching dtypes between inputs and weights. Under mixed precision, + # routed_input can be BF16 while local expert weights remain FP32. + if routed_input.dtype != w1.dtype: + w1 = w1.to(dtype=routed_input.dtype) + w2 = w2.to(dtype=routed_input.dtype) + w3 = w3.to(dtype=routed_input.dtype) + local_num_tokens = ( + num_tokens_per_expert.to_local() + if DTensor is not None and isinstance(num_tokens_per_expert, DTensor) + else num_tokens_per_expert + ) + + outputs: list[torch.Tensor] = [] + start = 0 + total_rows = routed_input.shape[0] + + for expert_idx, num_tokens in enumerate(local_num_tokens.tolist()): + requested_tokens = int(num_tokens) + end = start + requested_tokens + local_end = min(end, total_rows) + expert_input = routed_input[start:local_end] + real_tokens = int(expert_input.shape[0]) + + out_parts: list[torch.Tensor] = [] + if real_tokens > 0: + x1 = F.linear(expert_input, w1[expert_idx]) + x2 = F.linear(expert_input, w2[expert_idx]) + out_parts.append(self.dropout(F.linear(F.silu(x1) * x2, w3[expert_idx]))) + + pad = requested_tokens - real_tokens + if pad > 0: + out_parts.append(routed_input.new_zeros((pad, self.d_model))) + + if out_parts: + outputs.append(torch.cat(out_parts, dim=0) if len(out_parts) > 1 else out_parts[0]) + + start = end + + if not outputs: + return routed_input.new_zeros((0, self.d_model)) + + out = torch.cat(outputs, dim=0) + if out.shape[0] < total_rows: + out = torch.cat([out, routed_input.new_zeros((total_rows - out.shape[0], self.d_model))], dim=0) + elif out.shape[0] > total_rows: + out = out[:total_rows] + return out + + def forward(self, routed_input: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + return self._forward_local(routed_input, num_tokens_per_expert) + + +class MoEBlock(nn.Module): + def __init__( + self, + d_model: int, + moe_d_ff: int, + moe_num_experts: int, + moe_top_k: int, + moe_capacity_factor: float, + moe_min_capacity: int, + moe_overflow_policy: str, + moe_router_noise_std: float, + moe_router_temperature: float, + moe_router_dropout: float, + moe_aux_loss_coef: float, + moe_z_loss_coef: float, + ffn_dropout: float, + ): + super().__init__() + self.num_experts = moe_num_experts + self.top_k = moe_top_k + self.capacity_factor = moe_capacity_factor + self.min_capacity = moe_min_capacity + self.overflow_policy = moe_overflow_policy + self.router_noise_std = moe_router_noise_std + self.router_dropout = nn.Dropout(moe_router_dropout) if moe_router_dropout > 0 else nn.Identity() + self.router_temperature = moe_router_temperature + self.aux_loss_coef = moe_aux_loss_coef + self.z_loss_coef = moe_z_loss_coef + + self.router = nn.Linear(d_model, self.num_experts, bias=False) + self.experts = GroupedExperts( + num_experts=moe_num_experts, d_model=d_model, d_ff=moe_d_ff, ffn_dropout=ffn_dropout + ) + self.last_aux_loss: Optional[torch.Tensor] = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, D = x.shape + E = self.num_experts + K = self.top_k + N = B * T + x_flat = x.view(N, D) + + logits = self.router(self.router_dropout(x_flat).to(self.router.weight.dtype)).float() + if self.router_noise_std > 0 and self.training: + logits = logits + torch.randn_like(logits) * self.router_noise_std + logits = logits / self.router_temperature + + probs = torch.softmax(logits, dim=-1) + topk_val, topk_idx = torch.topk(probs, k=K, dim=-1) + topk_w = (topk_val / (topk_val.sum(dim=-1, keepdim=True) + 1e-9)).to(x_flat.dtype) + + capacity = max(math.ceil(self.capacity_factor * N / E), self.min_capacity) + + dispatch_mask = F.one_hot(topk_idx, num_classes=E).to(x_flat.dtype) + positions = torch.cumsum(dispatch_mask.sum(dim=1), dim=0) + capacity_mask = (positions <= capacity).to(x_flat.dtype) + final_mask = dispatch_mask * capacity_mask.unsqueeze(1) + + load = final_mask.sum(dim=[0, 1]) + importance = probs.sum(dim=0) + + flat_valid = capacity_mask.gather(1, topk_idx).bool().reshape(-1) + flat_token_ids = torch.arange(N, device=x.device).unsqueeze(1).expand(N, K).reshape(-1)[flat_valid] + flat_expert_ids = topk_idx.reshape(-1)[flat_valid] + flat_weights = topk_w.reshape(-1)[flat_valid] + + if flat_expert_ids.numel() > 0: + sort_idx = torch.argsort(flat_expert_ids) + token_ids_sorted = flat_token_ids[sort_idx] + expert_ids_sorted = flat_expert_ids[sort_idx] + weights_sorted = flat_weights[sort_idx] + + routed_output = self.experts(x_flat[token_ids_sorted], torch.bincount(expert_ids_sorted, minlength=E)) + weighted_output = routed_output * weights_sorted.unsqueeze(-1) + + out = x_flat.new_zeros((N, D)) + out.index_add_(0, token_ids_sorted, weighted_output) + assigned = x_flat.new_zeros((N,)) + assigned.index_add_(0, token_ids_sorted, weights_sorted) + else: + out = x_flat.new_zeros((N, D)) + assigned = x_flat.new_zeros((N,)) + + not_assigned = assigned < 1e-6 + if not_assigned.any() and self.overflow_policy == "residual": + out[not_assigned] = x_flat[not_assigned] + + aux = None + if self.aux_loss_coef > 0: + imp = importance / (importance.sum() + 1e-9) + ld = load / (load.sum() + 1e-9) + aux = self.aux_loss_coef * E * torch.sum(imp * ld) + if self.z_loss_coef > 0: + z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) + aux = (aux if aux is not None else torch.tensor(0.0, device=x.device)) + self.z_loss_coef * z_loss + + self.last_aux_loss = aux + return out.view(B, T, D) + + +class TransformerBlock(nn.Module): + def __init__( + self, + d_model: int, + d_ff: int, + n_heads: int, + n_kv_heads: int, + max_seq_len: int, + rope_base: float, + norm_eps: float, + attn_dropout: float, + ffn_dropout: float, + moe_d_ff: int = 768, + moe_num_experts: int = 128, + moe_top_k: int = 8, + moe_capacity_factor: float = 1.25, + moe_min_capacity: int = 4, + moe_overflow_policy: str = "residual", + moe_router_noise_std: float = 0.0, + moe_router_temperature: float = 1.0, + moe_router_dropout: float = 0.0, + moe_aux_loss_coef: float = 0.001, + moe_z_loss_coef: float = 0.0, + ): + super().__init__() + self.pre_attn_norm = RMSNorm(d_model, eps=norm_eps) + self.attn = GroupedQueryAttention( + d_model=d_model, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + max_seq_len=max_seq_len, + rope_base=rope_base, + norm_eps=norm_eps, + attn_dropout=attn_dropout, + ) + self.pre_ffn_norm = RMSNorm(d_model, eps=norm_eps) + self.ffn = MoEBlock( + d_model=d_model, + moe_d_ff=moe_d_ff, + moe_num_experts=moe_num_experts, + moe_top_k=moe_top_k, + moe_capacity_factor=moe_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_overflow_policy=moe_overflow_policy, + moe_router_noise_std=moe_router_noise_std, + moe_router_temperature=moe_router_temperature, + moe_router_dropout=moe_router_dropout, + moe_aux_loss_coef=moe_aux_loss_coef, + moe_z_loss_coef=moe_z_loss_coef, + ffn_dropout=ffn_dropout, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.pre_attn_norm(x)) + x = x + self.ffn(self.pre_ffn_norm(x)) + return x + + @property + def aux_loss(self) -> Optional[torch.Tensor]: + return getattr(self.ffn, "last_aux_loss", None) + + +class QwenModel(NNModel): + def __init__( + self, + vocab_size: int, + max_seq_len: int, + d_model: int, + n_heads: int, + n_kv_heads: int, + d_ff: int, + num_layers: int, + moe_d_ff: int = 768, + sample_key: str = "input_ids", + prediction_key: str = "logits", + attn_dropout: float = 0.0, + ffn_dropout: float = 0.0, + tie_embeddings: bool = False, + norm_eps: float = 1e-6, + rope_base: float = 1000000.0, + moe_num_experts: int = 128, + moe_top_k: int = 8, + moe_capacity_factor: float = 1.25, + moe_min_capacity: int = 4, + moe_overflow_policy: str = "residual", + moe_router_noise_std: float = 0.0, + moe_router_temperature: float = 1.0, + moe_router_dropout: float = 0.0, + moe_aux_loss_coef: float = 0.001, + moe_z_loss_coef: float = 0.0, + ): + weight_decay_groups = { + "linear": ["q_proj", "k_proj", "v_proj", "o_proj", "lm_head", "router", "w1", "w2", "w3"], + "embedding": ["token_emb"], + "layernorm": ["pre_attn_norm", "pre_ffn_norm", "final_norm", "q_norm", "k_norm"], + } + super().__init__(weight_decay_groups=weight_decay_groups) + self.sample_key = sample_key + self.prediction_key = prediction_key + + self.token_emb = nn.Embedding(vocab_size, d_model) + + self.layers = nn.ModuleDict( + { + str(i): TransformerBlock( + d_model=d_model, + d_ff=d_ff, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + max_seq_len=max_seq_len, + rope_base=rope_base, + norm_eps=norm_eps, + attn_dropout=attn_dropout, + ffn_dropout=ffn_dropout, + moe_d_ff=moe_d_ff, + moe_num_experts=moe_num_experts, + moe_top_k=moe_top_k, + moe_capacity_factor=moe_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_overflow_policy=moe_overflow_policy, + moe_router_noise_std=moe_router_noise_std, + moe_router_temperature=moe_router_temperature, + moe_router_dropout=moe_router_dropout, + moe_aux_loss_coef=moe_aux_loss_coef, + moe_z_loss_coef=moe_z_loss_coef, + ) + for i in range(num_layers) + } + ) + + self.final_norm = RMSNorm(d_model, eps=norm_eps) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False) + + if tie_embeddings: + self.lm_head.weight = self.token_emb.weight + + @overload + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + ... + + @overload + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + ... + + def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: + if isinstance(inputs, dict): + return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} + return self.forward_impl(inputs) + + def forward_impl(self, input_ids: torch.Tensor) -> torch.Tensor: + x = self.token_emb(input_ids) + for layer in self.layers.values(): + x = layer(x) + return self.lm_head(self.final_norm(x)) diff --git a/src/modalities/models/parallelism/expert_parallelism.py b/src/modalities/models/parallelism/expert_parallelism.py new file mode 100644 index 000000000..eaa33937c --- /dev/null +++ b/src/modalities/models/parallelism/expert_parallelism.py @@ -0,0 +1,139 @@ +# Some portions of this implementation are inspired, adapted, or refactored +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + +import torch +import torch.nn as nn +from torch import Tensor +from torch.distributed._functional_collectives import all_to_all_single, all_to_all_single_autograd +from torch.distributed.tensor import DeviceMesh, Shard, distribute_module, distribute_tensor + + +def _permute_tokens( + x: Tensor, + num_tokens_per_expert_group: Tensor, + ep_degree: int, + num_local_experts: int, +) -> tuple[tuple, Tensor, Tensor, Tensor]: + """ + Reorder tokens from the post-all-to-all layout to per-local-expert contiguous layout. + + After the all-to-all, received tokens are ordered as: + [e0_from_rank0 tokens, e1_from_rank0 tokens, ..., e0_from_rank1 tokens, ...] + + We reorder to: + [all tokens for local_expert_0, all tokens for local_expert_1, ...] + + Returns (original_shape, x_permuted, permuted_indices, new_num_tokens_per_expert). + """ + counts = num_tokens_per_expert_group.view(ep_degree, num_local_experts) # (ep_degree, num_local_experts) + + flat_counts = counts.flatten() # length = ep_degree * num_local_experts + + offsets = flat_counts.cumsum(0) - flat_counts + + # build permuted_indices + indices_per_expert: list[Tensor] = [] + for e in range(num_local_experts): + for r in range(ep_degree): + count = int(counts[r, e].item()) + if count > 0: + start = int(offsets[r * num_local_experts + e].item()) + indices_per_expert.append(torch.arange(start, start + count, device=x.device, dtype=torch.long)) + + if indices_per_expert: + permuted_indices = torch.cat(indices_per_expert) + else: + permuted_indices = torch.zeros(0, dtype=torch.long, device=x.device) + + new_num_tokens_per_expert = counts.sum(dim=0) # (num_local_experts,) + original_shape = x.shape + x_permuted = x[permuted_indices] if permuted_indices.numel() > 0 else x.new_zeros((0, x.shape[-1])) + return original_shape, x_permuted, permuted_indices, new_num_tokens_per_expert + + +def _unpermute_tokens(out: Tensor, original_shape: tuple, permuted_indices: Tensor) -> Tensor: + """ + Inverse of _permute_tokens: scatter expert outputs back to the all-to-all layout. + """ + out_unpermuted = out.new_zeros(original_shape) + if permuted_indices.numel() > 0: + out_unpermuted[permuted_indices] = out + return out_unpermuted + + +class ExpertParallel: + """ + Expert Parallelism for grouped-expert MoE layers. + + Shards GroupedExperts parameters on the expert dimension (Shard(0)) across EP ranks, + and wraps forward() with all-to-all token dispatch/combine collectives. + + Usage: + module.experts = ExpertParallel()._apply(module.experts, ep_mesh) + """ + + def __init__(self) -> None: + self.input_splits: list[int] | None = None + self.output_splits: list[int] | None = None + self.original_shape: tuple | None = None + self.permuted_indices: Tensor | None = None + + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: + for param_name, param in mod.named_parameters(recurse=False): + mod.register_parameter( + param_name, + nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])), + ) + + def _token_dispatch(self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh) -> tuple[Tensor, Tensor]: + routed_input, num_tokens_per_expert = inputs + ep_degree = device_mesh.shape[0] + num_local_experts = num_tokens_per_expert.shape[0] // ep_degree + + with torch.no_grad(): + num_tokens_per_expert_group = all_to_all_single( + num_tokens_per_expert, None, None, group=device_mesh.get_group() + ) + + num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(num_tokens_per_expert_group) + input_splits = ( + num_tokens_per_expert.view(ep_degree, -1).sum(dim=1).to(torch.device("cpu"), non_blocking=True) + ) + + output_splits = ( + num_tokens_per_expert_group.view(ep_degree, -1).sum(dim=1).to(torch.device("cpu"), non_blocking=False) + ) + self.input_splits = input_splits.tolist() + self.output_splits = output_splits.tolist() + + routed_input = all_to_all_single_autograd( + routed_input, + self.output_splits, + self.input_splits, + device_mesh.get_group(), + ) + + self.original_shape, routed_input, self.permuted_indices, num_tokens_per_expert_group = _permute_tokens( + routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts + ) + return routed_input, num_tokens_per_expert_group + + def _token_combine(self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh) -> Tensor: + routed_output = _unpermute_tokens(routed_output, self.original_shape, self.permuted_indices) + routed_output = all_to_all_single_autograd( + routed_output, + self.input_splits, + self.output_splits, + device_mesh.get_group(), + ) + return routed_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) diff --git a/src/modalities/optimizers/ep_adamw.py b/src/modalities/optimizers/ep_adamw.py new file mode 100644 index 000000000..006f9faf9 --- /dev/null +++ b/src/modalities/optimizers/ep_adamw.py @@ -0,0 +1,154 @@ +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.nn import Module +from torch.optim import AdamW, Optimizer + +from modalities.optimizers.optimizer_factory import _build_optimizer_groups_via_weight_decay_split + + +def _get_ep_param_ids(model: Module) -> set: + return {id(p) for m in model.modules() if getattr(m, "_ep_enabled", False) for p in m.parameters(recurse=False)} + + +def _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_groups_excluded): + weight_decay_groups = model.weight_decay_groups + params = { + name: p + for name, p in model.named_parameters() + if p.requires_grad and id(p) not in ep_param_ids and (not isinstance(p, DTensor) or p.to_local().numel() > 0) + } + return _build_optimizer_groups_via_weight_decay_split( + weight_decay, weight_decay_groups_excluded, weight_decay_groups, params + ) + + +class EPAdamW(Optimizer): + """ + ZeRO stage-1 for EP (DTensor) params + standard AdamW for dense params. + + Each dp_shard rank stores optimizer states for 1/dp_shard of the EP params. + After each step, updated EP param values are broadcast from owner to all ranks. + Dense params are handled by a separate AdamW (FSDP2 shards them independently). + """ + + def __init__( + self, + model: Module, + device_mesh, + lr: float, + betas: tuple[float, float], + eps: float, + weight_decay: float, + weight_decay_groups_excluded: list[str], + ): + self._dp_mesh = device_mesh["dp_shard"] + self._dp_group = self._dp_mesh.get_group() + self._dp_rank = dist.get_rank(self._dp_group) + self._dp_size = dist.get_world_size(self._dp_group) + + ep_param_ids = _get_ep_param_ids(model) + self._all_ep_params = [p for p in model.parameters() if id(p) in ep_param_ids] + + # rank r owns params[r::dp_size] + self._owned_ep_params = self._all_ep_params[self._dp_rank :: self._dp_size] + + dense_groups = _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_groups_excluded) + + if self._owned_ep_params: + self._ep_adamw = AdamW(self._owned_ep_params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + else: + self._ep_adamw = None + self._dense_adamw = AdamW(dense_groups, lr=lr, betas=betas, eps=eps) + + # unified param groups for lr_scheduler compatibility: + # group 0 = all EP params, groups 1+ = dense weight-decay split + ep_group = {"params": self._all_ep_params, "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay} + all_groups = [ep_group] + [{**g, "lr": lr, "betas": betas, "eps": eps} for g in dense_groups] + super().__init__(all_groups, {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # all-reduce + for p in self._all_ep_params: + if p.grad is None: + continue + if isinstance(p.grad, DTensor): + local_g = p.grad.to_local() + dist.all_reduce(local_g, op=dist.ReduceOp.SUM, group=self._dp_group) + local_g.div_(self._dp_size) + else: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM, group=self._dp_group) + p.grad.div_(self._dp_size) + + # Sync lr + if self._ep_adamw is not None: + self._ep_adamw.param_groups[0]["lr"] = self.param_groups[0]["lr"] + for i, group in enumerate(self._dense_adamw.param_groups): + group["lr"] = self.param_groups[i + 1]["lr"] + + # Update ep params + if self._ep_adamw is not None: + self._ep_adamw.step() + + # Update dense params + self._dense_adamw.step() + + # broadcast updated EP param local tensors + for i, p in enumerate(self._all_ep_params): + owner_local_rank = i % self._dp_size + owner_global_rank = dist.get_global_rank(self._dp_group, owner_local_rank) + if isinstance(p, DTensor): + local_tensor = p.to_local() + elif isinstance(p.data, DTensor): + local_tensor = p.data.to_local() + else: + local_tensor = p.data + dist.broadcast(local_tensor, src=owner_global_rank, group=self._dp_group) + + return loss + + def zero_grad(self, set_to_none: bool = True): + for p in self._all_ep_params: + if set_to_none: + p.grad = None + elif p.grad is not None: + p.grad.detach_() + p.grad.zero_() + self._dense_adamw.zero_grad(set_to_none=set_to_none) + + def state_dict(self) -> dict: + return { + "ep_adamw": self._ep_adamw.state_dict() if self._ep_adamw is not None else {}, + "dense_adamw": self._dense_adamw.state_dict(), + } + + def load_state_dict(self, state_dict: dict) -> None: + if self._ep_adamw is not None and state_dict["ep_adamw"]: + self._ep_adamw.load_state_dict(state_dict["ep_adamw"]) + self._dense_adamw.load_state_dict(state_dict["dense_adamw"]) + + +def get_ep_adam_w( + wrapped_model, + device_mesh, + lr: float, + betas: tuple[float, float], + eps: float, + weight_decay: float, + weight_decay_groups_excluded: list[str], +) -> EPAdamW: + return EPAdamW( + model=wrapped_model, + device_mesh=device_mesh, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + weight_decay_groups_excluded=weight_decay_groups_excluded, + ) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 26df9b432..71eb2c8ad 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -36,6 +36,8 @@ DummyLRSchedulerConfig, DummyProgressSubscriberConfig, DummyResultSubscriberConfig, + EPAdamWConfig, + EPWrappedModelConfig, EvaluationResultToDiscSubscriberConfig, FSDP1ActivationCheckpointedModelConfig, FSDP1CheckpointedModelConfig, @@ -51,6 +53,7 @@ LinearWarmupCosineAnnealingLRSchedulerConfig, LLMDataLoaderConfig, MemMapDatasetConfig, + MoECrossEntropyLossConfig, OneCycleLRSchedulerConfig, PackedMemMapDatasetContinuousConfig, PackedMemMapDatasetMegatronConfig, @@ -96,6 +99,9 @@ from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer, Llama3InitializerConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory +from modalities.models.moe.loss_functions import MoECrossEntropyLoss +from modalities.models.moe.model_factory import get_ep_wrapped_model +from modalities.models.moe.qwen_model import QwenModel, QwenModelConfig from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory from modalities.models.parallelism.pipeline_parallelism_configs import ( ComponentSelectorFromPipelineConfig, @@ -109,12 +115,14 @@ ComposedInitializationRoutines, ComposedModelInitializationConfig, ) +from modalities.optimizers.ep_adamw import get_ep_adam_w from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory from modalities.optimizers.optimizer_factory import OptimizerFactory from modalities.optimizers.optimizer_list import OptimizersList from modalities.optimizers.scheduler_list import SchedulerList from modalities.running_env.fsdp.device_mesh import DeviceMeshConfig, get_device_mesh, get_parallel_degree from modalities.tokenization.tokenizer_wrapper import PreTrainedHFTokenizer, PreTrainedSPTokenizer +from modalities.training.gradient_clipping.ep_gradient_clipper import EPGradientClipper from modalities.training.gradient_clipping.fsdp_gradient_clipper import ( FSDP1GradientClipper, FSDP1LoggingOnlyGradientClipper, @@ -187,6 +195,8 @@ class ComponentEntity: COMPONENTS = [ # models ComponentEntity("model", "gpt2", GPT2ModelFactory.get_gpt2_model, GPT2LLMConfig), + ComponentEntity("model", "moe", QwenModel, QwenModelConfig), + ComponentEntity("model", "ep_wrapped", get_ep_wrapped_model, EPWrappedModelConfig), ComponentEntity( "model", "gpt2_tp", maybe_model_list(GPT2ModelFactory.get_gpt2_tensor_parallelized_model), GPT2ModelTPConfig ), @@ -250,6 +260,7 @@ class ComponentEntity: ), # losses ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), + ComponentEntity("loss", "moe_cross_entropy", MoECrossEntropyLoss, MoECrossEntropyLossConfig), # optimizers ComponentEntity( "optimizer", "adam", maybe_model_list_for_optimizer(OptimizerFactory.get_adam), AdamOptimizerConfig @@ -257,6 +268,7 @@ class ComponentEntity: ComponentEntity( "optimizer", "adam_w", maybe_model_list_for_optimizer(OptimizerFactory.get_adam_w), AdamWOptimizerConfig ), + ComponentEntity("optimizer", "ep_adam_w", maybe_model_list_for_optimizer(get_ep_adam_w), EPAdamWConfig), ComponentEntity( "optimizer", "fsdp1_checkpointed", @@ -402,6 +414,7 @@ class ComponentEntity: "gradient_clipper", "fsdp1_logging_only", FSDP1LoggingOnlyGradientClipper, FSDP1DummyGradientClipperConfig ), ComponentEntity("gradient_clipper", "fsdp2", FSDP2GradientClipper, FSDP2GradientClipperConfig), + ComponentEntity("gradient_clipper", "ep", EPGradientClipper, FSDP2GradientClipperConfig), ComponentEntity( "gradient_clipper", "fsdp2_logging_only", FSDP2LoggingOnlyGradientClipper, FSDP2DummyGradientClipperConfig ), diff --git a/src/modalities/running_env/fsdp/device_mesh.py b/src/modalities/running_env/fsdp/device_mesh.py index cd456938c..f4f3b7e26 100644 --- a/src/modalities/running_env/fsdp/device_mesh.py +++ b/src/modalities/running_env/fsdp/device_mesh.py @@ -21,6 +21,7 @@ class DeviceMeshConfig(BaseModel): tensor_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1 pipeline_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1 context_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1 + expert_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1 enable_loss_parallel: Optional[bool] = False world_size: Annotated[int, Field(strict=True, gt=0)] @@ -28,6 +29,7 @@ class DeviceMeshConfig(BaseModel): def _validate(self): for d in ( self.context_parallel_degree, + self.expert_parallel_degree, self.tensor_parallel_degree, self.pipeline_parallel_degree, ): @@ -50,6 +52,7 @@ def _validate(self): self.data_parallel_shard_degree = self.world_size // ( self.data_parallel_replicate_degree * self.context_parallel_degree + * self.expert_parallel_degree * self.tensor_parallel_degree * self.pipeline_parallel_degree ) @@ -58,12 +61,14 @@ def _validate(self): self.data_parallel_replicate_degree = self.world_size // ( self.data_parallel_shard_degree * self.context_parallel_degree + * self.expert_parallel_degree * self.tensor_parallel_degree * self.pipeline_parallel_degree ) if ( self.data_parallel_shard_degree * self.data_parallel_replicate_degree + * self.expert_parallel_degree * self.tensor_parallel_degree * self.pipeline_parallel_degree * self.context_parallel_degree @@ -72,6 +77,7 @@ def _validate(self): raise ConfigError( f"Invalid parallel dims: data_parallel_shard_degree({self.data_parallel_shard_degree}) * " f"data_parallel_replicate_degree({self.data_parallel_replicate_degree}) * " + f"expert_parallel_degree({self.expert_parallel_degree}) * " f"tensor_parallel_degree({self.tensor_parallel_degree}) *" f"* pipeline_parallel_degree({self.pipeline_parallel_degree}) *" f"context_parallel_degree({self.context_parallel_degree})!= WORLD_SIZE({self.world_size})" @@ -85,6 +91,7 @@ class ParallelismDegrees(Enum): DP_REPLICATE = "dp_replicate" DP_SHARD = "dp_shard" CP = "cp" + EP = "ep" TP = "tp" PP = "pp" @@ -96,6 +103,7 @@ def get_device_mesh( tensor_parallel_degree: int, pipeline_parallel_degree: int, context_parallel_degree: int, + expert_parallel_degree: int, enable_loss_parallel: bool, world_size: int, ) -> DeviceMesh: @@ -109,6 +117,7 @@ def get_device_mesh( tensor_parallel_degree (int): The tensor parallel degree. pipeline_parallel_degree (int): The pipeline parallel degree. context_parallel_degree (int): The context parallel degree. + expert_parallel_degree (int): The expert parallel degree. enable_loss_parallel (bool): Whether to enable loss parallelism. world_size (int): The world size. @@ -123,6 +132,7 @@ def get_device_mesh( data_parallel_replicate_degree, data_parallel_shard_degree, context_parallel_degree, + expert_parallel_degree, tensor_parallel_degree, ], [ @@ -130,6 +140,7 @@ def get_device_mesh( ParallelismDegrees.DP_REPLICATE.value, ParallelismDegrees.DP_SHARD.value, ParallelismDegrees.CP.value, + ParallelismDegrees.EP.value, ParallelismDegrees.TP.value, ], strict=True, diff --git a/src/modalities/training/gradient_clipping/ep_gradient_clipper.py b/src/modalities/training/gradient_clipping/ep_gradient_clipper.py new file mode 100644 index 000000000..2efc5ed58 --- /dev/null +++ b/src/modalities/training/gradient_clipping/ep_gradient_clipper.py @@ -0,0 +1,89 @@ +import math +from typing import Optional + +import torch +from torch import distributed as dist +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import FSDPModule as FSDP2 +from torch.distributed.tensor import DTensor + +from modalities.running_env.fsdp.device_mesh import ( + ParallelismDegrees, + get_mesh_for_parallelism_method, + has_parallelism_method, +) +from modalities.training.gradient_clipping.fsdp_gradient_clipper import FSDP2GradientClipper, GradientClippingMode + + +class EPGradientClipper(FSDP2GradientClipper): + """FSDP2 clipper wrapper that handles EP DTensor gradients safely.""" + + def __init__( + self, + model_parts: FSDP2 | list[FSDP2], + max_norm: float, + norm_type: GradientClippingMode, + device_mesh: Optional[DeviceMesh] = None, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, + ) -> None: + super().__init__( + model_parts=model_parts, + max_norm=max_norm, + norm_type=norm_type, + device_mesh=device_mesh, + error_if_nonfinite=error_if_nonfinite, + foreach=foreach, + ) + + @torch.no_grad() + def clip_gradients(self) -> torch.Tensor: + grads = [p.grad for model in self.models for p in model.parameters() if p.grad is not None] + + if len(grads) == 0: + device = ( + torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu") + ) + total_norm = torch.tensor(0.0, device=device) + else: + norm_type_val = self.norm_type.value + first_grad = grads[0] + first_device = first_grad.to_local().device if isinstance(first_grad, DTensor) else first_grad.device + norm_scalars: list[torch.Tensor] = [] + + for grad in grads: + grad_norm = torch.linalg.vector_norm(grad, ord=norm_type_val) + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.full_tensor() + norm_scalars.append(grad_norm.to(first_device)) + + if math.isinf(norm_type_val): + total_norm = torch.max(torch.stack(norm_scalars)) + else: + total_norm = torch.linalg.vector_norm(torch.stack(norm_scalars), ord=norm_type_val) + + if self.error_if_nonfinite and (torch.isnan(total_norm) or torch.isinf(total_norm)): + raise RuntimeError( + f"The total norm of order {norm_type_val} for gradients is non-finite: {total_norm.item()}" + ) + + if has_parallelism_method(self.device_mesh, ParallelismDegrees.PP): + pp_mesh = get_mesh_for_parallelism_method( + device_mesh=self.device_mesh, parallelism_method=ParallelismDegrees.PP + ) + if math.isinf(self.norm_type.value): + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) + else: + total_norm **= self.norm_type.value + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + total_norm **= 1.0 / self.norm_type.value + + # do not use torch.nn.utils.clip_grads_with_norm_ here: it batches grads with + # torch._foreach_mul_, which fails when the list mixes DTensors from different meshes. + clip_coef = self.max_norm / (total_norm + 1e-6) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + + for grad in grads: + grad_device = grad.to_local().device if isinstance(grad, DTensor) else grad.device + grad.mul_(clip_coef_clamped.to(grad_device)) + return total_norm diff --git a/tests/end2end_tests/test_moe_ep_fsdp2_e2e.py b/tests/end2end_tests/test_moe_ep_fsdp2_e2e.py new file mode 100644 index 000000000..70e7203e9 --- /dev/null +++ b/tests/end2end_tests/test_moe_ep_fsdp2_e2e.py @@ -0,0 +1,142 @@ +import logging +import multiprocessing as py_mp +import os +import traceback +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.multiprocessing as mp + +from modalities.__main__ import Main, load_app_config_dict +from modalities.batch import EvaluationResultBatch +from modalities.config.config import ProcessGroupBackendType +from modalities.config.instantiation_models import TrainingComponentsInstantiationModel +from modalities.logging_broker.messages import Message +from tests.end2end_tests.custom_components import ( + MultiProcessingCudaEnv, + SaveAllResultSubscriber, + SaveAllResultSubscriberConfig, +) +from tests.utility import find_free_port, monitor_child_processes + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="This E2E test requires 4 CUDA devices.") +class TestMoEEPFSDP2E2E: + @staticmethod + def _patch_for_short_test_run(config_dict: dict[str, Any], checkpoint_root_path: Path) -> None: + # Keep runtime short while preserving EP + FSDP2 wiring. + config_dict["settings"]["intervals"]["training_log_interval_in_steps"] = 1 + config_dict["settings"]["intervals"]["checkpointing_interval_in_steps"] = 1 + config_dict["settings"]["intervals"]["evaluation_interval_in_steps"] = 1000 + + config_dict["settings"]["step_profile"]["sequence_length"] = 64 + config_dict["settings"]["step_profile"]["local_train_micro_batch_size"] = 1 + config_dict["settings"]["step_profile"]["gradient_accumulation_steps"] = 1 + + config_dict["settings"]["training_target"]["num_target_tokens"] = 512 + config_dict["settings"]["training_target"]["num_target_steps"] = 2 + config_dict["lr_scheduler"]["config"]["total_steps"] = 2 + + config_dict["train_dataset"]["config"]["sequence_length"] = 64 + config_dict["test_dataset"]["config"]["sequence_length"] = 64 + config_dict["train_dataloader"]["config"]["num_workers"] = 0 + config_dict["test_dataloader"]["config"]["num_workers"] = 0 + config_dict["train_dataloader"]["config"]["pin_memory"] = False + config_dict["test_dataloader"]["config"]["pin_memory"] = False + + config_dict["settings"]["paths"]["checkpoint_saving_path"] = checkpoint_root_path + config_dict["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"][ + "checkpoint_path" + ] = checkpoint_root_path + + @staticmethod + def _worker_wrapper( + process_id: int, + world_size: int, + rdvz_port: int, + config_file_path: Path, + tmp_path: Path, + error_queue: Any, + ) -> None: + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=rdvz_port, + ): + try: + TestMoEEPFSDP2E2E._worker_impl( + process_id=process_id, + config_file_path=config_file_path, + tmp_path=tmp_path, + ) + except Exception as exc: + tb = traceback.format_exc() + logging.error(f"Process {process_id} failed: {exc}\n{tb}") + try: + error_queue.put((process_id, tb)) + except Exception: + logging.error("Failed to write child exception to queue.") + os._exit(1) + + @staticmethod + def _worker_impl(process_id: int, config_file_path: Path, tmp_path: Path) -> None: + experiment_id = "moe-ep-fsdp2-e2e" + checkpoint_root_path = tmp_path / experiment_id / "checkpoints" + cfg = load_app_config_dict( + config_file_path=config_file_path, experiments_root_path=tmp_path, experiment_id=experiment_id + ) + TestMoEEPFSDP2E2E._patch_for_short_test_run(cfg, checkpoint_root_path) + + main_obj = Main(config_file_path, experiments_root_path=tmp_path, experiment_id=experiment_id) + main_obj.config_dict = cfg + main_obj.add_custom_component( + component_key="results_subscriber", + variant_key="save_all", + custom_component=SaveAllResultSubscriber, + custom_config=SaveAllResultSubscriberConfig, + ) + main_obj.config_dict["evaluation_subscriber"]["variant_key"] = "save_all" + main_obj.config_dict["evaluation_subscriber"]["config"] = {} + + components: TrainingComponentsInstantiationModel = main_obj.build_components( + components_model_type=TrainingComponentsInstantiationModel + ) + + assert getattr(components.model_raw, "_ep_wrapped", False), "Expected EP wrapping marker on raw model." + first_layer = next(iter(components.model_raw.layers.values())) + assert getattr(first_layer.ffn.experts, "_ep_enabled", False), "Expected experts to be EP-enabled." + + main_obj.run(components) + + result_messages: list[Message[EvaluationResultBatch]] = components.evaluation_subscriber.message_list + assert len(result_messages) > 0, "Expected training messages in evaluation subscriber." + for message in result_messages: + loss_value = message.payload.losses["train loss avg"].value + assert torch.isfinite(loss_value), f"Found non-finite train loss: {loss_value}" + + if process_id == 0: + checkpoint_info_file_path = checkpoint_root_path / "last_checkpoint_info.json" + assert checkpoint_info_file_path.exists(), "Expected checkpoint info file from DCP save." + + @staticmethod + def test_moe_ep_fsdp2_training_and_checkpointing(tmp_path: Path) -> None: + repo_root = Path(__file__).resolve().parents[2] + config_file_path = repo_root / "config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml" + + world_size = 4 + rdvz_port = find_free_port() + + manager = py_mp.Manager() + error_queue = manager.Queue() + proc_ctx = mp.spawn( + TestMoEEPFSDP2E2E._worker_wrapper, + args=(world_size, rdvz_port, config_file_path, tmp_path, error_queue), + nprocs=world_size, + join=False, + ) + + monitor_child_processes(manager, error_queue, proc_ctx) diff --git a/tests/models/moe/__init__.py b/tests/models/moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/models/moe/test_loss_functions.py b/tests/models/moe/test_loss_functions.py new file mode 100644 index 000000000..346b69818 --- /dev/null +++ b/tests/models/moe/test_loss_functions.py @@ -0,0 +1,59 @@ +import torch +from torch.nn import CrossEntropyLoss + +from modalities.batch import InferenceResultBatch +from modalities.models.moe.loss_functions import MoECrossEntropyLoss + + +class DummyLayer: + def __init__(self, aux_loss): + self.aux_loss = aux_loss + + +class DummyModel: + def __init__(self, aux_losses: list[torch.Tensor | None]): + self.layers = {str(i): DummyLayer(aux) for i, aux in enumerate(aux_losses)} + + +def test_moe_cross_entropy_loss_adds_aux_losses(): + logits = torch.tensor( + [ + [[1.2, 0.3, -0.5], [0.1, 1.8, -0.3]], + [[0.5, -0.4, 1.1], [0.7, 0.2, -0.1]], + ], + dtype=torch.float32, + ) + targets = torch.tensor([[0, 1], [2, 0]], dtype=torch.long) + + batch = InferenceResultBatch( + targets={"targets": targets}, + predictions={"logits": logits}, + ) + + aux_1 = torch.tensor(0.2) + aux_2 = torch.tensor(0.3) + model = DummyModel(aux_losses=[aux_1, None, aux_2]) + loss_fn = MoECrossEntropyLoss(target_key="targets", prediction_key="logits", model=model) + + loss = loss_fn(batch) + base_ce = CrossEntropyLoss(reduction="mean")(logits.view(-1, logits.size(-1)), targets.view(-1)) + + assert torch.allclose(loss, base_ce + aux_1 + aux_2) + + +def test_moe_cross_entropy_loss_without_aux_matches_plain_ce(): + logits = torch.randn(2, 3, 5) + targets = torch.randint(0, 5, (2, 3), dtype=torch.long) + + batch = InferenceResultBatch( + targets={"labels": targets}, + predictions={"pred": logits}, + ) + + model = DummyModel(aux_losses=[None, None]) + loss_fn = MoECrossEntropyLoss(target_key="labels", prediction_key="pred", model=model) + + loss = loss_fn(batch) + expected = CrossEntropyLoss(reduction="mean")(logits.view(-1, logits.size(-1)), targets.view(-1)) + + assert torch.allclose(loss, expected) diff --git a/tests/models/moe/test_qwen_model.py b/tests/models/moe/test_qwen_model.py new file mode 100644 index 000000000..d4d90b592 --- /dev/null +++ b/tests/models/moe/test_qwen_model.py @@ -0,0 +1,60 @@ +import torch + +from modalities.models.moe.qwen_model import GroupedExperts, QwenModel + + +def _build_tiny_qwen_model() -> QwenModel: + return QwenModel( + vocab_size=32, + max_seq_len=16, + d_model=16, + n_heads=4, + n_kv_heads=2, + d_ff=32, + num_layers=1, + moe_d_ff=24, + moe_num_experts=4, + moe_top_k=2, + moe_capacity_factor=1.25, + moe_min_capacity=1, + moe_overflow_policy="residual", + moe_aux_loss_coef=0.01, + moe_z_loss_coef=0.0, + ) + + +def test_qwen_model_forward_dict_output_shape(): + torch.manual_seed(0) + model = _build_tiny_qwen_model() + batch_size, seq_len = 2, 5 + + input_ids = torch.randint(0, 32, (batch_size, seq_len), dtype=torch.long) + output = model({"input_ids": input_ids}) + + assert "logits" in output + assert output["logits"].shape == (batch_size, seq_len, 32) + + +def test_grouped_experts_forward_local_preserves_input_dtype(): + experts = GroupedExperts(num_experts=2, d_model=8, d_ff=12, ffn_dropout=0.0) + experts.reset_parameters() + + # Input in bf16 while expert weights are initialized in fp32. + routed_input = torch.randn(4, 8, dtype=torch.bfloat16) + num_tokens_per_expert = torch.tensor([2, 2], dtype=torch.long) + + out = experts._forward_local(routed_input=routed_input, num_tokens_per_expert=num_tokens_per_expert) + + assert out.shape == routed_input.shape + assert out.dtype == routed_input.dtype + + +def test_transformer_block_exposes_aux_loss_after_forward(): + torch.manual_seed(1) + model = _build_tiny_qwen_model() + input_ids = torch.randint(0, 32, (2, 4), dtype=torch.long) + + _ = model({"input_ids": input_ids}) + + first_layer = next(iter(model.layers.values())) + assert first_layer.aux_loss is not None diff --git a/tests/optimizers/test_ep_adamw.py b/tests/optimizers/test_ep_adamw.py new file mode 100644 index 000000000..bb366627b --- /dev/null +++ b/tests/optimizers/test_ep_adamw.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn + +from modalities.models.model import NNModel +from modalities.optimizers.ep_adamw import EPAdamW + + +class DummyDPShardMesh: + def __init__(self): + self._group = object() + + def get_group(self): + return self._group + + +class EPSubmodule(nn.Module): + def __init__(self): + super().__init__() + self.ep_weight = nn.Parameter(torch.tensor([1.0, -1.0])) + self._ep_enabled = True + + +class TinyModel(NNModel): + def __init__(self): + super().__init__(weight_decay_groups={"linear": ["linear"], "embedding": [], "layernorm": ["norm"]}) + self.linear = nn.Linear(2, 2, bias=False) + self.norm = nn.LayerNorm(2) + self.experts = EPSubmodule() + + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + x = inputs["x"] + return {"y": self.linear(x)} + + +def _patch_distributed_ops(monkeypatch): + from modalities.optimizers import ep_adamw as ep_adamw_module + + monkeypatch.setattr(ep_adamw_module.dist, "get_rank", lambda group=None: 0) + monkeypatch.setattr(ep_adamw_module.dist, "get_world_size", lambda group=None: 1) + monkeypatch.setattr(ep_adamw_module.dist, "all_reduce", lambda tensor, op=None, group=None: tensor) + monkeypatch.setattr(ep_adamw_module.dist, "broadcast", lambda tensor, src=0, group=None: tensor) + monkeypatch.setattr(ep_adamw_module.dist, "get_global_rank", lambda group, group_rank: group_rank) + + +def test_ep_adamw_state_dict_and_load_state_dict(monkeypatch): + _patch_distributed_ops(monkeypatch) + + model = TinyModel() + optimizer = EPAdamW( + model=model, + device_mesh={"dp_shard": DummyDPShardMesh()}, + lr=1e-2, + betas=(0.9, 0.95), + eps=1e-8, + weight_decay=0.1, + weight_decay_groups_excluded=["layernorm"], + ) + + state = optimizer.state_dict() + assert "ep_adamw" in state + assert "dense_adamw" in state + + optimizer.load_state_dict(state) + + +def test_ep_adamw_step_updates_parameters_and_zero_grad(monkeypatch): + _patch_distributed_ops(monkeypatch) + + model = TinyModel() + optimizer = EPAdamW( + model=model, + device_mesh={"dp_shard": DummyDPShardMesh()}, + lr=1e-2, + betas=(0.9, 0.95), + eps=1e-8, + weight_decay=0.1, + weight_decay_groups_excluded=["layernorm"], + ) + + before = [p.detach().clone() for p in model.parameters()] + for p in model.parameters(): + p.grad = torch.ones_like(p) + + optimizer.step() + after = list(model.parameters()) + + for p_before, p_after in zip(before, after): + assert not torch.allclose(p_before, p_after) + + optimizer.zero_grad(set_to_none=True) + for p in model.parameters(): + assert p.grad is None diff --git a/tests/training/gradient_clipping/test_ep_gradient_clipper.py b/tests/training/gradient_clipping/test_ep_gradient_clipper.py new file mode 100644 index 000000000..322ece24f --- /dev/null +++ b/tests/training/gradient_clipping/test_ep_gradient_clipper.py @@ -0,0 +1,50 @@ +import pytest +import torch +import torch.nn as nn + +from modalities.training.gradient_clipping.ep_gradient_clipper import EPGradientClipper +from modalities.training.gradient_clipping.fsdp_gradient_clipper import GradientClippingMode + + +class MockModel(nn.Module): + def __init__(self): + super().__init__() + self.param1 = nn.Parameter(torch.tensor([1.0, 2.0])) + self.param2 = nn.Parameter(torch.tensor([3.0, 4.0])) + + +def test_ep_gradient_clipper_clips_gradients(): + model = MockModel() + model.param1.grad = torch.tensor([1.0, 1.0]) + model.param2.grad = torch.tensor([1.0, 1.0]) + + clipper = EPGradientClipper(model_parts=model, max_norm=1.0, norm_type=GradientClippingMode.P2_NORM) + total_norm = clipper.clip_gradients() + + assert torch.allclose(total_norm, torch.tensor(2.0)) + assert torch.allclose(model.param1.grad, torch.tensor([0.5, 0.5]), atol=1e-6) + assert torch.allclose(model.param2.grad, torch.tensor([0.5, 0.5]), atol=1e-6) + + +def test_ep_gradient_clipper_returns_zero_for_no_gradients(): + model = MockModel() + + clipper = EPGradientClipper(model_parts=model, max_norm=1.0, norm_type=GradientClippingMode.P2_NORM) + total_norm = clipper.clip_gradients() + + assert torch.allclose(total_norm.cpu(), torch.tensor(0.0)) + + +def test_ep_gradient_clipper_raises_for_nonfinite_norm(): + model = MockModel() + model.param1.grad = torch.tensor([float("nan"), 1.0]) + + clipper = EPGradientClipper( + model_parts=model, + max_norm=1.0, + norm_type=GradientClippingMode.P2_NORM, + error_if_nonfinite=True, + ) + + with pytest.raises(RuntimeError, match="non-finite"): + clipper.clip_gradients()