From 8c972a1c902675480a5554f802b40c9ecc260aee Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 9 Jun 2026 09:05:01 +0000 Subject: [PATCH 01/12] feat: Add moe model --- moe/config/moe_ep_config.yaml | 357 ++++++++++++ moe/config/qwen_config.yaml | 365 ++++++++++++ moe/config/tokenization_config.yaml | 18 + moe/modalities_moe/__init__.py | 0 moe/modalities_moe/config/__init__.py | 0 moe/modalities_moe/config/config.py | 22 + moe/modalities_moe/loss_functions.py | 39 ++ moe/modalities_moe/models/__init__.py | 0 moe/modalities_moe/models/model_factory.py | 143 +++++ moe/modalities_moe/models/moe/__init__.py | 0 moe/modalities_moe/models/moe/moe_model.py | 537 ++++++++++++++++++ moe/modalities_moe/models/moe/qwen_model.py | 501 ++++++++++++++++ moe/modalities_moe/optimizers/__init__.py | 0 moe/modalities_moe/optimizers/ep_adamw.py | 169 ++++++ moe/modalities_moe/training/__init__.py | 0 .../training/gradient_clipping/__init__.py | 0 .../gradient_clipping/ep_gradient_clipper.py | 90 +++ moe/scripts/monitor_gpus.sh | 60 ++ moe/scripts/train_ep.py | 195 +++++++ src/modalities/models/model_factory.py | 10 +- 20 files changed, 2505 insertions(+), 1 deletion(-) create mode 100644 moe/config/moe_ep_config.yaml create mode 100644 moe/config/qwen_config.yaml create mode 100644 moe/config/tokenization_config.yaml create mode 100644 moe/modalities_moe/__init__.py create mode 100644 moe/modalities_moe/config/__init__.py create mode 100644 moe/modalities_moe/config/config.py create mode 100644 moe/modalities_moe/loss_functions.py create mode 100644 moe/modalities_moe/models/__init__.py create mode 100644 moe/modalities_moe/models/model_factory.py create mode 100644 moe/modalities_moe/models/moe/__init__.py create mode 100644 moe/modalities_moe/models/moe/moe_model.py create mode 100644 moe/modalities_moe/models/moe/qwen_model.py create mode 100644 moe/modalities_moe/optimizers/__init__.py create mode 100644 moe/modalities_moe/optimizers/ep_adamw.py create mode 100644 moe/modalities_moe/training/__init__.py create mode 100644 moe/modalities_moe/training/gradient_clipping/__init__.py create mode 100644 moe/modalities_moe/training/gradient_clipping/ep_gradient_clipper.py create mode 100755 moe/scripts/monitor_gpus.sh create mode 100644 moe/scripts/train_ep.py diff --git a/moe/config/moe_ep_config.yaml b/moe/config/moe_ep_config.yaml new file mode 100644 index 000000000..883e0dffb --- /dev/null +++ b/moe/config/moe_ep_config.yaml @@ -0,0 +1,357 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + experiments_root_path: /leonardo/home/userexternal/gesposit/projects/modalities/moe/experiments + experiment_folder_path: ${settings.paths.experiments_root_path}/${settings.experiment_id} + checkpoint_saving_path: /leonardo_scratch/large/userexternal/gesposit/modalities/checkpoints + train_dataset_path: /leonardo_scratch/large/userexternal/gesposit/modalities/data/processed/fineweb_edu_num_docs_483606.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 1001 + evaluation_interval_in_steps: 1001 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 4 + local_train_micro_batch_size: 1 + sequence_length: 512 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_num_steps + config: + num_steps: ${settings.training_target.num_target_steps} + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: 10 + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + # we set num_workers to 0 so that the the data is loaded in the main process + # this is required to track how often the collator has been called + # in the library tutorials. Otherwise the collator will be copied for each worker + # and the number of call is out of scope. + num_workers: 0 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: [] + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.experiment_folder_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: moe_cross_entropy + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + model: + instance_key: model_raw + pass_type: BY_REFERENCE + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + # Keep FSDP sharding on dp_shard and reserve tp for expert parallel. + data_parallel_shard_degree: -1 + tensor_parallel_degree: 32 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: # get the parallel degree from the device mesh + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.num_layers} + +ep_model: + component_key: model + variant_key: ep_wrapped + config: + model: + instance_key: model_raw # Bypass torch.compile - MoE routing is incompatible + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + ep_mesh_dim_name: tp + block_names: [TransformerBlock] + +ac_model: + component_key: model + variant_key: activation_checkpointed # using modalities fsdp2 ac. should do to job also for ep layers + config: + model: + instance_key: ep_model + pass_type: BY_REFERENCE + ac_variant: full_activation_checkpointing + layers_fqn: layers + ac_fun_params: + ac_freq: 1 + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: ac_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + reshard_after_forward: true + block_names: [TransformerBlock] + +compiled_model: + component_key: model + variant_key: compiled + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + block_names: [TransformerBlock] + +model_raw: + component_key: model + variant_key: moe + config: + vocab_size: 32064 # to match a pretrained tokenizer + max_seq_len: 4096 + d_model: 4096 + n_heads: 32 + n_kv_heads: 8 + num_layers: 32 + d_ff: 14336 + moe_every_n_layers: 1 + moe_num_experts: 16 + moe_top_k: 2 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.02 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: ep + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: to_disc + config: + output_file_path: ${settings.paths.experiment_folder_path}/evaluation_results.jsonl + +mfu_calculator: + component_key: mfu_calculator + variant_key: gpt2 + config: + n_layer: ${model_raw.config.num_layers} + sequence_length: ${settings.step_profile.sequence_length} + n_embd: ${model_raw.config.d_model} + world_size: ${settings.cuda_env.world_size} + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +# profiler: +# component_key: steppable_profiler +# variant_key: combined +# config: +# profilers: +# - instance_key: kernel_profiler +# pass_type: BY_REFERENCE +# # - instance_key: memory_profiler +# # pass_type: BY_REFERENCE + +kernel_profiler: + component_key: steppable_profiler + variant_key: kernel_tracing + config: + num_wait_steps: 1 + num_warmup_steps: 1 + num_active_steps: 3 + profiler_activities: [CUDA] + profile_memory: true + record_shapes: true + with_stack: true + with_flops: true + with_modules: true + tracked_ranks: [0] + output_folder_path: ${settings.paths.experiment_folder_path}/profiling + +memory_profiler: + component_key: steppable_profiler + variant_key: memory_tracing + config: + memory_snapshot_folder_path: ${settings.paths.experiment_folder_path}/profiling + num_wait_steps: 1 + num_warmup_steps: 1 + num_active_steps: 3 + tracked_ranks: [0] \ No newline at end of file diff --git a/moe/config/qwen_config.yaml b/moe/config/qwen_config.yaml new file mode 100644 index 000000000..46b233dec --- /dev/null +++ b/moe/config/qwen_config.yaml @@ -0,0 +1,365 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + experiments_root_path: /raid/s3/opengptx/user/richard-rutmann/experiments/modalities/moe_fsdp2 + experiment_folder_path: ${settings.paths.experiments_root_path}/${settings.experiment_id} + checkpoint_saving_path: /raid/s3/opengptx/user/richard-rutmann/experiments/modalities/moe_fsdp2/checkpoints + train_dataset_path: /raid/s3/opengptx/user/richard-rutmann/data/modalities/gpt2_tokenized/000_00000.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 1001 + evaluation_interval_in_steps: 1001 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 4 + local_train_micro_batch_size: 2 + sequence_length: 4096 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_num_steps + config: + num_steps: ${settings.training_target.num_target_steps} + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: 10 + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + # we set num_workers to 0 so that the the data is loaded in the main process + # this is required to track how often the collator has been called + # in the library tutorials. Otherwise the collator will be copied for each worker + # and the number of call is out of scope. + num_workers: 0 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: [] + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.experiment_folder_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: moe_cross_entropy + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + model: + instance_key: model_raw + pass_type: BY_REFERENCE + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + # Keep FSDP sharding on dp_shard and reserve tp for expert parallel. + data_parallel_shard_degree: -1 + tensor_parallel_degree: 4 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: # get the parallel degree from the device mesh + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.num_layers} + +ep_model: + component_key: model + variant_key: ep_wrapped + config: + model: + instance_key: model_raw # Bypass torch.compile - MoE routing is incompatible + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + ep_mesh_dim_name: tp + block_names: [TransformerBlock] + +ac_model: + component_key: model + variant_key: activation_checkpointed # using modalities fsdp2 ac. should do to job also for ep layers + config: + model: + instance_key: ep_model + pass_type: BY_REFERENCE + ac_variant: full_activation_checkpointing + layers_fqn: layers + ac_fun_params: + ac_freq: 1 + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: ac_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + reshard_after_forward: true + block_names: [TransformerBlock] + +compiled_model: + component_key: model + variant_key: compiled + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + block_names: [TransformerBlock] + +model_raw: + component_key: model + variant_key: moe + config: + vocab_size: 50257 # to match a pretrained tokenizer, tochange + max_seq_len: 4096 + d_model: 2048 + d_ff: 6144 + n_heads: 32 + n_kv_heads: 8 + num_layers: 8 + attn_dropout: 0.0 + ffn_dropout: 0.0 + tie_embeddings: false + norm_eps: 1e-06 + rope_base: 1000000.0 + moe_num_experts: 128 + moe_d_ff: 768 + moe_top_k: 8 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.02 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: ep_adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: ep + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: to_disc + config: + output_file_path: ${settings.paths.experiment_folder_path}/evaluation_results.jsonl + +mfu_calculator: + component_key: mfu_calculator + variant_key: gpt2 + config: + n_layer: ${model_raw.config.num_layers} + sequence_length: ${settings.step_profile.sequence_length} + n_embd: ${model_raw.config.d_model} + world_size: ${settings.cuda_env.world_size} + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +# profiler: +# component_key: steppable_profiler +# variant_key: combined +# config: +# profilers: +# - instance_key: kernel_profiler +# pass_type: BY_REFERENCE +# # - instance_key: memory_profiler +# # pass_type: BY_REFERENCE + +kernel_profiler: + component_key: steppable_profiler + variant_key: kernel_tracing + config: + num_wait_steps: 1 + num_warmup_steps: 1 + num_active_steps: 3 + profiler_activities: [CUDA] + profile_memory: true + record_shapes: true + with_stack: true + with_flops: true + with_modules: true + tracked_ranks: [0] + output_folder_path: ${settings.paths.experiment_folder_path}/profiling + +memory_profiler: + component_key: steppable_profiler + variant_key: memory_tracing + config: + memory_snapshot_folder_path: ${settings.paths.experiment_folder_path}/profiling + num_wait_steps: 1 + num_warmup_steps: 1 + num_active_steps: 3 + tracked_ranks: [0] \ No newline at end of file diff --git a/moe/config/tokenization_config.yaml b/moe/config/tokenization_config.yaml new file mode 100644 index 000000000..5a4b8b781 --- /dev/null +++ b/moe/config/tokenization_config.yaml @@ -0,0 +1,18 @@ +settings: + src_path: data/raw/fineweb_edu_num_docs_483606.jsonl + dst_path: data/preprocessed/fineweb_edu_num_docs_483606.pbin + index_path: data/preprocessed/fineweb_edu_num_docs_483606.idx + jq_pattern: .text + num_cpus: ${node_env:num_cpus} + eod_token: <|endoftext|> + processing_batch_size: 10 + raw_samples_queue_size: 300 + processed_samples_queue_size: 300 + +tokenizer: + component_key: tokenizer + variant_key: pretrained_hf_tokenizer + config: + pretrained_model_name_or_path: data/tokenizer + padding: false + truncation: false \ No newline at end of file diff --git a/moe/modalities_moe/__init__.py b/moe/modalities_moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/moe/modalities_moe/config/__init__.py b/moe/modalities_moe/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/moe/modalities_moe/config/config.py b/moe/modalities_moe/config/config.py new file mode 100644 index 000000000..0ab14372a --- /dev/null +++ b/moe/modalities_moe/config/config.py @@ -0,0 +1,22 @@ +from typing import Any + +from pydantic import BaseModel + +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticPytorchModuleOrListType + + +class MoECrossEntropyLossConfig(BaseModel): + target_key: str + prediction_key: str + model: Any + tag: str = "MoECrossEntropyLoss" + + class Config: + arbitrary_types_allowed = True + + +class EPWrappedModelConfig(BaseModel): + model: PydanticPytorchModuleOrListType + block_names: list[str] + device_mesh: PydanticDeviceMeshIFType + ep_mesh_dim_name: str | None = None diff --git a/moe/modalities_moe/loss_functions.py b/moe/modalities_moe/loss_functions.py new file mode 100644 index 000000000..654677efb --- /dev/null +++ b/moe/modalities_moe/loss_functions.py @@ -0,0 +1,39 @@ +import torch +from torch.nn import CrossEntropyLoss + +from modalities.batch import InferenceResultBatch +from modalities.loss_functions import Loss + + +class MoECrossEntropyLoss(Loss): + """Cross Entropy Loss with auxiliary loss support for router balancing""" + + def __init__( + self, + target_key: str, + prediction_key: str, + model, + tag: str = "MoECrossEntropyLoss", + ): + super().__init__(tag) + self.target_key = target_key + self.prediction_key = prediction_key + self.model = model + self.loss_fun = CrossEntropyLoss(reduction="mean") + + def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: + labels = forward_batch.get_targets(self.target_key) + lm_logits = forward_batch.get_predictions(self.prediction_key) + + labels = labels.to(lm_logits.device) + loss = self.loss_fun( + lm_logits.contiguous().view(-1, lm_logits.size(-1)), + labels.contiguous().long().view(-1), + ) + + # Aux loss + for layer in self.model.layers.values(): + if hasattr(layer, "aux_loss") and layer.aux_loss is not None: + loss = loss + layer.aux_loss.to(loss.dtype) + + return loss diff --git a/moe/modalities_moe/models/__init__.py b/moe/modalities_moe/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/moe/modalities_moe/models/model_factory.py b/moe/modalities_moe/models/model_factory.py new file mode 100644 index 000000000..65dbe8e3f --- /dev/null +++ b/moe/modalities_moe/models/model_factory.py @@ -0,0 +1,143 @@ +import warnings + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._composable.fsdp import MixedPrecisionPolicy +from torch.distributed.device_mesh import DeviceMesh +from torchtitan.distributed.expert_parallel import ExpertParallel + +from modalities.util import get_module_class_from_name + + +# TODO refactor these funtions into a utils +def _resolve_ep_mesh( + device_mesh: DeviceMesh, ep_mesh_dim_name: str | None +) -> DeviceMesh: # devicemesh not supporting EP + mesh_dim_names = tuple(device_mesh.mesh_dim_names or ()) + + if ep_mesh_dim_name is not None: + if ep_mesh_dim_name not in mesh_dim_names: + raise ValueError(f"ep_mesh_dim_name='{ep_mesh_dim_name}' not in mesh_dim_names={mesh_dim_names}") + return device_mesh[ep_mesh_dim_name] + + if len(mesh_dim_names) <= 1: + return device_mesh + + raise ValueError( + "DeviceMesh has multiple dimensions. Pass ep_mesh_dim_name explicitly. " + f"Available dimensions: {mesh_dim_names}" + ) + + +def _validate_moe_block_for_ep(module) -> None: + if not hasattr(module, "experts"): + raise ValueError(f"Module {type(module).__name__} has no 'experts' attribute") + + experts = module.experts + required_attrs = ["w1", "w2"] + missing = [attr for attr in required_attrs if not hasattr(experts, attr)] + if missing: + raise ValueError( + f"Module {type(module).__name__}.experts is not grouped-experts compatible. Missing: {missing}" + ) + + if experts.w1.ndim != 3 or experts.w2.ndim != 3: + raise ValueError( + f"Expected grouped expert parameters with ndim=3. Got w1.ndim={experts.w1.ndim}, " + f"w2.ndim={experts.w2.ndim}" + ) + + +def _get_ep_target_module(module): + if hasattr(module, "experts"): + return module + + ffn = getattr(module, "ffn", None) + if ffn is not None and hasattr(ffn, "experts"): + return ffn + + return None + + +def _attach_ep_metadata(module, ep_mesh) -> None: + setattr(module, "_ep_mesh", ep_mesh) + setattr(module, "_ep_group", ep_mesh.get_group()) + setattr(module, "_ep_size", ep_mesh.size()) + setattr(module, "_ep_rank", ep_mesh.get_local_rank()) + + +def _apply_torchtitan_ep(module, ep_mesh) -> None: + module.experts = ExpertParallel()._apply(module.experts, ep_mesh) + setattr(module.experts, "_ep_enabled", True) + + +def debug_forward_hook(module, input): + for name, param in module.named_parameters(recurse=False): + if hasattr(param, "_local_tensor"): + # still dTensor + print(f"[EP forward] {name}: still DTensor, local={param._local_tensor.shape}") + else: + print(f"[EP forward] {name}: plain tensor shape={param.shape}") + + +def get_ep_wrapped_model( + model, + block_names: list[str], + device_mesh: DeviceMesh, + ep_mesh_dim_name: str | None = None, + mp_param_dtype=torch.bfloat16, + mp_reduce_dtype=torch.bfloat16, +) -> nn.Module: + # Warn for unresolved names, but still wrap any block types that can be resolved. + block_types = [] + missing_block_names = [] + for name in block_names: + block_type = get_module_class_from_name(model, name) + if block_type is None: + missing_block_names.append(name) + else: + block_types.append(block_type) + + if len(missing_block_names) > 0 and (not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0): + warnings.warn( + f"Could not resolve some requested MoE block names and they will be ignored: {missing_block_names}", + stacklevel=2, + ) + + block_types = tuple(block_types) + + if len(block_types) == 0: + raise ValueError(f"None of the requested MoE block names were found: {block_names}") + + ep_mesh = _resolve_ep_mesh(device_mesh, ep_mesh_dim_name) + device_mesh["dp_shard"] + MixedPrecisionPolicy(param_dtype=mp_param_dtype, reduce_dtype=mp_reduce_dtype) + + wrapped_blocks = 0 + for module in model.modules(): + if isinstance(module, block_types): + ep_target_module = _get_ep_target_module(module) + if ep_target_module is None: + raise ValueError( + f"Module {type(module).__name__} has no EP-compatible experts location. " + "Expected `experts` or `ffn.experts`." + ) + + if getattr(ep_target_module, "_ep_enabled", False): + continue + + _validate_moe_block_for_ep(ep_target_module) + _attach_ep_metadata(ep_target_module, ep_mesh) + _apply_torchtitan_ep(ep_target_module, ep_mesh) + + wrapped_blocks += 1 + + if wrapped_blocks == 0: + raise ValueError(f"No blocks matched the requested types: {[t.__name__ for t in block_types]}") + + setattr(model, "_ep_wrapped", True) + setattr(model, "_ep_mesh", ep_mesh) + setattr(model, "_ep_num_wrapped_blocks", wrapped_blocks) + + return model diff --git a/moe/modalities_moe/models/moe/__init__.py b/moe/modalities_moe/models/moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/moe/modalities_moe/models/moe/moe_model.py b/moe/modalities_moe/models/moe/moe_model.py new file mode 100644 index 000000000..a1412d410 --- /dev/null +++ b/moe/modalities_moe/models/moe/moe_model.py @@ -0,0 +1,537 @@ +import math +from dataclasses import dataclass +from typing import Literal, Optional, overload + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pydantic import BaseModel + +# TODO reolve this import +try: + from torch.distributed.tensor import DTensor +except Exception: + DTensor = None + + +class MoEModelConfig(BaseModel): + # model config + vocab_size: int + max_seq_len: int + d_model: int + n_heads: int + n_kv_heads: int + num_layers: int + d_ff: int + sample_key: str = "input_ids" + prediction_key: str = "logits" + attn_dropout: float = 0.0 + ffn_dropout: float = 0.0 + tie_embeddings: bool = False + moe_every_n_layers: int = 1 + moe_num_experts: int = 8 + moe_top_k: int = 2 + moe_capacity_factor: float = 1.25 + moe_aux_loss_coef: float = 0.01 + moe_z_loss_coef: float = 0.0 + moe_router_noise_std: float = 0.0 + + +@dataclass +class MoEArguments: + # Model hyperparameters + d_model: int + d_ff: int + + # MoE hyperparameters + num_experts: int + top_k: int + capacity_factor: float = 1.25 + min_capacity: int = 4 + overflow_policy: Literal["drop", "residual"] = "residual" + + # Router configuration + router_noise_std: float = 0.0 + router_temperature: float = 1.0 + router_dropout: float = 0.0 + + # Auxiliary loss coefficients + aux_loss_coef: float = 0.01 + z_loss_coef: float = 0.0 + + # Training configuration + dropout: float = 0.0 + + +class RMSNorm(nn.Module): + def __init__(self, d_model, eps=1e-8): + super(RMSNorm, self).__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d_model)) + + def forward(self, x): + norm_x = x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return norm_x * self.weight + + +class Expert(nn.Module): + def __init__(self, d_model, d_ff, dropout=0.0): + super(Expert, self).__init__() + self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.w1 = nn.Linear(d_model, d_ff) + self.w2 = nn.Linear(d_model, d_ff) + self.w3 = nn.Linear(d_ff, d_model) + + def forward(self, x): + x1 = self.w1(x) + x2 = self.w2(x) + x = torch.nn.functional.silu(x1) * x2 + x = self.w3(x) + return self.dropout(x) + + +class GroupedExperts(nn.Module): + """Grouped experts for torchtitan compatibility.""" + + def __init__(self, config: MoEArguments): + super().__init__() + self.num_experts = config.num_experts + self.d_model = config.d_model + self.d_ff = config.d_ff + self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else nn.Identity() + + self.w1 = nn.Parameter(torch.empty(self.num_experts, self.d_ff, self.d_model)) + self.b1 = nn.Parameter(torch.empty(self.num_experts, self.d_ff)) + self.w2 = nn.Parameter(torch.empty(self.num_experts, self.d_ff, self.d_model)) + self.b2 = nn.Parameter(torch.empty(self.num_experts, self.d_ff)) + self.w3 = nn.Parameter(torch.empty(self.num_experts, self.d_model, self.d_ff)) + self.b3 = nn.Parameter(torch.empty(self.num_experts, self.d_model)) + + self.initialize() + + def initialize(self): + nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) + bound_w1 = 1 / math.sqrt(self.d_model) + nn.init.uniform_(self.b1, -bound_w1, bound_w1) + + nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) + bound_w2 = 1 / math.sqrt(self.d_model) + nn.init.uniform_(self.b2, -bound_w2, bound_w2) + + nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5)) + bound_w3 = 1 / math.sqrt(self.d_ff) + nn.init.uniform_(self.b3, -bound_w3, bound_w3) + + def _forward_local(self, routed_input, num_tokens_per_expert) -> torch.Tensor: + outputs: list[torch.Tensor] = [] + start = 0 + + # ExpertParallel may convert parameters to DTensor. Local expert compute + # expects plain tensors, so we materialize local shards when needed. + w1 = self.w1.to_local() if DTensor is not None and isinstance(self.w1, DTensor) else self.w1 + b1 = self.b1.to_local() if DTensor is not None and isinstance(self.b1, DTensor) else self.b1 + w2 = self.w2.to_local() if DTensor is not None and isinstance(self.w2, DTensor) else self.w2 + b2 = self.b2.to_local() if DTensor is not None and isinstance(self.b2, DTensor) else self.b2 + w3 = self.w3.to_local() if DTensor is not None and isinstance(self.w3, DTensor) else self.w3 + b3 = self.b3.to_local() if DTensor is not None and isinstance(self.b3, DTensor) else self.b3 + + local_num_tokens = ( + num_tokens_per_expert.to_local() + if DTensor is not None and isinstance(num_tokens_per_expert, DTensor) + else num_tokens_per_expert + ) + + total_rows = routed_input.shape[0] + for expert_idx, num_tokens in enumerate(local_num_tokens.tolist()): + requested_tokens = int(num_tokens) + end = start + requested_tokens + + # EP alignment can request padded tokens; only a subset may exist in routed_input. + local_end = min(end, total_rows) + expert_input = routed_input[start:local_end] + real_tokens = int(expert_input.shape[0]) + + out_parts: list[torch.Tensor] = [] + if real_tokens > 0: + x1 = torch.nn.functional.linear(expert_input, w1[expert_idx], b1[expert_idx]) + x2 = torch.nn.functional.linear(expert_input, w2[expert_idx], b2[expert_idx]) + hidden = torch.nn.functional.silu(x1) * x2 + out_real = torch.nn.functional.linear(hidden, w3[expert_idx], b3[expert_idx]) + out_parts.append(self.dropout(out_real)) + + pad_tokens = requested_tokens - real_tokens + if pad_tokens > 0: + out_parts.append(routed_input.new_zeros((pad_tokens, self.d_model))) + + if len(out_parts) > 0: + outputs.append(torch.cat(out_parts, dim=0) if len(out_parts) > 1 else out_parts[0]) + + start = end + + if len(outputs) == 0: + return routed_input.new_zeros((0, self.d_model)) + + out = torch.cat(outputs, dim=0) + + # EP permute may append extra global padding slots beyond per-expert aligned sizes. + # output_fn(_unpermute) expects the same row count as routed_input. + if out.shape[0] < total_rows: + out = torch.cat( + [out, routed_input.new_zeros((total_rows - out.shape[0], self.d_model))], + dim=0, + ) + elif out.shape[0] > total_rows: + out = out[:total_rows] + + return out + + def forward(self, routed_input, num_tokens_per_expert) -> torch.Tensor: + # routed_input: (M, D), sorted/grouped by expert id + # num_tokens_per_expert: (E_local,) for local compute, or global counts before EP input_fn + return self._forward_local(routed_input, num_tokens_per_expert) + + +class MoEBlock(nn.Module): + def __init__(self, config: MoEArguments): + super(MoEBlock, self).__init__() + self.config = config + self.num_experts = config.num_experts + self.router = nn.Linear(config.d_model, self.num_experts) + self.router_dropout = nn.Dropout(config.router_dropout) if config.router_dropout > 0 else nn.Identity() + self.experts = GroupedExperts(config) + + self.last_aux_loss: Optional[torch.Tensor] = None + + def forward(self, x): + B, T, D = x.size() + E = self.config.num_experts + K = self.config.top_k + N = B * T + + x_flat = x.view(N, D) + + # Router logits + logits = self.router(self.router_dropout(x_flat)) # (N, E) + if self.config.router_noise_std > 0 and self.training: + noise = torch.randn_like(logits) * self.config.router_noise_std + logits = logits + noise + logits = logits / self.config.router_temperature + probs = torch.softmax(logits, dim=-1) # (N, E) + + # top-k + topk_val, topk_idx = torch.topk(probs, k=K, dim=-1) # (N, K) + topk_w = topk_val / (topk_val.sum(dim=-1, keepdim=True) + 1e-9) # (N, K) + + # capacity per expert + capacity = math.ceil(self.config.capacity_factor * N / E) + capacity = max(capacity, self.config.min_capacity) + + # dispatch mask - preserve dtype of input + dispatch_mask = torch.nn.functional.one_hot(topk_idx, num_classes=E).to(x_flat.dtype) # (N, K, E) + + # token assignment + expert_mask = dispatch_mask.sum(dim=1) # (N, E) + positions = torch.cumsum(expert_mask, dim=0) # (N, E) + capacity_mask = (positions <= capacity).to(x_flat.dtype) # (N, E) + final_mask = dispatch_mask * capacity_mask.unsqueeze(1) # (N, K, E) + combine_weights = final_mask * topk_w.unsqueeze(-1) # (N, K, E) + + combine_weights.sum(dim=1) # (N, E) + + # count actual assignments per expert + load = final_mask.sum(dim=[0, 1]) # (E,) + importance = probs.sum(dim=0) # (E,) + + # Build routed token stream. + valid_mask = capacity_mask.gather(1, topk_idx).bool() # (N, K) + token_ids = torch.arange(N, device=x.device).unsqueeze(1).expand(N, K) + + flat_valid = valid_mask.reshape(-1) + flat_token_ids = token_ids.reshape(-1)[flat_valid] + flat_expert_ids = topk_idx.reshape(-1)[flat_valid] + flat_weights = topk_w.reshape(-1)[flat_valid] + + if flat_expert_ids.numel() > 0: + sort_idx = torch.argsort(flat_expert_ids) + token_ids_sorted = flat_token_ids[sort_idx] + expert_ids_sorted = flat_expert_ids[sort_idx] + weights_sorted = flat_weights[sort_idx] + + routed_input = x_flat[token_ids_sorted] + num_tokens_per_expert = torch.bincount(expert_ids_sorted, minlength=E) + + routed_output = self.experts(routed_input, num_tokens_per_expert) + weighted_output = routed_output * weights_sorted.unsqueeze(-1) + + out = x_flat.new_zeros((N, D)) + out.index_add_(0, token_ids_sorted, weighted_output) + + assigned = x_flat.new_zeros((N,)) + assigned.index_add_(0, token_ids_sorted, weights_sorted) + else: + out = x_flat.new_zeros((N, D)) + assigned = x_flat.new_zeros((N,)) + + # Overflow handling: tokens not assigned to any expert + not_assigned = assigned < 1e-6 + + if not_assigned.any(): + if self.config.overflow_policy == "residual": + out[not_assigned] = x_flat[not_assigned] + # if 'drop', out is already zero for those positions + + # auxiliary loss + aux = None + if self.config.aux_loss_coef > 0: + imp = importance / (importance.sum() + 1e-9) + ld = load / (load.sum() + 1e-9) + lb = E * torch.sum(imp * ld) + aux = self.config.aux_loss_coef * lb + + if self.config.z_loss_coef > 0: + z = torch.logsumexp(logits, dim=-1) + z_loss = torch.mean(z**2) + aux = (aux if aux is not None else 0.0) + self.config.z_loss_coef * z_loss + + self.last_aux_loss = aux + return out.view(B, T, D) + + +class GroupedQueryAttention(nn.Module): + def __init__(self, d_model, num_heads, num_kv_heads): + super(GroupedQueryAttention, self).__init__() + self.d_model = d_model + self.n_heads = num_heads + self.n_kv_heads = num_kv_heads + self.head_dim = d_model // num_heads + self.q_proj = nn.Linear(d_model, num_heads * self.head_dim) + self.k_proj = nn.Linear(d_model, num_kv_heads * self.head_dim) + self.v_proj = nn.Linear(d_model, num_kv_heads * self.head_dim) + self.out_proj = nn.Linear(num_heads * self.head_dim, d_model) + + def forward(self, query, key, value, mask=None): + Q = self.q_proj(query).view(query.size(0), -1, self.n_heads, self.head_dim) + K = self.k_proj(key).view(key.size(0), -1, self.n_kv_heads, self.head_dim) + V = self.v_proj(value).view(value.size(0), -1, self.n_kv_heads, self.head_dim) + Q = Q.permute(0, 2, 1, 3) + K = K.permute(0, 2, 1, 3) + V = V.permute(0, 2, 1, 3) + # Compute attention scores + attn_scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) / (self.head_dim**0.5) + if mask is not None: + attn_scores += mask + attn_weights = F.softmax(attn_scores, dim=-1) + attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights, V) + attn_output = attn_output.contiguous().view(query.size(0), -1, self.n_heads * self.head_dim) + return self.out_proj(attn_output), None + + +class TransformerBlock(nn.Module): + """Transformer block with MoE""" + + def __init__(self, d_model, d_ff, num_heads, num_kv_heads, moe_config: MoEArguments): + super(TransformerBlock, self).__init__() + self.d_model = d_model + self.d_ff = d_ff + self.n_heads = num_heads + self.n_kv_heads = num_kv_heads + self.attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True) + self.pre_attn_norm = RMSNorm(d_model) + self.pre_ffn_norm = RMSNorm(d_model) + + if moe_config is not None: + self.ffn = MoEBlock(moe_config) + self.is_moe = True + else: + self.ffn = Expert(d_model, d_ff) + self.is_moe = False + + def forward(self, x): + x_norm = self.pre_attn_norm(x) + attn_output, _ = self.attention(x_norm, x_norm, x_norm) + x = x + attn_output + + # Pre-MoE norm + x_norm = self.pre_ffn_norm(x) + moe_output = self.ffn(x_norm) + x = x + moe_output + + return x + + @property # TODO: AUX LOSS IN FORWARD + def aux_loss(self): + if self.is_moe and hasattr(self.ffn, "last_aux_loss"): + return self.ffn.last_aux_loss + return None + + +class MoEModel(nn.Module): + def __init__( + self, + vocab_size: int, + max_seq_len: int, + d_model: int, + n_heads: int, + n_kv_heads: int, + d_ff: int, + num_layers: int, + sample_key: str = "input_ids", + prediction_key: str = "logits", + attn_dropout: float = 0.0, + ffn_dropout: float = 0.0, + tie_embeddings: bool = True, + moe_every_n_layers: int = 1, + moe_num_experts: int = 8, + moe_top_k: int = 2, + moe_capacity_factor: float = 1.25, + moe_aux_loss_coef: float = 0.01, + moe_z_loss_coef: float = 0.0, + moe_router_noise_std: float = 0.0, + ): + super(MoEModel, self).__init__() + self.sample_key = sample_key + self.prediction_key = prediction_key + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.d_model = d_model + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.num_layers = num_layers + self.d_ff = d_ff + self.attn_dropout = attn_dropout + self.ffn_dropout = ffn_dropout + self.tie_embeddings = tie_embeddings + self.moe_every_n_layers = moe_every_n_layers + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_capacity_factor = moe_capacity_factor + self.moe_aux_loss_coef = moe_aux_loss_coef + self.moe_z_loss_coef = moe_z_loss_coef + self.moe_router_noise_std = moe_router_noise_std + + self.token_emb = nn.Embedding(self.vocab_size, self.d_model) + self.pos_emb = nn.Embedding(self.max_seq_len, self.d_model) + + moe_config = MoEArguments( + d_model=self.d_model, + d_ff=self.d_ff, + num_experts=self.moe_num_experts, + top_k=self.moe_top_k, + capacity_factor=self.moe_capacity_factor, + aux_loss_coef=self.moe_aux_loss_coef, + z_loss_coef=self.moe_z_loss_coef, + router_noise_std=self.moe_router_noise_std, + dropout=self.ffn_dropout, + ) + + self.layers = nn.ModuleDict() + for i in range(self.num_layers): + if i % self.moe_every_n_layers == 0: + block = TransformerBlock(self.d_model, self.d_ff, self.n_heads, self.n_kv_heads, moe_config) + else: + block = TransformerBlock(self.d_model, self.d_ff, self.n_heads, self.n_kv_heads, None) # No MoE + self.layers[str(i)] = block + self.final_norm = RMSNorm(self.d_model) + self.lm_head = nn.Linear(self.d_model, self.vocab_size, bias=False) + if self.tie_embeddings: + self.lm_head.weight = self.token_emb.weight + + @property + def weight_decay_groups(self): + return { + "linear": ["attention", "router", "w1", "w2", "w3", "b1", "b2", "b3", "lm_head"], + "embedding": ["token_emb", "pos_emb"], + "layernorm": ["pre_attn_norm", "pre_ffn_norm", "final_norm"], + } + + @overload + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Forward pass of the MoE module. + + Args: + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. + - sample_key (str): Key for the input tensor containing token ids. + + Returns: + dict[str, torch.Tensor]: A dictionary containing output tensors. + - prediction_key (str): Key for the output tensor containing logits. + """ + ... + + @overload + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the module. + + Args: + inputs (torch.Tensor): A tensor containing input token ids. + + Returns: + torch.Tensor: A tensor containing output logits. + """ + ... + + def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: + """ + Forward pass of the module. + + Args: + inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. + + Returns: + dict[str, torch.Tensor] | torch.Tensor: Model output. + """ + if isinstance(inputs, dict): + return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} + else: + return self.forward_impl(inputs) + + def forward_impl(self, input_ids: torch.Tensor) -> torch.Tensor: + B, T = input_ids.size() + assert T <= self.max_seq_len, f"Sequence length {T} exceeds model's max_seq_len {self.max_seq_len}" + device = input_ids.device + + # Token and position embeddings + token_embeddings = self.token_emb(input_ids) # (B, T, D) + positions = torch.arange(T, device=device).unsqueeze(0).expand(B, T) + pos_embeddings = self.pos_emb(positions) # (B, T, D) + x = token_embeddings + pos_embeddings # (B, T, D) + + # Transformer blocks + for i, layer in enumerate(self.layers.values()): + x = layer(x) + + x = self.final_norm(x) + logits = self.lm_head(x) # (B, T, vocab_size) + + return logits + + +if __name__ == "__main__": # sanity test + torch.manual_seed(0) + + model = MoEModel( + vocab_size=32064, + max_seq_len=32768, + d_model=4096, + n_heads=32, + n_kv_heads=8, + num_layers=32, + d_ff=14336, + moe_every_n_layers=1, + moe_num_experts=8, + moe_top_k=2, + ) + + # Print number of trainable parameters + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Number of trainable parameters: {num_params:,}") + + x = torch.randint(0, model.vocab_size, (2, 64)) + logits = model(x) + + print("logits:", logits.shape) + loss = logits.mean() + loss.backward() + print("backward OK") diff --git a/moe/modalities_moe/models/moe/qwen_model.py b/moe/modalities_moe/models/moe/qwen_model.py new file mode 100644 index 000000000..3a5ec2d61 --- /dev/null +++ b/moe/modalities_moe/models/moe/qwen_model.py @@ -0,0 +1,501 @@ +import math +from typing import Literal, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pydantic import BaseModel + +try: + from torch.distributed.tensor import DTensor +except Exception: + DTensor = None + + +class QwenModelConfig(BaseModel): + # Model + vocab_size: int + max_seq_len: int + d_model: int + n_heads: int + n_kv_heads: int + num_layers: int + d_ff: int + sample_key: str = "input_ids" + prediction_key: str = "logits" + attn_dropout: float = 0.0 + ffn_dropout: float = 0.0 + tie_embeddings: bool = False + norm_eps: float = 1e-6 + rope_base: float = 1000000.0 + + moe_num_experts: int = 128 + moe_top_k: int = 8 + moe_d_ff: int = 768 + moe_capacity_factor: float = 1.25 + moe_min_capacity: int = 4 + moe_overflow_policy: Literal["drop", "residual"] = "residual" + moe_router_noise_std: float = 0.0 + moe_router_temperature: float = 1.0 + moe_router_dropout: float = 0.0 + moe_aux_loss_coef: float = 0.001 + moe_z_loss_coef: float = 0.0 + + +class RMSNorm(nn.Module): + def __init__(self, d_model: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d_model)) + + def reset_parameters(self): + nn.init.ones_(self.weight) + + def forward(self, x): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim: int, max_seq_len: int, base: float = 1000000.0): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.base = base + self.register_buffer("cos_cached", None, persistent=False) + self.register_buffer("sin_cached", None, persistent=False) + + def _compute_cache(self, device): + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim)) + t = torch.arange(self.max_seq_len, device=device).float() + freqs = torch.outer(t, inv_freq) + emb = torch.cat([freqs, freqs], dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + def forward(self, x: torch.Tensor, seq_len: int): + if self.cos_cached is None: + self._compute_cache(x.device) + return ( + self.cos_cached[:, :, :seq_len, :].to(x.dtype), + self.sin_cached[:, :, :seq_len, :].to(x.dtype), + ) + + +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat([-x2, x1], dim=-1) + + +def apply_rotary_emb(q, k, cos, sin): + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +class GroupedQueryAttention(nn.Module): + def __init__(self, d_model, n_heads, n_kv_heads, max_seq_len, rope_base, norm_eps, attn_dropout): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = d_model // n_heads + self.n_rep = n_heads // n_kv_heads + + self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False) + + self.q_norm = RMSNorm(self.head_dim, eps=norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=norm_eps) + + self.rope = RotaryEmbedding(self.head_dim, max_seq_len, base=rope_base) + self.dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else nn.Identity() + + def forward(self, x, mask=None): + B, T, _ = x.shape + + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + + q = self.q_norm(q) + k = self.k_norm(k) + + cos, sin = self.rope(q, seq_len=T) + q, k = apply_rotary_emb(q, k, cos, sin) + + if self.n_rep > 1: + k = ( + k.unsqueeze(2) + .expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim) + .reshape(B, self.n_heads, T, self.head_dim) + ) + v = ( + v.unsqueeze(2) + .expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim) + .reshape(B, self.n_heads, T, self.head_dim) + ) + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=mask is None) + return self.o_proj(out.transpose(1, 2).contiguous().view(B, T, -1)) + + +class GroupedExperts(nn.Module): + def __init__( + self, + num_experts, + d_model, + d_ff, + ffn_dropout, + ): + super().__init__() + self.num_experts = num_experts + self.d_model = d_model + self.d_ff = d_ff + self.dropout = nn.Dropout(ffn_dropout) if ffn_dropout > 0 else nn.Identity() + + self.w1 = nn.Parameter(torch.empty(self.num_experts, self.d_ff, self.d_model)) + self.w2 = nn.Parameter(torch.empty(self.num_experts, self.d_ff, self.d_model)) + self.w3 = nn.Parameter(torch.empty(self.num_experts, self.d_model, self.d_ff)) + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5)) + + def _forward_local(self, routed_input: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + w1 = self.w1.to_local() if DTensor is not None and isinstance(self.w1, DTensor) else self.w1 + w2 = self.w2.to_local() if DTensor is not None and isinstance(self.w2, DTensor) else self.w2 + w3 = self.w3.to_local() if DTensor is not None and isinstance(self.w3, DTensor) else self.w3 + local_num_tokens = ( + num_tokens_per_expert.to_local() + if DTensor is not None and isinstance(num_tokens_per_expert, DTensor) + else num_tokens_per_expert + ) + + outputs: list[torch.Tensor] = [] + start = 0 + total_rows = routed_input.shape[0] + + for expert_idx, num_tokens in enumerate(local_num_tokens.tolist()): + requested_tokens = int(num_tokens) + end = start + requested_tokens + local_end = min(end, total_rows) + expert_input = routed_input[start:local_end] + real_tokens = int(expert_input.shape[0]) + + out_parts: list[torch.Tensor] = [] + if real_tokens > 0: + x1 = F.linear(expert_input, w1[expert_idx]) + x2 = F.linear(expert_input, w2[expert_idx]) + out_parts.append(self.dropout(F.linear(F.silu(x1) * x2, w3[expert_idx]))) + + pad = requested_tokens - real_tokens + if pad > 0: + out_parts.append(routed_input.new_zeros((pad, self.d_model))) + + if out_parts: + outputs.append(torch.cat(out_parts, dim=0) if len(out_parts) > 1 else out_parts[0]) + + start = end + + if not outputs: + return routed_input.new_zeros((0, self.d_model)) + + out = torch.cat(outputs, dim=0) + if out.shape[0] < total_rows: + out = torch.cat([out, routed_input.new_zeros((total_rows - out.shape[0], self.d_model))], dim=0) + elif out.shape[0] > total_rows: + out = out[:total_rows] + return out + + def forward(self, routed_input: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + return self._forward_local(routed_input, num_tokens_per_expert) + + +class MoEBlock(nn.Module): + def __init__( + self, + d_model: int, + moe_d_ff: int, + moe_num_experts: int, + moe_top_k: int, + moe_capacity_factor: float, + moe_min_capacity: int, + moe_overflow_policy: str, + moe_router_noise_std: float, + moe_router_temperature: float, + moe_router_dropout: float, + moe_aux_loss_coef: float, + moe_z_loss_coef: float, + ffn_dropout: float, + ): + super().__init__() + self.num_experts = moe_num_experts + self.top_k = moe_top_k + self.capacity_factor = moe_capacity_factor + self.min_capacity = moe_min_capacity + self.overflow_policy = moe_overflow_policy + self.router_noise_std = moe_router_noise_std + self.router_dropout = nn.Dropout(moe_router_dropout) if moe_router_dropout > 0 else nn.Identity() + self.router_temperature = moe_router_temperature + self.aux_loss_coef = moe_aux_loss_coef + self.z_loss_coef = moe_z_loss_coef + + self.router = nn.Linear(d_model, self.num_experts, bias=False) + self.experts = GroupedExperts( + num_experts=moe_num_experts, d_model=d_model, d_ff=moe_d_ff, ffn_dropout=ffn_dropout + ) + self.last_aux_loss: Optional[torch.Tensor] = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, D = x.shape + E = self.num_experts + K = self.top_k + N = B * T + x_flat = x.view(N, D) + + logits = self.router(self.router_dropout(x_flat).to(self.router.weight.dtype)).float() + if self.router_noise_std > 0 and self.training: + logits = logits + torch.randn_like(logits) * self.router_noise_std + logits = logits / self.router_temperature + + probs = torch.softmax(logits, dim=-1) + topk_val, topk_idx = torch.topk(probs, k=K, dim=-1) + topk_w = (topk_val / (topk_val.sum(dim=-1, keepdim=True) + 1e-9)).to(x_flat.dtype) + + capacity = max(math.ceil(self.capacity_factor * N / E), self.min_capacity) + + dispatch_mask = F.one_hot(topk_idx, num_classes=E).to(x_flat.dtype) + positions = torch.cumsum(dispatch_mask.sum(dim=1), dim=0) + capacity_mask = (positions <= capacity).to(x_flat.dtype) + final_mask = dispatch_mask * capacity_mask.unsqueeze(1) + + load = final_mask.sum(dim=[0, 1]) + importance = probs.sum(dim=0) + + flat_valid = capacity_mask.gather(1, topk_idx).bool().reshape(-1) + flat_token_ids = torch.arange(N, device=x.device).unsqueeze(1).expand(N, K).reshape(-1)[flat_valid] + flat_expert_ids = topk_idx.reshape(-1)[flat_valid] + flat_weights = topk_w.reshape(-1)[flat_valid] + + if flat_expert_ids.numel() > 0: + sort_idx = torch.argsort(flat_expert_ids) + token_ids_sorted = flat_token_ids[sort_idx] + expert_ids_sorted = flat_expert_ids[sort_idx] + weights_sorted = flat_weights[sort_idx] + + routed_output = self.experts(x_flat[token_ids_sorted], torch.bincount(expert_ids_sorted, minlength=E)) + weighted_output = routed_output * weights_sorted.unsqueeze(-1) + + out = x_flat.new_zeros((N, D)) + out.index_add_(0, token_ids_sorted, weighted_output) + assigned = x_flat.new_zeros((N,)) + assigned.index_add_(0, token_ids_sorted, weights_sorted) + else: + out = x_flat.new_zeros((N, D)) + assigned = x_flat.new_zeros((N,)) + + not_assigned = assigned < 1e-6 + if not_assigned.any() and self.overflow_policy == "residual": + out[not_assigned] = x_flat[not_assigned] + + aux = None + if self.aux_loss_coef > 0: + imp = importance / (importance.sum() + 1e-9) + ld = load / (load.sum() + 1e-9) + aux = self.aux_loss_coef * E * torch.sum(imp * ld) + if self.z_loss_coef > 0: + z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) + aux = (aux if aux is not None else torch.tensor(0.0, device=x.device)) + self.z_loss_coef * z_loss + + self.last_aux_loss = aux + return out.view(B, T, D) + + +class DenseMLP(nn.Module): + def __init__(self, d_model, d_ff, ffn_dropout): + super().__init__() + self.w1 = nn.Linear(d_model, d_ff, bias=False) + self.w2 = nn.Linear(d_model, d_ff, bias=False) + self.w3 = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(ffn_dropout) if ffn_dropout > 0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x))) + + +class TransformerBlock(nn.Module): + def __init__( + self, + d_model: int, + d_ff: int, + n_heads: int, + n_kv_heads: int, + max_seq_len: int, + rope_base: float, + norm_eps: float, + attn_dropout: float, + ffn_dropout: float, + moe_d_ff: int = 768, + moe_num_experts: int = 128, + moe_top_k: int = 8, + moe_capacity_factor: float = 1.25, + moe_min_capacity: int = 4, + moe_overflow_policy: str = "residual", + moe_router_noise_std: float = 0.0, + moe_router_temperature: float = 1.0, + moe_router_dropout: float = 0.0, + moe_aux_loss_coef: float = 0.001, + moe_z_loss_coef: float = 0.0, + ): + super().__init__() + self.pre_attn_norm = RMSNorm(d_model, eps=norm_eps) + self.attn = GroupedQueryAttention( + d_model=d_model, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + max_seq_len=max_seq_len, + rope_base=rope_base, + norm_eps=norm_eps, + attn_dropout=attn_dropout, + ) + self.pre_ffn_norm = RMSNorm(d_model, eps=norm_eps) + self.ffn = MoEBlock( + d_model=d_model, + moe_d_ff=moe_d_ff, + moe_num_experts=moe_num_experts, + moe_top_k=moe_top_k, + moe_capacity_factor=moe_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_overflow_policy=moe_overflow_policy, + moe_router_noise_std=moe_router_noise_std, + moe_router_temperature=moe_router_temperature, + moe_router_dropout=moe_router_dropout, + moe_aux_loss_coef=moe_aux_loss_coef, + moe_z_loss_coef=moe_z_loss_coef, + ffn_dropout=ffn_dropout, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.pre_attn_norm(x)) + x = x + self.ffn(self.pre_ffn_norm(x)) + return x + + @property + def aux_loss(self) -> Optional[torch.Tensor]: + return getattr(self.ffn, "last_aux_loss", None) + + +class QwenModel(nn.Module): + def __init__( + self, + vocab_size: int, + max_seq_len: int, + d_model: int, + n_heads: int, + n_kv_heads: int, + d_ff: int, + num_layers: int, + moe_d_ff: int = 768, + sample_key: str = "input_ids", + prediction_key: str = "logits", + attn_dropout: float = 0.0, + ffn_dropout: float = 0.0, + tie_embeddings: bool = False, + norm_eps: float = 1e-6, + rope_base: float = 1000000.0, + moe_num_experts: int = 128, + moe_top_k: int = 8, + moe_capacity_factor: float = 1.25, + moe_min_capacity: int = 4, + moe_overflow_policy: str = "residual", + moe_router_noise_std: float = 0.0, + moe_router_temperature: float = 1.0, + moe_router_dropout: float = 0.0, + moe_aux_loss_coef: float = 0.001, + moe_z_loss_coef: float = 0.0, + ): + super().__init__() + self.sample_key = sample_key + self.prediction_key = prediction_key + + self.token_emb = nn.Embedding(vocab_size, d_model) + + self.layers = nn.ModuleDict( + { + str(i): TransformerBlock( + d_model=d_model, + d_ff=d_ff, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + max_seq_len=max_seq_len, + rope_base=rope_base, + norm_eps=norm_eps, + attn_dropout=attn_dropout, + ffn_dropout=ffn_dropout, + moe_d_ff=moe_d_ff, + moe_num_experts=moe_num_experts, + moe_top_k=moe_top_k, + moe_capacity_factor=moe_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_overflow_policy=moe_overflow_policy, + moe_router_noise_std=moe_router_noise_std, + moe_router_temperature=moe_router_temperature, + moe_router_dropout=moe_router_dropout, + moe_aux_loss_coef=moe_aux_loss_coef, + moe_z_loss_coef=moe_z_loss_coef, + ) + for i in range(num_layers) + } + ) + + self.final_norm = RMSNorm(d_model, eps=norm_eps) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False) + + if tie_embeddings: + self.lm_head.weight = self.token_emb.weight + + @property + def weight_decay_groups(self): + return { + "linear": ["q_proj", "k_proj", "v_proj", "o_proj", "lm_head", "router", "w1", "w2", "w3"], + "embedding": ["token_emb"], + "layernorm": ["pre_attn_norm", "pre_ffn_norm", "final_norm", "q_norm", "k_norm"], + } + + def forward(self, inputs): + if isinstance(inputs, dict): + return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} + return self.forward_impl(inputs) + + def forward_impl(self, input_ids: torch.Tensor) -> torch.Tensor: + x = self.token_emb(input_ids) + for layer in self.layers.values(): + x = layer(x) + return self.lm_head(self.final_norm(x)) + + +if __name__ == "__main__": + torch.manual_seed(0) + + model = QwenModel( + vocab_size=151936, + max_seq_len=4096, + d_model=2048, + n_heads=32, + n_kv_heads=8, + d_ff=6144, + moe_d_ff=768, + num_layers=48, + moe_num_experts=128, + moe_top_k=8, + ) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Parametri: {num_params/1e9:.2f}B") + + x = torch.randint(0, 151936, (2, 64)) + logits = model(x) + print(f"Output: {logits.shape}") + + loss = logits.mean() + loss.backward() + print("Backward OK") diff --git a/moe/modalities_moe/optimizers/__init__.py b/moe/modalities_moe/optimizers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/moe/modalities_moe/optimizers/ep_adamw.py b/moe/modalities_moe/optimizers/ep_adamw.py new file mode 100644 index 000000000..d7d19fe9c --- /dev/null +++ b/moe/modalities_moe/optimizers/ep_adamw.py @@ -0,0 +1,169 @@ +import torch +import torch.distributed as dist +from pydantic import BaseModel +from torch.distributed.tensor import DTensor +from torch.nn import Module +from torch.optim import AdamW, Optimizer + +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticPytorchModuleOrListType +from modalities.optimizers.optimizer_factory import _build_optimizer_groups_via_weight_decay_split + + +class EPAdamWConfig(BaseModel): + wrapped_model: PydanticPytorchModuleOrListType + device_mesh: PydanticDeviceMeshIFType + lr: float + betas: tuple[float, float] + eps: float + weight_decay: float + weight_decay_groups_excluded: list[str] + + class Config: + arbitrary_types_allowed = True + + +def _get_ep_param_ids(model: Module) -> set: + return {id(p) for m in model.modules() if getattr(m, "_ep_enabled", False) for p in m.parameters(recurse=False)} + + +def _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_groups_excluded): + weight_decay_groups = model.weight_decay_groups + params = { + name: p + for name, p in model.named_parameters() + if p.requires_grad and id(p) not in ep_param_ids and (not isinstance(p, DTensor) or p.to_local().numel() > 0) + } + return _build_optimizer_groups_via_weight_decay_split( + weight_decay, weight_decay_groups_excluded, weight_decay_groups, params + ) + + +class EPAdamW(Optimizer): + """ + ZeRO stage-1 for EP (DTensor) params + standard AdamW for dense params. + + Each dp_shard rank stores optimizer states for 1/dp_shard of the EP params. + After each step, updated EP param values are broadcast from owner to all ranks. + Dense params are handled by a separate AdamW (FSDP2 shards them independently). + """ + + def __init__( + self, + model: Module, + device_mesh, + lr: float, + betas: tuple[float, float], + eps: float, + weight_decay: float, + weight_decay_groups_excluded: list[str], + ): + self._dp_mesh = device_mesh["dp_shard"] + self._dp_group = self._dp_mesh.get_group() + self._dp_rank = dist.get_rank(self._dp_group) + self._dp_size = dist.get_world_size(self._dp_group) + + ep_param_ids = _get_ep_param_ids(model) + self._all_ep_params = [p for p in model.parameters() if id(p) in ep_param_ids] + + # rank r owns params[r::dp_size] + self._owned_ep_params = self._all_ep_params[self._dp_rank :: self._dp_size] + + dense_groups = _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_groups_excluded) + + if self._owned_ep_params: + self._ep_adamw = AdamW(self._owned_ep_params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + else: + self._ep_adamw = None + self._dense_adamw = AdamW(dense_groups, lr=lr, betas=betas, eps=eps) + + # unified param groups for lr_scheduler compatibility: + # group 0 = all EP params, groups 1+ = dense weight-decay split + ep_group = {"params": self._all_ep_params, "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay} + all_groups = [ep_group] + [{**g, "lr": lr, "betas": betas, "eps": eps} for g in dense_groups] + super().__init__(all_groups, {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # all-reduce + for p in self._all_ep_params: + if p.grad is None: + continue + if isinstance(p.grad, DTensor): + local_g = p.grad.to_local() + dist.all_reduce(local_g, op=dist.ReduceOp.SUM, group=self._dp_group) + local_g.div_(self._dp_size) + else: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM, group=self._dp_group) + p.grad.div_(self._dp_size) + + # Sync lr + if self._ep_adamw is not None: + self._ep_adamw.param_groups[0]["lr"] = self.param_groups[0]["lr"] + for i, g in enumerate(self._dense_adamw.param_groups): + g["lr"] = self.param_groups[i + 1]["lr"] + + # Update ep params + if self._ep_adamw is not None: + self._ep_adamw.step() + + # Update dense params + self._dense_adamw.step() + + # broadcast updated EP param local tensors + for i, p in enumerate(self._all_ep_params): + owner_local_rank = i % self._dp_size + owner_global_rank = dist.get_global_rank(self._dp_group, owner_local_rank) + if isinstance(p, DTensor): + local_tensor = p.to_local() + elif isinstance(p.data, DTensor): + local_tensor = p.data.to_local() + else: + local_tensor = p.data + dist.broadcast(local_tensor, src=owner_global_rank, group=self._dp_group) + + return loss + + def zero_grad(self, set_to_none: bool = True): + for p in self._all_ep_params: + if set_to_none: + p.grad = None + elif p.grad is not None: + p.grad.detach_() + p.grad.zero_() + self._dense_adamw.zero_grad(set_to_none=set_to_none) + + def state_dict(self) -> dict: + return { + "ep_adamw": self._ep_adamw.state_dict() if self._ep_adamw is not None else {}, + "dense_adamw": self._dense_adamw.state_dict(), + } + + def load_state_dict(self, state_dict: dict) -> None: + if self._ep_adamw is not None and state_dict["ep_adamw"]: + self._ep_adamw.load_state_dict(state_dict["ep_adamw"]) + self._dense_adamw.load_state_dict(state_dict["dense_adamw"]) + + +def get_ep_adam_w( + wrapped_model, + device_mesh, + lr: float, + betas: tuple[float, float], + eps: float, + weight_decay: float, + weight_decay_groups_excluded: list[str], +) -> EPAdamW: + return EPAdamW( + model=wrapped_model, + device_mesh=device_mesh, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + weight_decay_groups_excluded=weight_decay_groups_excluded, + ) diff --git a/moe/modalities_moe/training/__init__.py b/moe/modalities_moe/training/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/moe/modalities_moe/training/gradient_clipping/__init__.py b/moe/modalities_moe/training/gradient_clipping/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/moe/modalities_moe/training/gradient_clipping/ep_gradient_clipper.py b/moe/modalities_moe/training/gradient_clipping/ep_gradient_clipper.py new file mode 100644 index 000000000..0581e3634 --- /dev/null +++ b/moe/modalities_moe/training/gradient_clipping/ep_gradient_clipper.py @@ -0,0 +1,90 @@ +import math +from typing import Optional + +import torch +from torch import distributed as dist +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import FSDPModule as FSDP2 +from torch.distributed.tensor import DTensor + +from modalities.running_env.fsdp.device_mesh import ( + ParallelismDegrees, + get_mesh_for_parallelism_method, + has_parallelism_method, +) +from modalities.training.gradient_clipping.fsdp_gradient_clipper import FSDP2GradientClipper, GradientClippingMode + + +class EPGradientClipper(FSDP2GradientClipper): + """FSDP2 clipper wrapper for EP adaptation""" + + def __init__( + self, + model_parts: FSDP2 | list[FSDP2], + max_norm: float, + norm_type: GradientClippingMode, + device_mesh: Optional[DeviceMesh] = None, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, + ) -> None: + super().__init__( + model_parts=model_parts, + max_norm=max_norm, + norm_type=norm_type, + device_mesh=device_mesh, + error_if_nonfinite=error_if_nonfinite, + foreach=foreach, + ) + + @torch.no_grad() + def clip_gradients(self) -> torch.Tensor: + grads = [p.grad for model in self.models for p in model.parameters() if p.grad is not None] + + if len(grads) == 0: + device = ( + torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu") + ) + total_norm = torch.tensor(0.0, device=device) + else: + norm_type_val = self.norm_type.value + first_grad = grads[0] + first_device = first_grad.to_local().device if isinstance(first_grad, DTensor) else first_grad.device + norm_scalars: list[torch.Tensor] = [] + + for grad in grads: + grad_norm = torch.linalg.vector_norm(grad, ord=norm_type_val) + if isinstance(grad_norm, DTensor): + # Reduce each partial norm inside its own mesh before aggregation. + grad_norm = grad_norm.full_tensor() + norm_scalars.append(grad_norm.to(first_device)) + + if math.isinf(norm_type_val): + total_norm = torch.max(torch.stack(norm_scalars)) + else: + total_norm = torch.linalg.vector_norm(torch.stack(norm_scalars), ord=norm_type_val) + + if self.error_if_nonfinite and (torch.isnan(total_norm) or torch.isinf(total_norm)): + raise RuntimeError( + f"The total norm of order {norm_type_val} for gradients is non-finite: {total_norm.item()}" + ) + + if has_parallelism_method(self.device_mesh, ParallelismDegrees.PP): + pp_mesh = get_mesh_for_parallelism_method( + device_mesh=self.device_mesh, parallelism_method=ParallelismDegrees.PP + ) + if math.isinf(self.norm_type.value): + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) + else: + total_norm **= self.norm_type.value + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + total_norm **= 1.0 / self.norm_type.value + + # do not use torch.nn.utils.clip_grads_with_norm_ here: it batches grads with + # torch._foreach_mul_, which fails when the list mixes DTensors from different meshes. + clip_coef = self.max_norm / (total_norm + 1e-6) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + + for grad in grads: + grad_device = grad.to_local().device if isinstance(grad, DTensor) else grad.device + grad.mul_(clip_coef_clamped.to(grad_device)) + return total_norm diff --git a/moe/scripts/monitor_gpus.sh b/moe/scripts/monitor_gpus.sh new file mode 100755 index 000000000..7c221ffb1 --- /dev/null +++ b/moe/scripts/monitor_gpus.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +# GPU Monitoring Script - saves metrics to CSV with timestamps +# Usage: ./monitor_gpus.sh [interval_seconds] [output_file] + +INTERVAL=${1:-5} # Default: sample every 5 seconds +OUTPUT=${2:-logs/gpu_metrics_$(date +%Y%m%d_%H%M%S).csv} +PIDFILE=/tmp/gpu_monitor_$$.pid + +echo "Starting GPU monitoring..." +echo "Interval: ${INTERVAL}s" +echo "Output: ${OUTPUT}" +echo "PID file: ${PIDFILE}" + +# Create output directory if needed +mkdir -p "$(dirname "$OUTPUT")" + +# Save PID for cleanup +echo $$ > "$PIDFILE" + +# Write CSV header +echo "timestamp,gpu_id,memory_used_mb,memory_total_mb,memory_util_pct,gpu_util_pct,temperature_c,power_draw_w,power_limit_w" > "$OUTPUT" + +# Cleanup function +cleanup() { + echo "" + echo "Stopping GPU monitoring..." + rm -f "$PIDFILE" + echo "Metrics saved to: $OUTPUT" + exit 0 +} + +trap cleanup SIGINT SIGTERM EXIT + +# Monitoring loop +while true; do + TIMESTAMP=$(date +%Y-%m-%d\ %H:%M:%S) + + # Query nvidia-smi for all metrics at once + nvidia-smi --query-gpu=index,memory.used,memory.total,utilization.memory,utilization.gpu,temperature.gpu,power.draw,power.limit \ + --format=csv,noheader,nounits 2>/dev/null | while IFS=',' read -r gpu_id mem_used mem_total mem_util gpu_util temp power power_limit; do + # Trim whitespace + gpu_id=$(echo "$gpu_id" | xargs) + mem_used=$(echo "$mem_used" | xargs) + mem_total=$(echo "$mem_total" | xargs) + mem_util=$(echo "$mem_util" | xargs) + gpu_util=$(echo "$gpu_util" | xargs) + temp=$(echo "$temp" | xargs) + power=$(echo "$power" | xargs) + power_limit=$(echo "$power_limit" | xargs) + + # Write to CSV + echo "$TIMESTAMP,$gpu_id,$mem_used,$mem_total,$mem_util,$gpu_util,$temp,$power,$power_limit" >> "$OUTPUT" + done + + # Live display (optional, comment out if too verbose) + # echo "[$(date +%H:%M:%S)] Logged GPU metrics ($(wc -l < "$OUTPUT") samples)" + + sleep "$INTERVAL" +done diff --git a/moe/scripts/train_ep.py b/moe/scripts/train_ep.py new file mode 100644 index 000000000..232ef287f --- /dev/null +++ b/moe/scripts/train_ep.py @@ -0,0 +1,195 @@ +# ruff: noqa: E402 + +import os +import sys +from pathlib import Path + +MOE_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(MOE_ROOT)) + +import torch +import torch.distributed as dist +from modalities_moe.config.config import EPWrappedModelConfig, MoECrossEntropyLossConfig +from modalities_moe.loss_functions import MoECrossEntropyLoss +from modalities_moe.models.model_factory import get_ep_wrapped_model +from modalities_moe.models.moe.qwen_model import QwenModel, QwenModelConfig +from modalities_moe.optimizers.ep_adamw import EPAdamWConfig, get_ep_adam_w +from modalities_moe.training.gradient_clipping.ep_gradient_clipper import EPGradientClipper +from torch.distributed.tensor import DTensor + +from modalities.__main__ import Main +from modalities.config.config import ProcessGroupBackendType +from modalities.config.instantiation_models import TrainingComponentsInstantiationModel +from modalities.running_env.cuda_env import CudaEnv +from modalities.training.gradient_clipping.fsdp_gradient_clipper_config import FSDP2GradientClipperConfig + +cwd = Path(__file__).resolve().parent.parent +os.chdir(cwd) +CONFIG_FILE_PATH = cwd / "config" / "qwen_config.yaml" +EXPERIMENTS_ROOT_PATH = cwd / "results" / "debug" + + +# TODO solve this +def _enable_torchtitan_moe_permute_fallback() -> ( + None +): # VIBECODATA because of Triton C error with Python headers don't know what that is + """Avoid Triton JIT build for MoE permute indices on systems without Python dev headers.""" + try: + import torchtitan.models.moe.kernels as kernels + import torchtitan.models.moe.utils as moe_utils + except Exception: + return + + if getattr(kernels, "_modalities_fallback_enabled", False): + return + + def _fill_indices_torch( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + ) -> torch.Tensor: + device = tokens_per_expert_group.device + permuted_indices = torch.full((max_len,), -1, dtype=torch.int32, device=device) + + for e in range(experts_per_rank): + write_start = int(write_offsets[e].item()) + for r in range(num_ranks): + i = r * experts_per_rank + e + start_index = int(start_index_values[i].item()) + length = int(tokens_per_expert_group[i].item()) + if length > 0: + end_idx = min(write_start + length, max_len) + permuted_indices[write_start:end_idx] = torch.arange( + start_index, + start_index + (end_idx - write_start), + dtype=torch.int32, + device=device, + ) + write_start += length + + return permuted_indices + + _orig_generate_permute_indices = kernels.generate_permute_indices + + def _generate_permute_indices_no_triton( + tokens_per_expert_group: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + alignment: int, + use_cpu: bool = False, + ): + del use_cpu + start_index_values = torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group + total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) + total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment) + m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(torch.int32) + m_offsets = torch.cumsum(m_sizes, 0) + write_offsets = m_offsets - m_sizes + + permuted_indices = _fill_indices_torch( + tokens_per_expert_group=tokens_per_expert_group, + start_index_values=start_index_values, + write_offsets=write_offsets, + experts_per_rank=experts_per_rank, + num_ranks=num_ranks, + max_len=max_len, + ) + return permuted_indices, m_sizes, m_offsets.to(torch.int32) + + kernels.generate_permute_indices = _generate_permute_indices_no_triton + moe_utils.generate_permute_indices = _generate_permute_indices_no_triton + kernels._modalities_fallback_enabled = True + kernels._modalities_generate_permute_indices_original = _orig_generate_permute_indices + + +def debug_ep(model): + # Stima memoria teorica + total_params = sum(p.numel() for p in model.parameters()) + ep_params = sum( + p.numel() for m in model.modules() if getattr(m, "_ep_enabled", False) for p in m.parameters(recurse=False) + ) + dense_params = total_params - ep_params + + print(f"Params totali: {total_params/1e6:.0f}M") + print(f"Params EP (non shardati): {ep_params/1e6:.0f}M") + print(f"Params densi (shardati su dp_shard): {dense_params/1e6:.0f}M") + + rank = dist.get_rank() + free, total = torch.cuda.mem_get_info() + print(f"[rank{rank}] Memoria dopo init: {(total-free)/1e9:.1f} GB usati") + + +def main(): + _enable_torchtitan_moe_permute_fallback() + EXPERIMENTS_ROOT_PATH.mkdir(parents=True, exist_ok=True) + + with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + modalities_main = Main( + config_path=CONFIG_FILE_PATH, + experiments_root_path=EXPERIMENTS_ROOT_PATH, + ) + modalities_main.add_custom_component( + component_key="model", + variant_key="ep_wrapped", + custom_component=get_ep_wrapped_model, + custom_config=EPWrappedModelConfig, + ) + + modalities_main.add_custom_component( + component_key="model", variant_key="moe", custom_component=QwenModel, custom_config=QwenModelConfig + ) + + modalities_main.add_custom_component( + component_key="gradient_clipper", + variant_key="ep", + custom_component=EPGradientClipper, + custom_config=FSDP2GradientClipperConfig, + ) + + modalities_main.add_custom_component( + component_key="loss", + variant_key="moe_cross_entropy", + custom_component=MoECrossEntropyLoss, + custom_config=MoECrossEntropyLossConfig, + ) + + modalities_main.add_custom_component( + component_key="optimizer", + variant_key="ep_adam_w", + custom_component=get_ep_adam_w, + custom_config=EPAdamWConfig, + ) + + components: TrainingComponentsInstantiationModel = modalities_main.build_components( + components_model_type=TrainingComponentsInstantiationModel + ) + + # WORKAROUNDS (wip) + # TODO implement those into moe code + # 1. some parameters remain on cpu + device = torch.device(f"cuda:{torch.cuda.current_device()}") + for name, param in components.model_raw.named_parameters(): + if param.device.type == "cpu": + param.data = param.data.to(device) + + # 2. cast EP params to bf16 — FSDP2 skips them via ignored_params, so they stay + # fp32 from model init. Cast here to match the MixedPrecisionPolicy applied to + # dense params (param_dtype=BF_16). Halves EP memory: 29 GB → 14.5 GB at tp=4. + for mod in components.model_raw.modules(): + if getattr(mod, "_ep_enabled", False): + for pname, p in list(mod._parameters.items()): + if isinstance(p, DTensor) and p.dtype != torch.bfloat16: + bf16_local = p.to_local().to(torch.bfloat16) + bf16_p = DTensor.from_local(bf16_local, p.device_mesh, p.placements, run_check=False) + mod._parameters[pname] = torch.nn.Parameter(bf16_p, requires_grad=p.requires_grad) + + debug_ep(components.model_raw) + modalities_main.run(components) + + +if __name__ == "__main__": + main() diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 62933794d..acef23f71 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -212,6 +212,12 @@ def get_fsdp2_wrapped_model( modules = list(model.modules()) + # Collect EP parameters to exclude from FSDP2 sharding + ep_params = { + p for m in model.modules() if getattr(m, "_ep_enabled", False) for p in m.parameters(recurse=False) + } + ignored_params = ep_params if ep_params else None + # we first shard all the blocks grouped_modules: list[nn.Module] = [] module_id = 0 @@ -226,6 +232,7 @@ def get_fsdp2_wrapped_model( grouped_modules, **fsdp_config, reshard_after_forward=reshard_block_after_forward, + ignored_params=ignored_params, ) grouped_modules = list() @@ -235,10 +242,11 @@ def get_fsdp2_wrapped_model( grouped_modules, **fsdp_config, reshard_after_forward=reshard_block_after_forward, + ignored_params=ignored_params, ) # finally, we shard the entire model - fully_shard(model, **fsdp_config, reshard_after_forward=reshard_after_forward) + fully_shard(model, **fsdp_config, reshard_after_forward=reshard_after_forward, ignored_params=ignored_params) logger.info( f"Rank {dist.get_rank()} sharded number of parameters: " f"{get_local_number_of_trainable_parameters(model)}" From 7dbd86d0fe0722e357c4e10c42fb6dbed043a000 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Jun 2026 11:08:03 +0000 Subject: [PATCH 02/12] refactor(moe): Merge MoE implementation into modalities core --- .../config_lorem_ipsum_long_moe_ep_fsdp2.yaml | 382 +++++++++++++ moe/modalities_moe/__init__.py | 0 moe/modalities_moe/config/__init__.py | 0 moe/modalities_moe/config/config.py | 22 - moe/modalities_moe/models/__init__.py | 0 moe/modalities_moe/models/moe/__init__.py | 0 moe/modalities_moe/models/moe/moe_model.py | 537 ------------------ moe/modalities_moe/optimizers/__init__.py | 0 moe/modalities_moe/training/__init__.py | 0 .../training/gradient_clipping/__init__.py | 0 moe/scripts/train_ep.py | 52 +- src/modalities/config/config.py | 30 + src/modalities/models/moe/__init__.py | 10 + .../modalities/models/moe}/loss_functions.py | 3 +- .../modalities/models/moe}/model_factory.py | 16 +- .../modalities}/models/moe/qwen_model.py | 69 +-- .../modalities}/optimizers/ep_adamw.py | 35 +- src/modalities/registry/components.py | 13 + .../gradient_clipping/ep_gradient_clipper.py | 5 +- 19 files changed, 464 insertions(+), 710 deletions(-) create mode 100644 config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml delete mode 100644 moe/modalities_moe/__init__.py delete mode 100644 moe/modalities_moe/config/__init__.py delete mode 100644 moe/modalities_moe/config/config.py delete mode 100644 moe/modalities_moe/models/__init__.py delete mode 100644 moe/modalities_moe/models/moe/__init__.py delete mode 100644 moe/modalities_moe/models/moe/moe_model.py delete mode 100644 moe/modalities_moe/optimizers/__init__.py delete mode 100644 moe/modalities_moe/training/__init__.py delete mode 100644 moe/modalities_moe/training/gradient_clipping/__init__.py create mode 100644 src/modalities/models/moe/__init__.py rename {moe/modalities_moe => src/modalities/models/moe}/loss_functions.py (92%) rename {moe/modalities_moe/models => src/modalities/models/moe}/model_factory.py (86%) rename {moe/modalities_moe => src/modalities}/models/moe/qwen_model.py (92%) rename {moe/modalities_moe => src/modalities}/optimizers/ep_adamw.py (80%) rename {moe/modalities_moe => src/modalities}/training/gradient_clipping/ep_gradient_clipper.py (91%) diff --git a/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml b/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml new file mode 100644 index 000000000..577656b3b --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml @@ -0,0 +1,382 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum_long.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + experiments_root_path: ${modalities_env:experiments_root_path} + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 32 + evaluation_interval_in_steps: 32 + consistency_enforcement: + enforce_tokens_per_step_consistency: false + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 1 + sequence_length: 256 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + drop_last: true + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: moe_cross_entropy + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + model: + instance_key: model_raw + pass_type: BY_REFERENCE + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: -1 + tensor_parallel_degree: 4 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.num_layers} + multi_device_generator_policy: error + +ep_model: + component_key: model + variant_key: ep_wrapped + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + ep_mesh_dim_name: tp + block_names: [TransformerBlock] + +ac_model: + component_key: model + variant_key: activation_checkpointed + config: + model: + instance_key: ep_model + pass_type: BY_REFERENCE + ac_variant: full_activation_checkpointing + layers_fqn: layers + ac_fun_params: + ac_freq: 1 + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: ac_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + reshard_after_forward: true + block_names: [TransformerBlock] + +model_raw: + component_key: model + variant_key: moe + config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 + max_seq_len: ${settings.step_profile.sequence_length} + d_model: 128 + n_heads: 8 + n_kv_heads: 4 + num_layers: 2 + d_ff: 128 + attn_dropout: 0.0 + ffn_dropout: 0.0 + tie_embeddings: false + norm_eps: 1e-6 + rope_base: 1000000.0 + moe_num_experts: 8 + moe_top_k: 2 + moe_d_ff: 128 + moe_capacity_factor: 1.25 + moe_min_capacity: 4 + moe_overflow_policy: residual + moe_router_noise_std: 0.0 + moe_router_temperature: 1.0 + moe_router_dropout: 0.0 + moe_aux_loss_coef: 0.001 + moe_z_loss_coef: 0.0 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: ep_adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: ep + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +mfu_calculator: + component_key: mfu_calculator + variant_key: gpt2 + config: + n_layer: ${model_raw.config.num_layers} + sequence_length: ${settings.step_profile.sequence_length} + n_embd: ${model_raw.config.d_model} + world_size: ${settings.cuda_env.world_size} + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE diff --git a/moe/modalities_moe/__init__.py b/moe/modalities_moe/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/config/__init__.py b/moe/modalities_moe/config/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/config/config.py b/moe/modalities_moe/config/config.py deleted file mode 100644 index 0ab14372a..000000000 --- a/moe/modalities_moe/config/config.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Any - -from pydantic import BaseModel - -from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticPytorchModuleOrListType - - -class MoECrossEntropyLossConfig(BaseModel): - target_key: str - prediction_key: str - model: Any - tag: str = "MoECrossEntropyLoss" - - class Config: - arbitrary_types_allowed = True - - -class EPWrappedModelConfig(BaseModel): - model: PydanticPytorchModuleOrListType - block_names: list[str] - device_mesh: PydanticDeviceMeshIFType - ep_mesh_dim_name: str | None = None diff --git a/moe/modalities_moe/models/__init__.py b/moe/modalities_moe/models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/models/moe/__init__.py b/moe/modalities_moe/models/moe/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/models/moe/moe_model.py b/moe/modalities_moe/models/moe/moe_model.py deleted file mode 100644 index a1412d410..000000000 --- a/moe/modalities_moe/models/moe/moe_model.py +++ /dev/null @@ -1,537 +0,0 @@ -import math -from dataclasses import dataclass -from typing import Literal, Optional, overload - -import torch -import torch.nn as nn -import torch.nn.functional as F -from pydantic import BaseModel - -# TODO reolve this import -try: - from torch.distributed.tensor import DTensor -except Exception: - DTensor = None - - -class MoEModelConfig(BaseModel): - # model config - vocab_size: int - max_seq_len: int - d_model: int - n_heads: int - n_kv_heads: int - num_layers: int - d_ff: int - sample_key: str = "input_ids" - prediction_key: str = "logits" - attn_dropout: float = 0.0 - ffn_dropout: float = 0.0 - tie_embeddings: bool = False - moe_every_n_layers: int = 1 - moe_num_experts: int = 8 - moe_top_k: int = 2 - moe_capacity_factor: float = 1.25 - moe_aux_loss_coef: float = 0.01 - moe_z_loss_coef: float = 0.0 - moe_router_noise_std: float = 0.0 - - -@dataclass -class MoEArguments: - # Model hyperparameters - d_model: int - d_ff: int - - # MoE hyperparameters - num_experts: int - top_k: int - capacity_factor: float = 1.25 - min_capacity: int = 4 - overflow_policy: Literal["drop", "residual"] = "residual" - - # Router configuration - router_noise_std: float = 0.0 - router_temperature: float = 1.0 - router_dropout: float = 0.0 - - # Auxiliary loss coefficients - aux_loss_coef: float = 0.01 - z_loss_coef: float = 0.0 - - # Training configuration - dropout: float = 0.0 - - -class RMSNorm(nn.Module): - def __init__(self, d_model, eps=1e-8): - super(RMSNorm, self).__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(d_model)) - - def forward(self, x): - norm_x = x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - return norm_x * self.weight - - -class Expert(nn.Module): - def __init__(self, d_model, d_ff, dropout=0.0): - super(Expert, self).__init__() - self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() - self.w1 = nn.Linear(d_model, d_ff) - self.w2 = nn.Linear(d_model, d_ff) - self.w3 = nn.Linear(d_ff, d_model) - - def forward(self, x): - x1 = self.w1(x) - x2 = self.w2(x) - x = torch.nn.functional.silu(x1) * x2 - x = self.w3(x) - return self.dropout(x) - - -class GroupedExperts(nn.Module): - """Grouped experts for torchtitan compatibility.""" - - def __init__(self, config: MoEArguments): - super().__init__() - self.num_experts = config.num_experts - self.d_model = config.d_model - self.d_ff = config.d_ff - self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else nn.Identity() - - self.w1 = nn.Parameter(torch.empty(self.num_experts, self.d_ff, self.d_model)) - self.b1 = nn.Parameter(torch.empty(self.num_experts, self.d_ff)) - self.w2 = nn.Parameter(torch.empty(self.num_experts, self.d_ff, self.d_model)) - self.b2 = nn.Parameter(torch.empty(self.num_experts, self.d_ff)) - self.w3 = nn.Parameter(torch.empty(self.num_experts, self.d_model, self.d_ff)) - self.b3 = nn.Parameter(torch.empty(self.num_experts, self.d_model)) - - self.initialize() - - def initialize(self): - nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) - bound_w1 = 1 / math.sqrt(self.d_model) - nn.init.uniform_(self.b1, -bound_w1, bound_w1) - - nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) - bound_w2 = 1 / math.sqrt(self.d_model) - nn.init.uniform_(self.b2, -bound_w2, bound_w2) - - nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5)) - bound_w3 = 1 / math.sqrt(self.d_ff) - nn.init.uniform_(self.b3, -bound_w3, bound_w3) - - def _forward_local(self, routed_input, num_tokens_per_expert) -> torch.Tensor: - outputs: list[torch.Tensor] = [] - start = 0 - - # ExpertParallel may convert parameters to DTensor. Local expert compute - # expects plain tensors, so we materialize local shards when needed. - w1 = self.w1.to_local() if DTensor is not None and isinstance(self.w1, DTensor) else self.w1 - b1 = self.b1.to_local() if DTensor is not None and isinstance(self.b1, DTensor) else self.b1 - w2 = self.w2.to_local() if DTensor is not None and isinstance(self.w2, DTensor) else self.w2 - b2 = self.b2.to_local() if DTensor is not None and isinstance(self.b2, DTensor) else self.b2 - w3 = self.w3.to_local() if DTensor is not None and isinstance(self.w3, DTensor) else self.w3 - b3 = self.b3.to_local() if DTensor is not None and isinstance(self.b3, DTensor) else self.b3 - - local_num_tokens = ( - num_tokens_per_expert.to_local() - if DTensor is not None and isinstance(num_tokens_per_expert, DTensor) - else num_tokens_per_expert - ) - - total_rows = routed_input.shape[0] - for expert_idx, num_tokens in enumerate(local_num_tokens.tolist()): - requested_tokens = int(num_tokens) - end = start + requested_tokens - - # EP alignment can request padded tokens; only a subset may exist in routed_input. - local_end = min(end, total_rows) - expert_input = routed_input[start:local_end] - real_tokens = int(expert_input.shape[0]) - - out_parts: list[torch.Tensor] = [] - if real_tokens > 0: - x1 = torch.nn.functional.linear(expert_input, w1[expert_idx], b1[expert_idx]) - x2 = torch.nn.functional.linear(expert_input, w2[expert_idx], b2[expert_idx]) - hidden = torch.nn.functional.silu(x1) * x2 - out_real = torch.nn.functional.linear(hidden, w3[expert_idx], b3[expert_idx]) - out_parts.append(self.dropout(out_real)) - - pad_tokens = requested_tokens - real_tokens - if pad_tokens > 0: - out_parts.append(routed_input.new_zeros((pad_tokens, self.d_model))) - - if len(out_parts) > 0: - outputs.append(torch.cat(out_parts, dim=0) if len(out_parts) > 1 else out_parts[0]) - - start = end - - if len(outputs) == 0: - return routed_input.new_zeros((0, self.d_model)) - - out = torch.cat(outputs, dim=0) - - # EP permute may append extra global padding slots beyond per-expert aligned sizes. - # output_fn(_unpermute) expects the same row count as routed_input. - if out.shape[0] < total_rows: - out = torch.cat( - [out, routed_input.new_zeros((total_rows - out.shape[0], self.d_model))], - dim=0, - ) - elif out.shape[0] > total_rows: - out = out[:total_rows] - - return out - - def forward(self, routed_input, num_tokens_per_expert) -> torch.Tensor: - # routed_input: (M, D), sorted/grouped by expert id - # num_tokens_per_expert: (E_local,) for local compute, or global counts before EP input_fn - return self._forward_local(routed_input, num_tokens_per_expert) - - -class MoEBlock(nn.Module): - def __init__(self, config: MoEArguments): - super(MoEBlock, self).__init__() - self.config = config - self.num_experts = config.num_experts - self.router = nn.Linear(config.d_model, self.num_experts) - self.router_dropout = nn.Dropout(config.router_dropout) if config.router_dropout > 0 else nn.Identity() - self.experts = GroupedExperts(config) - - self.last_aux_loss: Optional[torch.Tensor] = None - - def forward(self, x): - B, T, D = x.size() - E = self.config.num_experts - K = self.config.top_k - N = B * T - - x_flat = x.view(N, D) - - # Router logits - logits = self.router(self.router_dropout(x_flat)) # (N, E) - if self.config.router_noise_std > 0 and self.training: - noise = torch.randn_like(logits) * self.config.router_noise_std - logits = logits + noise - logits = logits / self.config.router_temperature - probs = torch.softmax(logits, dim=-1) # (N, E) - - # top-k - topk_val, topk_idx = torch.topk(probs, k=K, dim=-1) # (N, K) - topk_w = topk_val / (topk_val.sum(dim=-1, keepdim=True) + 1e-9) # (N, K) - - # capacity per expert - capacity = math.ceil(self.config.capacity_factor * N / E) - capacity = max(capacity, self.config.min_capacity) - - # dispatch mask - preserve dtype of input - dispatch_mask = torch.nn.functional.one_hot(topk_idx, num_classes=E).to(x_flat.dtype) # (N, K, E) - - # token assignment - expert_mask = dispatch_mask.sum(dim=1) # (N, E) - positions = torch.cumsum(expert_mask, dim=0) # (N, E) - capacity_mask = (positions <= capacity).to(x_flat.dtype) # (N, E) - final_mask = dispatch_mask * capacity_mask.unsqueeze(1) # (N, K, E) - combine_weights = final_mask * topk_w.unsqueeze(-1) # (N, K, E) - - combine_weights.sum(dim=1) # (N, E) - - # count actual assignments per expert - load = final_mask.sum(dim=[0, 1]) # (E,) - importance = probs.sum(dim=0) # (E,) - - # Build routed token stream. - valid_mask = capacity_mask.gather(1, topk_idx).bool() # (N, K) - token_ids = torch.arange(N, device=x.device).unsqueeze(1).expand(N, K) - - flat_valid = valid_mask.reshape(-1) - flat_token_ids = token_ids.reshape(-1)[flat_valid] - flat_expert_ids = topk_idx.reshape(-1)[flat_valid] - flat_weights = topk_w.reshape(-1)[flat_valid] - - if flat_expert_ids.numel() > 0: - sort_idx = torch.argsort(flat_expert_ids) - token_ids_sorted = flat_token_ids[sort_idx] - expert_ids_sorted = flat_expert_ids[sort_idx] - weights_sorted = flat_weights[sort_idx] - - routed_input = x_flat[token_ids_sorted] - num_tokens_per_expert = torch.bincount(expert_ids_sorted, minlength=E) - - routed_output = self.experts(routed_input, num_tokens_per_expert) - weighted_output = routed_output * weights_sorted.unsqueeze(-1) - - out = x_flat.new_zeros((N, D)) - out.index_add_(0, token_ids_sorted, weighted_output) - - assigned = x_flat.new_zeros((N,)) - assigned.index_add_(0, token_ids_sorted, weights_sorted) - else: - out = x_flat.new_zeros((N, D)) - assigned = x_flat.new_zeros((N,)) - - # Overflow handling: tokens not assigned to any expert - not_assigned = assigned < 1e-6 - - if not_assigned.any(): - if self.config.overflow_policy == "residual": - out[not_assigned] = x_flat[not_assigned] - # if 'drop', out is already zero for those positions - - # auxiliary loss - aux = None - if self.config.aux_loss_coef > 0: - imp = importance / (importance.sum() + 1e-9) - ld = load / (load.sum() + 1e-9) - lb = E * torch.sum(imp * ld) - aux = self.config.aux_loss_coef * lb - - if self.config.z_loss_coef > 0: - z = torch.logsumexp(logits, dim=-1) - z_loss = torch.mean(z**2) - aux = (aux if aux is not None else 0.0) + self.config.z_loss_coef * z_loss - - self.last_aux_loss = aux - return out.view(B, T, D) - - -class GroupedQueryAttention(nn.Module): - def __init__(self, d_model, num_heads, num_kv_heads): - super(GroupedQueryAttention, self).__init__() - self.d_model = d_model - self.n_heads = num_heads - self.n_kv_heads = num_kv_heads - self.head_dim = d_model // num_heads - self.q_proj = nn.Linear(d_model, num_heads * self.head_dim) - self.k_proj = nn.Linear(d_model, num_kv_heads * self.head_dim) - self.v_proj = nn.Linear(d_model, num_kv_heads * self.head_dim) - self.out_proj = nn.Linear(num_heads * self.head_dim, d_model) - - def forward(self, query, key, value, mask=None): - Q = self.q_proj(query).view(query.size(0), -1, self.n_heads, self.head_dim) - K = self.k_proj(key).view(key.size(0), -1, self.n_kv_heads, self.head_dim) - V = self.v_proj(value).view(value.size(0), -1, self.n_kv_heads, self.head_dim) - Q = Q.permute(0, 2, 1, 3) - K = K.permute(0, 2, 1, 3) - V = V.permute(0, 2, 1, 3) - # Compute attention scores - attn_scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) / (self.head_dim**0.5) - if mask is not None: - attn_scores += mask - attn_weights = F.softmax(attn_scores, dim=-1) - attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights, V) - attn_output = attn_output.contiguous().view(query.size(0), -1, self.n_heads * self.head_dim) - return self.out_proj(attn_output), None - - -class TransformerBlock(nn.Module): - """Transformer block with MoE""" - - def __init__(self, d_model, d_ff, num_heads, num_kv_heads, moe_config: MoEArguments): - super(TransformerBlock, self).__init__() - self.d_model = d_model - self.d_ff = d_ff - self.n_heads = num_heads - self.n_kv_heads = num_kv_heads - self.attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True) - self.pre_attn_norm = RMSNorm(d_model) - self.pre_ffn_norm = RMSNorm(d_model) - - if moe_config is not None: - self.ffn = MoEBlock(moe_config) - self.is_moe = True - else: - self.ffn = Expert(d_model, d_ff) - self.is_moe = False - - def forward(self, x): - x_norm = self.pre_attn_norm(x) - attn_output, _ = self.attention(x_norm, x_norm, x_norm) - x = x + attn_output - - # Pre-MoE norm - x_norm = self.pre_ffn_norm(x) - moe_output = self.ffn(x_norm) - x = x + moe_output - - return x - - @property # TODO: AUX LOSS IN FORWARD - def aux_loss(self): - if self.is_moe and hasattr(self.ffn, "last_aux_loss"): - return self.ffn.last_aux_loss - return None - - -class MoEModel(nn.Module): - def __init__( - self, - vocab_size: int, - max_seq_len: int, - d_model: int, - n_heads: int, - n_kv_heads: int, - d_ff: int, - num_layers: int, - sample_key: str = "input_ids", - prediction_key: str = "logits", - attn_dropout: float = 0.0, - ffn_dropout: float = 0.0, - tie_embeddings: bool = True, - moe_every_n_layers: int = 1, - moe_num_experts: int = 8, - moe_top_k: int = 2, - moe_capacity_factor: float = 1.25, - moe_aux_loss_coef: float = 0.01, - moe_z_loss_coef: float = 0.0, - moe_router_noise_std: float = 0.0, - ): - super(MoEModel, self).__init__() - self.sample_key = sample_key - self.prediction_key = prediction_key - self.vocab_size = vocab_size - self.max_seq_len = max_seq_len - self.d_model = d_model - self.n_heads = n_heads - self.n_kv_heads = n_kv_heads - self.num_layers = num_layers - self.d_ff = d_ff - self.attn_dropout = attn_dropout - self.ffn_dropout = ffn_dropout - self.tie_embeddings = tie_embeddings - self.moe_every_n_layers = moe_every_n_layers - self.moe_num_experts = moe_num_experts - self.moe_top_k = moe_top_k - self.moe_capacity_factor = moe_capacity_factor - self.moe_aux_loss_coef = moe_aux_loss_coef - self.moe_z_loss_coef = moe_z_loss_coef - self.moe_router_noise_std = moe_router_noise_std - - self.token_emb = nn.Embedding(self.vocab_size, self.d_model) - self.pos_emb = nn.Embedding(self.max_seq_len, self.d_model) - - moe_config = MoEArguments( - d_model=self.d_model, - d_ff=self.d_ff, - num_experts=self.moe_num_experts, - top_k=self.moe_top_k, - capacity_factor=self.moe_capacity_factor, - aux_loss_coef=self.moe_aux_loss_coef, - z_loss_coef=self.moe_z_loss_coef, - router_noise_std=self.moe_router_noise_std, - dropout=self.ffn_dropout, - ) - - self.layers = nn.ModuleDict() - for i in range(self.num_layers): - if i % self.moe_every_n_layers == 0: - block = TransformerBlock(self.d_model, self.d_ff, self.n_heads, self.n_kv_heads, moe_config) - else: - block = TransformerBlock(self.d_model, self.d_ff, self.n_heads, self.n_kv_heads, None) # No MoE - self.layers[str(i)] = block - self.final_norm = RMSNorm(self.d_model) - self.lm_head = nn.Linear(self.d_model, self.vocab_size, bias=False) - if self.tie_embeddings: - self.lm_head.weight = self.token_emb.weight - - @property - def weight_decay_groups(self): - return { - "linear": ["attention", "router", "w1", "w2", "w3", "b1", "b2", "b3", "lm_head"], - "embedding": ["token_emb", "pos_emb"], - "layernorm": ["pre_attn_norm", "pre_ffn_norm", "final_norm"], - } - - @overload - def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """ - Forward pass of the MoE module. - - Args: - inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - - sample_key (str): Key for the input tensor containing token ids. - - Returns: - dict[str, torch.Tensor]: A dictionary containing output tensors. - - prediction_key (str): Key for the output tensor containing logits. - """ - ... - - @overload - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the module. - - Args: - inputs (torch.Tensor): A tensor containing input token ids. - - Returns: - torch.Tensor: A tensor containing output logits. - """ - ... - - def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: - """ - Forward pass of the module. - - Args: - inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. - - Returns: - dict[str, torch.Tensor] | torch.Tensor: Model output. - """ - if isinstance(inputs, dict): - return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} - else: - return self.forward_impl(inputs) - - def forward_impl(self, input_ids: torch.Tensor) -> torch.Tensor: - B, T = input_ids.size() - assert T <= self.max_seq_len, f"Sequence length {T} exceeds model's max_seq_len {self.max_seq_len}" - device = input_ids.device - - # Token and position embeddings - token_embeddings = self.token_emb(input_ids) # (B, T, D) - positions = torch.arange(T, device=device).unsqueeze(0).expand(B, T) - pos_embeddings = self.pos_emb(positions) # (B, T, D) - x = token_embeddings + pos_embeddings # (B, T, D) - - # Transformer blocks - for i, layer in enumerate(self.layers.values()): - x = layer(x) - - x = self.final_norm(x) - logits = self.lm_head(x) # (B, T, vocab_size) - - return logits - - -if __name__ == "__main__": # sanity test - torch.manual_seed(0) - - model = MoEModel( - vocab_size=32064, - max_seq_len=32768, - d_model=4096, - n_heads=32, - n_kv_heads=8, - num_layers=32, - d_ff=14336, - moe_every_n_layers=1, - moe_num_experts=8, - moe_top_k=2, - ) - - # Print number of trainable parameters - num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"Number of trainable parameters: {num_params:,}") - - x = torch.randint(0, model.vocab_size, (2, 64)) - logits = model(x) - - print("logits:", logits.shape) - loss = logits.mean() - loss.backward() - print("backward OK") diff --git a/moe/modalities_moe/optimizers/__init__.py b/moe/modalities_moe/optimizers/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/training/__init__.py b/moe/modalities_moe/training/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/training/gradient_clipping/__init__.py b/moe/modalities_moe/training/gradient_clipping/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/scripts/train_ep.py b/moe/scripts/train_ep.py index 232ef287f..7c99eee03 100644 --- a/moe/scripts/train_ep.py +++ b/moe/scripts/train_ep.py @@ -1,27 +1,17 @@ # ruff: noqa: E402 import os -import sys from pathlib import Path - -MOE_ROOT = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(MOE_ROOT)) +from typing import cast import torch import torch.distributed as dist -from modalities_moe.config.config import EPWrappedModelConfig, MoECrossEntropyLossConfig -from modalities_moe.loss_functions import MoECrossEntropyLoss -from modalities_moe.models.model_factory import get_ep_wrapped_model -from modalities_moe.models.moe.qwen_model import QwenModel, QwenModelConfig -from modalities_moe.optimizers.ep_adamw import EPAdamWConfig, get_ep_adam_w -from modalities_moe.training.gradient_clipping.ep_gradient_clipper import EPGradientClipper from torch.distributed.tensor import DTensor from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType from modalities.config.instantiation_models import TrainingComponentsInstantiationModel from modalities.running_env.cuda_env import CudaEnv -from modalities.training.gradient_clipping.fsdp_gradient_clipper_config import FSDP2GradientClipperConfig cwd = Path(__file__).resolve().parent.parent os.chdir(cwd) @@ -102,8 +92,8 @@ def _generate_permute_indices_no_triton( kernels.generate_permute_indices = _generate_permute_indices_no_triton moe_utils.generate_permute_indices = _generate_permute_indices_no_triton - kernels._modalities_fallback_enabled = True - kernels._modalities_generate_permute_indices_original = _orig_generate_permute_indices + setattr(kernels, "_modalities_fallback_enabled", True) + setattr(kernels, "_modalities_generate_permute_indices_original", _orig_generate_permute_indices) def debug_ep(model): @@ -132,40 +122,10 @@ def main(): config_path=CONFIG_FILE_PATH, experiments_root_path=EXPERIMENTS_ROOT_PATH, ) - modalities_main.add_custom_component( - component_key="model", - variant_key="ep_wrapped", - custom_component=get_ep_wrapped_model, - custom_config=EPWrappedModelConfig, - ) - - modalities_main.add_custom_component( - component_key="model", variant_key="moe", custom_component=QwenModel, custom_config=QwenModelConfig - ) - - modalities_main.add_custom_component( - component_key="gradient_clipper", - variant_key="ep", - custom_component=EPGradientClipper, - custom_config=FSDP2GradientClipperConfig, - ) - - modalities_main.add_custom_component( - component_key="loss", - variant_key="moe_cross_entropy", - custom_component=MoECrossEntropyLoss, - custom_config=MoECrossEntropyLossConfig, - ) - - modalities_main.add_custom_component( - component_key="optimizer", - variant_key="ep_adam_w", - custom_component=get_ep_adam_w, - custom_config=EPAdamWConfig, - ) - components: TrainingComponentsInstantiationModel = modalities_main.build_components( - components_model_type=TrainingComponentsInstantiationModel + components = cast( + TrainingComponentsInstantiationModel, + modalities_main.build_components(components_model_type=TrainingComponentsInstantiationModel), ) # WORKAROUNDS (wip) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 42a19b99a..af280a0a4 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -83,6 +83,16 @@ class CLMCrossEntropyLossConfig(BaseModel): prediction_key: str +class MoECrossEntropyLossConfig(BaseModel): + target_key: str + prediction_key: str + model: Any + tag: str = "MoECrossEntropyLoss" + + class Config: + arbitrary_types_allowed = True + + # Checkpointing class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): k: PositiveInt @@ -167,6 +177,19 @@ class AdamWOptimizerConfig(BaseModel): fused: bool | None = None +class EPAdamWConfig(BaseModel): + wrapped_model: PydanticPytorchModuleOrListType + device_mesh: PydanticDeviceMeshIFType + lr: float + betas: tuple[float, float] + eps: float + weight_decay: float + weight_decay_groups_excluded: list[str] + + class Config: + arbitrary_types_allowed = True + + class DummyLRSchedulerConfig(BaseModel): optimizer: PydanticOptimizerIFType @@ -311,6 +334,13 @@ def validate_dp_mesh_existence(self): return self +class EPWrappedModelConfig(BaseModel): + model: PydanticPytorchModuleOrListType + block_names: list[str] + device_mesh: PydanticDeviceMeshIFType + ep_mesh_dim_name: str | None = None + + class DebuggingEnrichedModelConfig(BaseModel): model: PydanticPytorchModuleOrListType logging_dir_path: Path diff --git a/src/modalities/models/moe/__init__.py b/src/modalities/models/moe/__init__.py new file mode 100644 index 000000000..5e55327a1 --- /dev/null +++ b/src/modalities/models/moe/__init__.py @@ -0,0 +1,10 @@ +from modalities.models.moe.loss_functions import MoECrossEntropyLoss +from modalities.models.moe.model_factory import get_ep_wrapped_model +from modalities.models.moe.qwen_model import QwenModel, QwenModelConfig + +__all__ = [ + "MoECrossEntropyLoss", + "QwenModel", + "QwenModelConfig", + "get_ep_wrapped_model", +] diff --git a/moe/modalities_moe/loss_functions.py b/src/modalities/models/moe/loss_functions.py similarity index 92% rename from moe/modalities_moe/loss_functions.py rename to src/modalities/models/moe/loss_functions.py index 654677efb..642efb47a 100644 --- a/moe/modalities_moe/loss_functions.py +++ b/src/modalities/models/moe/loss_functions.py @@ -6,7 +6,7 @@ class MoECrossEntropyLoss(Loss): - """Cross Entropy Loss with auxiliary loss support for router balancing""" + """Cross entropy loss with optional MoE auxiliary losses from model layers.""" def __init__( self, @@ -31,7 +31,6 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: labels.contiguous().long().view(-1), ) - # Aux loss for layer in self.model.layers.values(): if hasattr(layer, "aux_loss") and layer.aux_loss is not None: loss = loss + layer.aux_loss.to(loss.dtype) diff --git a/moe/modalities_moe/models/model_factory.py b/src/modalities/models/moe/model_factory.py similarity index 86% rename from moe/modalities_moe/models/model_factory.py rename to src/modalities/models/moe/model_factory.py index 65dbe8e3f..406da1964 100644 --- a/moe/modalities_moe/models/model_factory.py +++ b/src/modalities/models/moe/model_factory.py @@ -10,10 +10,7 @@ from modalities.util import get_module_class_from_name -# TODO refactor these funtions into a utils -def _resolve_ep_mesh( - device_mesh: DeviceMesh, ep_mesh_dim_name: str | None -) -> DeviceMesh: # devicemesh not supporting EP +def _resolve_ep_mesh(device_mesh: DeviceMesh, ep_mesh_dim_name: str | None) -> DeviceMesh: mesh_dim_names = tuple(device_mesh.mesh_dim_names or ()) if ep_mesh_dim_name is not None: @@ -72,15 +69,6 @@ def _apply_torchtitan_ep(module, ep_mesh) -> None: setattr(module.experts, "_ep_enabled", True) -def debug_forward_hook(module, input): - for name, param in module.named_parameters(recurse=False): - if hasattr(param, "_local_tensor"): - # still dTensor - print(f"[EP forward] {name}: still DTensor, local={param._local_tensor.shape}") - else: - print(f"[EP forward] {name}: plain tensor shape={param.shape}") - - def get_ep_wrapped_model( model, block_names: list[str], @@ -89,7 +77,6 @@ def get_ep_wrapped_model( mp_param_dtype=torch.bfloat16, mp_reduce_dtype=torch.bfloat16, ) -> nn.Module: - # Warn for unresolved names, but still wrap any block types that can be resolved. block_types = [] missing_block_names = [] for name in block_names: @@ -111,7 +98,6 @@ def get_ep_wrapped_model( raise ValueError(f"None of the requested MoE block names were found: {block_names}") ep_mesh = _resolve_ep_mesh(device_mesh, ep_mesh_dim_name) - device_mesh["dp_shard"] MixedPrecisionPolicy(param_dtype=mp_param_dtype, reduce_dtype=mp_reduce_dtype) wrapped_blocks = 0 diff --git a/moe/modalities_moe/models/moe/qwen_model.py b/src/modalities/models/moe/qwen_model.py similarity index 92% rename from moe/modalities_moe/models/moe/qwen_model.py rename to src/modalities/models/moe/qwen_model.py index 3a5ec2d61..47a810722 100644 --- a/moe/modalities_moe/models/moe/qwen_model.py +++ b/src/modalities/models/moe/qwen_model.py @@ -1,11 +1,13 @@ import math -from typing import Literal, Optional +from typing import Literal, Optional, overload import torch import torch.nn as nn import torch.nn.functional as F from pydantic import BaseModel +from modalities.models.model import NNModel + try: from torch.distributed.tensor import DTensor except Exception: @@ -13,7 +15,6 @@ class QwenModelConfig(BaseModel): - # Model vocab_size: int max_seq_len: int d_model: int @@ -311,18 +312,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out.view(B, T, D) -class DenseMLP(nn.Module): - def __init__(self, d_model, d_ff, ffn_dropout): - super().__init__() - self.w1 = nn.Linear(d_model, d_ff, bias=False) - self.w2 = nn.Linear(d_model, d_ff, bias=False) - self.w3 = nn.Linear(d_ff, d_model, bias=False) - self.dropout = nn.Dropout(ffn_dropout) if ffn_dropout > 0 else nn.Identity() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x))) - - class TransformerBlock(nn.Module): def __init__( self, @@ -385,7 +374,7 @@ def aux_loss(self) -> Optional[torch.Tensor]: return getattr(self.ffn, "last_aux_loss", None) -class QwenModel(nn.Module): +class QwenModel(NNModel): def __init__( self, vocab_size: int, @@ -414,7 +403,12 @@ def __init__( moe_aux_loss_coef: float = 0.001, moe_z_loss_coef: float = 0.0, ): - super().__init__() + weight_decay_groups = { + "linear": ["q_proj", "k_proj", "v_proj", "o_proj", "lm_head", "router", "w1", "w2", "w3"], + "embedding": ["token_emb"], + "layernorm": ["pre_attn_norm", "pre_ffn_norm", "final_norm", "q_norm", "k_norm"], + } + super().__init__(weight_decay_groups=weight_decay_groups) self.sample_key = sample_key self.prediction_key = prediction_key @@ -454,15 +448,15 @@ def __init__( if tie_embeddings: self.lm_head.weight = self.token_emb.weight - @property - def weight_decay_groups(self): - return { - "linear": ["q_proj", "k_proj", "v_proj", "o_proj", "lm_head", "router", "w1", "w2", "w3"], - "embedding": ["token_emb"], - "layernorm": ["pre_attn_norm", "pre_ffn_norm", "final_norm", "q_norm", "k_norm"], - } + @overload + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + ... + + @overload + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + ... - def forward(self, inputs): + def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: if isinstance(inputs, dict): return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} return self.forward_impl(inputs) @@ -472,30 +466,3 @@ def forward_impl(self, input_ids: torch.Tensor) -> torch.Tensor: for layer in self.layers.values(): x = layer(x) return self.lm_head(self.final_norm(x)) - - -if __name__ == "__main__": - torch.manual_seed(0) - - model = QwenModel( - vocab_size=151936, - max_seq_len=4096, - d_model=2048, - n_heads=32, - n_kv_heads=8, - d_ff=6144, - moe_d_ff=768, - num_layers=48, - moe_num_experts=128, - moe_top_k=8, - ) - num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"Parametri: {num_params/1e9:.2f}B") - - x = torch.randint(0, 151936, (2, 64)) - logits = model(x) - print(f"Output: {logits.shape}") - - loss = logits.mean() - loss.backward() - print("Backward OK") diff --git a/moe/modalities_moe/optimizers/ep_adamw.py b/src/modalities/optimizers/ep_adamw.py similarity index 80% rename from moe/modalities_moe/optimizers/ep_adamw.py rename to src/modalities/optimizers/ep_adamw.py index d7d19fe9c..2b5e72aae 100644 --- a/moe/modalities_moe/optimizers/ep_adamw.py +++ b/src/modalities/optimizers/ep_adamw.py @@ -1,27 +1,12 @@ import torch import torch.distributed as dist -from pydantic import BaseModel from torch.distributed.tensor import DTensor from torch.nn import Module from torch.optim import AdamW, Optimizer -from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticPytorchModuleOrListType from modalities.optimizers.optimizer_factory import _build_optimizer_groups_via_weight_decay_split -class EPAdamWConfig(BaseModel): - wrapped_model: PydanticPytorchModuleOrListType - device_mesh: PydanticDeviceMeshIFType - lr: float - betas: tuple[float, float] - eps: float - weight_decay: float - weight_decay_groups_excluded: list[str] - - class Config: - arbitrary_types_allowed = True - - def _get_ep_param_ids(model: Module) -> set: return {id(p) for m in model.modules() if getattr(m, "_ep_enabled", False) for p in m.parameters(recurse=False)} @@ -39,14 +24,6 @@ def _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_ class EPAdamW(Optimizer): - """ - ZeRO stage-1 for EP (DTensor) params + standard AdamW for dense params. - - Each dp_shard rank stores optimizer states for 1/dp_shard of the EP params. - After each step, updated EP param values are broadcast from owner to all ranks. - Dense params are handled by a separate AdamW (FSDP2 shards them independently). - """ - def __init__( self, model: Module, @@ -65,7 +42,6 @@ def __init__( ep_param_ids = _get_ep_param_ids(model) self._all_ep_params = [p for p in model.parameters() if id(p) in ep_param_ids] - # rank r owns params[r::dp_size] self._owned_ep_params = self._all_ep_params[self._dp_rank :: self._dp_size] dense_groups = _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_groups_excluded) @@ -76,8 +52,6 @@ def __init__( self._ep_adamw = None self._dense_adamw = AdamW(dense_groups, lr=lr, betas=betas, eps=eps) - # unified param groups for lr_scheduler compatibility: - # group 0 = all EP params, groups 1+ = dense weight-decay split ep_group = {"params": self._all_ep_params, "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay} all_groups = [ep_group] + [{**g, "lr": lr, "betas": betas, "eps": eps} for g in dense_groups] super().__init__(all_groups, {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}) @@ -89,7 +63,6 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() - # all-reduce for p in self._all_ep_params: if p.grad is None: continue @@ -101,20 +74,16 @@ def step(self, closure=None): dist.all_reduce(p.grad, op=dist.ReduceOp.SUM, group=self._dp_group) p.grad.div_(self._dp_size) - # Sync lr if self._ep_adamw is not None: self._ep_adamw.param_groups[0]["lr"] = self.param_groups[0]["lr"] - for i, g in enumerate(self._dense_adamw.param_groups): - g["lr"] = self.param_groups[i + 1]["lr"] + for i, group in enumerate(self._dense_adamw.param_groups): + group["lr"] = self.param_groups[i + 1]["lr"] - # Update ep params if self._ep_adamw is not None: self._ep_adamw.step() - # Update dense params self._dense_adamw.step() - # broadcast updated EP param local tensors for i, p in enumerate(self._all_ep_params): owner_local_rank = i % self._dp_size owner_global_rank = dist.get_global_rank(self._dp_group, owner_local_rank) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 26df9b432..71eb2c8ad 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -36,6 +36,8 @@ DummyLRSchedulerConfig, DummyProgressSubscriberConfig, DummyResultSubscriberConfig, + EPAdamWConfig, + EPWrappedModelConfig, EvaluationResultToDiscSubscriberConfig, FSDP1ActivationCheckpointedModelConfig, FSDP1CheckpointedModelConfig, @@ -51,6 +53,7 @@ LinearWarmupCosineAnnealingLRSchedulerConfig, LLMDataLoaderConfig, MemMapDatasetConfig, + MoECrossEntropyLossConfig, OneCycleLRSchedulerConfig, PackedMemMapDatasetContinuousConfig, PackedMemMapDatasetMegatronConfig, @@ -96,6 +99,9 @@ from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer, Llama3InitializerConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory +from modalities.models.moe.loss_functions import MoECrossEntropyLoss +from modalities.models.moe.model_factory import get_ep_wrapped_model +from modalities.models.moe.qwen_model import QwenModel, QwenModelConfig from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory from modalities.models.parallelism.pipeline_parallelism_configs import ( ComponentSelectorFromPipelineConfig, @@ -109,12 +115,14 @@ ComposedInitializationRoutines, ComposedModelInitializationConfig, ) +from modalities.optimizers.ep_adamw import get_ep_adam_w from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory from modalities.optimizers.optimizer_factory import OptimizerFactory from modalities.optimizers.optimizer_list import OptimizersList from modalities.optimizers.scheduler_list import SchedulerList from modalities.running_env.fsdp.device_mesh import DeviceMeshConfig, get_device_mesh, get_parallel_degree from modalities.tokenization.tokenizer_wrapper import PreTrainedHFTokenizer, PreTrainedSPTokenizer +from modalities.training.gradient_clipping.ep_gradient_clipper import EPGradientClipper from modalities.training.gradient_clipping.fsdp_gradient_clipper import ( FSDP1GradientClipper, FSDP1LoggingOnlyGradientClipper, @@ -187,6 +195,8 @@ class ComponentEntity: COMPONENTS = [ # models ComponentEntity("model", "gpt2", GPT2ModelFactory.get_gpt2_model, GPT2LLMConfig), + ComponentEntity("model", "moe", QwenModel, QwenModelConfig), + ComponentEntity("model", "ep_wrapped", get_ep_wrapped_model, EPWrappedModelConfig), ComponentEntity( "model", "gpt2_tp", maybe_model_list(GPT2ModelFactory.get_gpt2_tensor_parallelized_model), GPT2ModelTPConfig ), @@ -250,6 +260,7 @@ class ComponentEntity: ), # losses ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), + ComponentEntity("loss", "moe_cross_entropy", MoECrossEntropyLoss, MoECrossEntropyLossConfig), # optimizers ComponentEntity( "optimizer", "adam", maybe_model_list_for_optimizer(OptimizerFactory.get_adam), AdamOptimizerConfig @@ -257,6 +268,7 @@ class ComponentEntity: ComponentEntity( "optimizer", "adam_w", maybe_model_list_for_optimizer(OptimizerFactory.get_adam_w), AdamWOptimizerConfig ), + ComponentEntity("optimizer", "ep_adam_w", maybe_model_list_for_optimizer(get_ep_adam_w), EPAdamWConfig), ComponentEntity( "optimizer", "fsdp1_checkpointed", @@ -402,6 +414,7 @@ class ComponentEntity: "gradient_clipper", "fsdp1_logging_only", FSDP1LoggingOnlyGradientClipper, FSDP1DummyGradientClipperConfig ), ComponentEntity("gradient_clipper", "fsdp2", FSDP2GradientClipper, FSDP2GradientClipperConfig), + ComponentEntity("gradient_clipper", "ep", EPGradientClipper, FSDP2GradientClipperConfig), ComponentEntity( "gradient_clipper", "fsdp2_logging_only", FSDP2LoggingOnlyGradientClipper, FSDP2DummyGradientClipperConfig ), diff --git a/moe/modalities_moe/training/gradient_clipping/ep_gradient_clipper.py b/src/modalities/training/gradient_clipping/ep_gradient_clipper.py similarity index 91% rename from moe/modalities_moe/training/gradient_clipping/ep_gradient_clipper.py rename to src/modalities/training/gradient_clipping/ep_gradient_clipper.py index 0581e3634..a2b6f25b8 100644 --- a/moe/modalities_moe/training/gradient_clipping/ep_gradient_clipper.py +++ b/src/modalities/training/gradient_clipping/ep_gradient_clipper.py @@ -16,7 +16,7 @@ class EPGradientClipper(FSDP2GradientClipper): - """FSDP2 clipper wrapper for EP adaptation""" + """FSDP2 clipper wrapper that handles EP DTensor gradients safely.""" def __init__( self, @@ -54,7 +54,6 @@ def clip_gradients(self) -> torch.Tensor: for grad in grads: grad_norm = torch.linalg.vector_norm(grad, ord=norm_type_val) if isinstance(grad_norm, DTensor): - # Reduce each partial norm inside its own mesh before aggregation. grad_norm = grad_norm.full_tensor() norm_scalars.append(grad_norm.to(first_device)) @@ -79,8 +78,6 @@ def clip_gradients(self) -> torch.Tensor: dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) total_norm **= 1.0 / self.norm_type.value - # do not use torch.nn.utils.clip_grads_with_norm_ here: it batches grads with - # torch._foreach_mul_, which fails when the list mixes DTensors from different meshes. clip_coef = self.max_norm / (total_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) From b70ec51df1aad54c15aaedd6777f2b3f1156f448 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Jun 2026 11:12:08 +0000 Subject: [PATCH 03/12] fix(moe): Fix dtype and state_dict mismatch --- .../checkpointing/stateful/app_state.py | 16 ++++++++++++++++ src/modalities/models/moe/qwen_model.py | 6 ++++++ 2 files changed, 22 insertions(+) diff --git a/src/modalities/checkpointing/stateful/app_state.py b/src/modalities/checkpointing/stateful/app_state.py index 2da3ab236..25c1efc91 100644 --- a/src/modalities/checkpointing/stateful/app_state.py +++ b/src/modalities/checkpointing/stateful/app_state.py @@ -184,6 +184,16 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: class OptimizerStateRetriever(StateRetrieverIF): + @staticmethod + def _uses_standard_optimizer_state_dict(app_state: AppState) -> bool: + """Checks whether the optimizer state dict follows the standard torch Optimizer schema. + + Standard optimizer state dicts contain top-level "state" and "param_groups" keys, + which are required by distributed optimizer checkpoint utilities. + """ + state_dict = app_state.optimizer.state_dict() + return isinstance(state_dict, dict) and "state" in state_dict and "param_groups" in state_dict + @staticmethod def get_state_dict(app_state: AppState) -> dict[str, Any]: """Returns the state dict of the optimizer in the AppState object. @@ -196,6 +206,10 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]: """ if isinstance(app_state.optimizer, OptimizersList): sd = app_state.optimizer.state_dict() + elif not OptimizerStateRetriever._uses_standard_optimizer_state_dict(app_state): + # Custom optimizers (e.g. EP wrappers) may not expose the standard torch + # optimizer format expected by get_optimizer_state_dict. + sd = app_state.optimizer.state_dict() else: assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." sd = get_optimizer_state_dict( @@ -217,6 +231,8 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: """ if isinstance(app_state.optimizer, OptimizersList): app_state.optimizer.load_state_dict(state_dict) + elif not OptimizerStateRetriever._uses_standard_optimizer_state_dict(app_state): + app_state.optimizer.load_state_dict(state_dict) else: assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." set_optimizer_state_dict( diff --git a/src/modalities/models/moe/qwen_model.py b/src/modalities/models/moe/qwen_model.py index 47a810722..ac4cab752 100644 --- a/src/modalities/models/moe/qwen_model.py +++ b/src/modalities/models/moe/qwen_model.py @@ -166,6 +166,12 @@ def _forward_local(self, routed_input: torch.Tensor, num_tokens_per_expert: torc w1 = self.w1.to_local() if DTensor is not None and isinstance(self.w1, DTensor) else self.w1 w2 = self.w2.to_local() if DTensor is not None and isinstance(self.w2, DTensor) else self.w2 w3 = self.w3.to_local() if DTensor is not None and isinstance(self.w3, DTensor) else self.w3 + # F.linear requires matching dtypes between inputs and weights. Under mixed precision, + # routed_input can be BF16 while local expert weights remain FP32. + if routed_input.dtype != w1.dtype: + w1 = w1.to(dtype=routed_input.dtype) + w2 = w2.to(dtype=routed_input.dtype) + w3 = w3.to(dtype=routed_input.dtype) local_num_tokens = ( num_tokens_per_expert.to_local() if DTensor is not None and isinstance(num_tokens_per_expert, DTensor) From e5edeeb715c644368ba0096a8985d7a2f3ad219f Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Jun 2026 11:15:46 +0000 Subject: [PATCH 04/12] chore: Remove outdated files --- moe/config/moe_ep_config.yaml | 357 ---------------------- moe/config/qwen_config.yaml | 365 ----------------------- moe/config/tokenization_config.yaml | 18 -- moe/scripts/train_ep.py | 155 ---------- {moe/scripts => scripts}/monitor_gpus.sh | 0 5 files changed, 895 deletions(-) delete mode 100644 moe/config/moe_ep_config.yaml delete mode 100644 moe/config/qwen_config.yaml delete mode 100644 moe/config/tokenization_config.yaml delete mode 100644 moe/scripts/train_ep.py rename {moe/scripts => scripts}/monitor_gpus.sh (100%) diff --git a/moe/config/moe_ep_config.yaml b/moe/config/moe_ep_config.yaml deleted file mode 100644 index 883e0dffb..000000000 --- a/moe/config/moe_ep_config.yaml +++ /dev/null @@ -1,357 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - config_file_path: ${modalities_env:config_file_path} - referencing_keys: - sample_key: input_ids - target_key: target_ids - prediction_key: logits - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - experiments_root_path: /leonardo/home/userexternal/gesposit/projects/modalities/moe/experiments - experiment_folder_path: ${settings.paths.experiments_root_path}/${settings.experiment_id} - checkpoint_saving_path: /leonardo_scratch/large/userexternal/gesposit/modalities/checkpoints - train_dataset_path: /leonardo_scratch/large/userexternal/gesposit/modalities/data/processed/fineweb_edu_num_docs_483606.pbin - intervals: - training_log_interval_in_steps: 1 - checkpointing_interval_in_steps: 1001 - evaluation_interval_in_steps: 1001 - consistency_enforcement: - enforce_tokens_per_step_consistency: true - enforce_last_step_logged: false - enforce_last_step_evaluated: false - enforce_last_step_checkpointed: false - step_profile: - gradient_accumulation_steps: 4 - local_train_micro_batch_size: 1 - sequence_length: 512 - dp_degree: - instance_key: dp_degree - pass_type: BY_REFERENCE - training_target: - num_target_tokens: - component_key: number_conversion - variant_key: num_tokens_from_num_steps - config: - num_steps: ${settings.training_target.num_target_steps} - dp_degree: - instance_key: dp_degree - pass_type: BY_REFERENCE - local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} - sequence_length: ${settings.step_profile.sequence_length} - gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} - num_target_steps: 10 - training_progress: - global_num_seen_tokens: 0 - num_seen_steps: 0 - num_seen_samples: 0 - last_step: -1 - -collate_fn: - component_key: collate_fn - variant_key: gpt_2_llm_collator - config: - sample_key: ${settings.referencing_keys.sample_key} - target_key: ${settings.referencing_keys.target_key} - -train_dataset: - component_key: dataset - variant_key: packed_mem_map_dataset_continuous - config: - raw_data_path: ${settings.paths.train_dataset_path} - sequence_length: ${settings.step_profile.sequence_length} - sample_key: ${settings.referencing_keys.sample_key} - -train_dataloader: - component_key: data_loader - variant_key: default - config: - # we set num_workers to 0 so that the the data is loaded in the main process - # this is required to track how often the collator has been called - # in the library tutorials. Otherwise the collator will be copied for each worker - # and the number of call is out of scope. - num_workers: 0 - pin_memory: true - dataloader_tag: train - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.step_profile.local_train_micro_batch_size} - drop_last: true - sampler: - component_key: sampler - variant_key: resumable_distributed_sampler - config: - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: true - seed: 42 - drop_last: true - skip_num_global_samples: ${settings.training_progress.num_seen_samples} - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: [] - -checkpoint_saving: - component_key: checkpoint_saving - variant_key: default - config: - checkpoint_saving_strategy: - component_key: checkpoint_saving_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpoint_saving_execution: - component_key: checkpoint_saving_execution - variant_key: dcp - config: - checkpoint_path: ${settings.paths.experiment_folder_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - -loss_fn: - component_key: loss - variant_key: moe_cross_entropy - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: ${settings.referencing_keys.prediction_key} - model: - instance_key: model_raw - pass_type: BY_REFERENCE - -device_mesh: - component_key: device_mesh - variant_key: default - config: - device_type: cuda - data_parallel_replicate_degree: 1 - # Keep FSDP sharding on dp_shard and reserve tp for expert parallel. - data_parallel_shard_degree: -1 - tensor_parallel_degree: 32 - world_size: ${settings.cuda_env.world_size} - -dp_degree: - component_key: number_conversion - variant_key: parallel_degree - config: # get the parallel degree from the device mesh - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - parallelism_methods: [dp_shard, dp_replicate] - -app_state: - component_key: app_state - variant_key: raw - config: - model: - instance_key: initialized_model - pass_type: BY_REFERENCE - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - lr_scheduler: - instance_key: lr_scheduler - pass_type: BY_REFERENCE - -initialized_model: - component_key: model - variant_key: model_initialized - config: - model: - instance_key: fsdp_model - pass_type: BY_REFERENCE - model_initializer: - component_key: model_initialization - variant_key: composed - config: - model_type: gpt2 - weight_init_type: scaled - mean: 0.0 - std: 0.02 - num_layers: ${model_raw.config.num_layers} - -ep_model: - component_key: model - variant_key: ep_wrapped - config: - model: - instance_key: model_raw # Bypass torch.compile - MoE routing is incompatible - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - ep_mesh_dim_name: tp - block_names: [TransformerBlock] - -ac_model: - component_key: model - variant_key: activation_checkpointed # using modalities fsdp2 ac. should do to job also for ep layers - config: - model: - instance_key: ep_model - pass_type: BY_REFERENCE - ac_variant: full_activation_checkpointing - layers_fqn: layers - ac_fun_params: - ac_freq: 1 - -fsdp_model: - component_key: model - variant_key: fsdp2_wrapped - config: - model: - instance_key: ac_model - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - mixed_precision_settings: - param_dtype: BF_16 - reduce_dtype: BF_16 - reshard_after_forward: true - block_names: [TransformerBlock] - -compiled_model: - component_key: model - variant_key: compiled - config: - model: - instance_key: model_raw - pass_type: BY_REFERENCE - block_names: [TransformerBlock] - -model_raw: - component_key: model - variant_key: moe - config: - vocab_size: 32064 # to match a pretrained tokenizer - max_seq_len: 4096 - d_model: 4096 - n_heads: 32 - n_kv_heads: 8 - num_layers: 32 - d_ff: 14336 - moe_every_n_layers: 1 - moe_num_experts: 16 - moe_top_k: 2 - -lr_scheduler: - component_key: scheduler - variant_key: onecycle_lr - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - max_lr: 6e-4 - div_factor: 10 - final_div_factor: 1 - total_steps: ${settings.training_target.num_target_steps} - pct_start: 0.02 - anneal_strategy: cos - last_epoch: ${settings.training_progress.last_step} - -optimizer: - component_key: optimizer - variant_key: adam_w - config: - lr: 0.0001 - betas: [0.9, 0.95] - eps: 1e-8 - weight_decay: 1e-1 - weight_decay_groups_excluded: [embedding, layernorm] - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - -gradient_clipper: - component_key: gradient_clipper - variant_key: ep - config: - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - norm_type: P2_NORM - max_norm: 1.0 - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - -progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - global_rank: ${settings.cuda_env.global_rank} - num_seen_steps: ${settings.training_progress.num_seen_steps} - num_target_steps: ${settings.training_target.num_target_steps} - train_dataloader_tag: ${train_dataloader.config.dataloader_tag} - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: to_disc - config: - output_file_path: ${settings.paths.experiment_folder_path}/evaluation_results.jsonl - -mfu_calculator: - component_key: mfu_calculator - variant_key: gpt2 - config: - n_layer: ${model_raw.config.num_layers} - sequence_length: ${settings.step_profile.sequence_length} - n_embd: ${model_raw.config.d_model} - world_size: ${settings.cuda_env.world_size} - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - -# profiler: -# component_key: steppable_profiler -# variant_key: combined -# config: -# profilers: -# - instance_key: kernel_profiler -# pass_type: BY_REFERENCE -# # - instance_key: memory_profiler -# # pass_type: BY_REFERENCE - -kernel_profiler: - component_key: steppable_profiler - variant_key: kernel_tracing - config: - num_wait_steps: 1 - num_warmup_steps: 1 - num_active_steps: 3 - profiler_activities: [CUDA] - profile_memory: true - record_shapes: true - with_stack: true - with_flops: true - with_modules: true - tracked_ranks: [0] - output_folder_path: ${settings.paths.experiment_folder_path}/profiling - -memory_profiler: - component_key: steppable_profiler - variant_key: memory_tracing - config: - memory_snapshot_folder_path: ${settings.paths.experiment_folder_path}/profiling - num_wait_steps: 1 - num_warmup_steps: 1 - num_active_steps: 3 - tracked_ranks: [0] \ No newline at end of file diff --git a/moe/config/qwen_config.yaml b/moe/config/qwen_config.yaml deleted file mode 100644 index 46b233dec..000000000 --- a/moe/config/qwen_config.yaml +++ /dev/null @@ -1,365 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - config_file_path: ${modalities_env:config_file_path} - referencing_keys: - sample_key: input_ids - target_key: target_ids - prediction_key: logits - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - experiments_root_path: /raid/s3/opengptx/user/richard-rutmann/experiments/modalities/moe_fsdp2 - experiment_folder_path: ${settings.paths.experiments_root_path}/${settings.experiment_id} - checkpoint_saving_path: /raid/s3/opengptx/user/richard-rutmann/experiments/modalities/moe_fsdp2/checkpoints - train_dataset_path: /raid/s3/opengptx/user/richard-rutmann/data/modalities/gpt2_tokenized/000_00000.pbin - intervals: - training_log_interval_in_steps: 1 - checkpointing_interval_in_steps: 1001 - evaluation_interval_in_steps: 1001 - consistency_enforcement: - enforce_tokens_per_step_consistency: true - enforce_last_step_logged: false - enforce_last_step_evaluated: false - enforce_last_step_checkpointed: false - step_profile: - gradient_accumulation_steps: 4 - local_train_micro_batch_size: 2 - sequence_length: 4096 - dp_degree: - instance_key: dp_degree - pass_type: BY_REFERENCE - training_target: - num_target_tokens: - component_key: number_conversion - variant_key: num_tokens_from_num_steps - config: - num_steps: ${settings.training_target.num_target_steps} - dp_degree: - instance_key: dp_degree - pass_type: BY_REFERENCE - local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} - sequence_length: ${settings.step_profile.sequence_length} - gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} - num_target_steps: 10 - training_progress: - global_num_seen_tokens: 0 - num_seen_steps: 0 - num_seen_samples: 0 - last_step: -1 - -collate_fn: - component_key: collate_fn - variant_key: gpt_2_llm_collator - config: - sample_key: ${settings.referencing_keys.sample_key} - target_key: ${settings.referencing_keys.target_key} - -train_dataset: - component_key: dataset - variant_key: packed_mem_map_dataset_continuous - config: - raw_data_path: ${settings.paths.train_dataset_path} - sequence_length: ${settings.step_profile.sequence_length} - sample_key: ${settings.referencing_keys.sample_key} - -train_dataloader: - component_key: data_loader - variant_key: default - config: - # we set num_workers to 0 so that the the data is loaded in the main process - # this is required to track how often the collator has been called - # in the library tutorials. Otherwise the collator will be copied for each worker - # and the number of call is out of scope. - num_workers: 0 - pin_memory: true - dataloader_tag: train - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.step_profile.local_train_micro_batch_size} - drop_last: true - sampler: - component_key: sampler - variant_key: resumable_distributed_sampler - config: - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: true - seed: 42 - drop_last: true - skip_num_global_samples: ${settings.training_progress.num_seen_samples} - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: [] - -checkpoint_saving: - component_key: checkpoint_saving - variant_key: default - config: - checkpoint_saving_strategy: - component_key: checkpoint_saving_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpoint_saving_execution: - component_key: checkpoint_saving_execution - variant_key: dcp - config: - checkpoint_path: ${settings.paths.experiment_folder_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - -loss_fn: - component_key: loss - variant_key: moe_cross_entropy - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: ${settings.referencing_keys.prediction_key} - model: - instance_key: model_raw - pass_type: BY_REFERENCE - -device_mesh: - component_key: device_mesh - variant_key: default - config: - device_type: cuda - data_parallel_replicate_degree: 1 - # Keep FSDP sharding on dp_shard and reserve tp for expert parallel. - data_parallel_shard_degree: -1 - tensor_parallel_degree: 4 - world_size: ${settings.cuda_env.world_size} - -dp_degree: - component_key: number_conversion - variant_key: parallel_degree - config: # get the parallel degree from the device mesh - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - parallelism_methods: [dp_shard, dp_replicate] - -app_state: - component_key: app_state - variant_key: raw - config: - model: - instance_key: initialized_model - pass_type: BY_REFERENCE - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - lr_scheduler: - instance_key: lr_scheduler - pass_type: BY_REFERENCE - -initialized_model: - component_key: model - variant_key: model_initialized - config: - model: - instance_key: fsdp_model - pass_type: BY_REFERENCE - model_initializer: - component_key: model_initialization - variant_key: composed - config: - model_type: gpt2 - weight_init_type: scaled - mean: 0.0 - std: 0.02 - num_layers: ${model_raw.config.num_layers} - -ep_model: - component_key: model - variant_key: ep_wrapped - config: - model: - instance_key: model_raw # Bypass torch.compile - MoE routing is incompatible - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - ep_mesh_dim_name: tp - block_names: [TransformerBlock] - -ac_model: - component_key: model - variant_key: activation_checkpointed # using modalities fsdp2 ac. should do to job also for ep layers - config: - model: - instance_key: ep_model - pass_type: BY_REFERENCE - ac_variant: full_activation_checkpointing - layers_fqn: layers - ac_fun_params: - ac_freq: 1 - -fsdp_model: - component_key: model - variant_key: fsdp2_wrapped - config: - model: - instance_key: ac_model - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - mixed_precision_settings: - param_dtype: BF_16 - reduce_dtype: BF_16 - reshard_after_forward: true - block_names: [TransformerBlock] - -compiled_model: - component_key: model - variant_key: compiled - config: - model: - instance_key: model_raw - pass_type: BY_REFERENCE - block_names: [TransformerBlock] - -model_raw: - component_key: model - variant_key: moe - config: - vocab_size: 50257 # to match a pretrained tokenizer, tochange - max_seq_len: 4096 - d_model: 2048 - d_ff: 6144 - n_heads: 32 - n_kv_heads: 8 - num_layers: 8 - attn_dropout: 0.0 - ffn_dropout: 0.0 - tie_embeddings: false - norm_eps: 1e-06 - rope_base: 1000000.0 - moe_num_experts: 128 - moe_d_ff: 768 - moe_top_k: 8 - -lr_scheduler: - component_key: scheduler - variant_key: onecycle_lr - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - max_lr: 6e-4 - div_factor: 10 - final_div_factor: 1 - total_steps: ${settings.training_target.num_target_steps} - pct_start: 0.02 - anneal_strategy: cos - last_epoch: ${settings.training_progress.last_step} - -optimizer: - component_key: optimizer - variant_key: ep_adam_w - config: - lr: 0.0001 - betas: [0.9, 0.95] - eps: 1e-8 - weight_decay: 1e-1 - weight_decay_groups_excluded: [embedding, layernorm] - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - -gradient_clipper: - component_key: gradient_clipper - variant_key: ep - config: - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - norm_type: P2_NORM - max_norm: 1.0 - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - -progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - global_rank: ${settings.cuda_env.global_rank} - num_seen_steps: ${settings.training_progress.num_seen_steps} - num_target_steps: ${settings.training_target.num_target_steps} - train_dataloader_tag: ${train_dataloader.config.dataloader_tag} - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: to_disc - config: - output_file_path: ${settings.paths.experiment_folder_path}/evaluation_results.jsonl - -mfu_calculator: - component_key: mfu_calculator - variant_key: gpt2 - config: - n_layer: ${model_raw.config.num_layers} - sequence_length: ${settings.step_profile.sequence_length} - n_embd: ${model_raw.config.d_model} - world_size: ${settings.cuda_env.world_size} - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - -# profiler: -# component_key: steppable_profiler -# variant_key: combined -# config: -# profilers: -# - instance_key: kernel_profiler -# pass_type: BY_REFERENCE -# # - instance_key: memory_profiler -# # pass_type: BY_REFERENCE - -kernel_profiler: - component_key: steppable_profiler - variant_key: kernel_tracing - config: - num_wait_steps: 1 - num_warmup_steps: 1 - num_active_steps: 3 - profiler_activities: [CUDA] - profile_memory: true - record_shapes: true - with_stack: true - with_flops: true - with_modules: true - tracked_ranks: [0] - output_folder_path: ${settings.paths.experiment_folder_path}/profiling - -memory_profiler: - component_key: steppable_profiler - variant_key: memory_tracing - config: - memory_snapshot_folder_path: ${settings.paths.experiment_folder_path}/profiling - num_wait_steps: 1 - num_warmup_steps: 1 - num_active_steps: 3 - tracked_ranks: [0] \ No newline at end of file diff --git a/moe/config/tokenization_config.yaml b/moe/config/tokenization_config.yaml deleted file mode 100644 index 5a4b8b781..000000000 --- a/moe/config/tokenization_config.yaml +++ /dev/null @@ -1,18 +0,0 @@ -settings: - src_path: data/raw/fineweb_edu_num_docs_483606.jsonl - dst_path: data/preprocessed/fineweb_edu_num_docs_483606.pbin - index_path: data/preprocessed/fineweb_edu_num_docs_483606.idx - jq_pattern: .text - num_cpus: ${node_env:num_cpus} - eod_token: <|endoftext|> - processing_batch_size: 10 - raw_samples_queue_size: 300 - processed_samples_queue_size: 300 - -tokenizer: - component_key: tokenizer - variant_key: pretrained_hf_tokenizer - config: - pretrained_model_name_or_path: data/tokenizer - padding: false - truncation: false \ No newline at end of file diff --git a/moe/scripts/train_ep.py b/moe/scripts/train_ep.py deleted file mode 100644 index 7c99eee03..000000000 --- a/moe/scripts/train_ep.py +++ /dev/null @@ -1,155 +0,0 @@ -# ruff: noqa: E402 - -import os -from pathlib import Path -from typing import cast - -import torch -import torch.distributed as dist -from torch.distributed.tensor import DTensor - -from modalities.__main__ import Main -from modalities.config.config import ProcessGroupBackendType -from modalities.config.instantiation_models import TrainingComponentsInstantiationModel -from modalities.running_env.cuda_env import CudaEnv - -cwd = Path(__file__).resolve().parent.parent -os.chdir(cwd) -CONFIG_FILE_PATH = cwd / "config" / "qwen_config.yaml" -EXPERIMENTS_ROOT_PATH = cwd / "results" / "debug" - - -# TODO solve this -def _enable_torchtitan_moe_permute_fallback() -> ( - None -): # VIBECODATA because of Triton C error with Python headers don't know what that is - """Avoid Triton JIT build for MoE permute indices on systems without Python dev headers.""" - try: - import torchtitan.models.moe.kernels as kernels - import torchtitan.models.moe.utils as moe_utils - except Exception: - return - - if getattr(kernels, "_modalities_fallback_enabled", False): - return - - def _fill_indices_torch( - tokens_per_expert_group: torch.Tensor, - start_index_values: torch.Tensor, - write_offsets: torch.Tensor, - experts_per_rank: int, - num_ranks: int, - max_len: int, - ) -> torch.Tensor: - device = tokens_per_expert_group.device - permuted_indices = torch.full((max_len,), -1, dtype=torch.int32, device=device) - - for e in range(experts_per_rank): - write_start = int(write_offsets[e].item()) - for r in range(num_ranks): - i = r * experts_per_rank + e - start_index = int(start_index_values[i].item()) - length = int(tokens_per_expert_group[i].item()) - if length > 0: - end_idx = min(write_start + length, max_len) - permuted_indices[write_start:end_idx] = torch.arange( - start_index, - start_index + (end_idx - write_start), - dtype=torch.int32, - device=device, - ) - write_start += length - - return permuted_indices - - _orig_generate_permute_indices = kernels.generate_permute_indices - - def _generate_permute_indices_no_triton( - tokens_per_expert_group: torch.Tensor, - experts_per_rank: int, - num_ranks: int, - max_len: int, - alignment: int, - use_cpu: bool = False, - ): - del use_cpu - start_index_values = torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group - total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) - total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment) - m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(torch.int32) - m_offsets = torch.cumsum(m_sizes, 0) - write_offsets = m_offsets - m_sizes - - permuted_indices = _fill_indices_torch( - tokens_per_expert_group=tokens_per_expert_group, - start_index_values=start_index_values, - write_offsets=write_offsets, - experts_per_rank=experts_per_rank, - num_ranks=num_ranks, - max_len=max_len, - ) - return permuted_indices, m_sizes, m_offsets.to(torch.int32) - - kernels.generate_permute_indices = _generate_permute_indices_no_triton - moe_utils.generate_permute_indices = _generate_permute_indices_no_triton - setattr(kernels, "_modalities_fallback_enabled", True) - setattr(kernels, "_modalities_generate_permute_indices_original", _orig_generate_permute_indices) - - -def debug_ep(model): - # Stima memoria teorica - total_params = sum(p.numel() for p in model.parameters()) - ep_params = sum( - p.numel() for m in model.modules() if getattr(m, "_ep_enabled", False) for p in m.parameters(recurse=False) - ) - dense_params = total_params - ep_params - - print(f"Params totali: {total_params/1e6:.0f}M") - print(f"Params EP (non shardati): {ep_params/1e6:.0f}M") - print(f"Params densi (shardati su dp_shard): {dense_params/1e6:.0f}M") - - rank = dist.get_rank() - free, total = torch.cuda.mem_get_info() - print(f"[rank{rank}] Memoria dopo init: {(total-free)/1e9:.1f} GB usati") - - -def main(): - _enable_torchtitan_moe_permute_fallback() - EXPERIMENTS_ROOT_PATH.mkdir(parents=True, exist_ok=True) - - with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): - modalities_main = Main( - config_path=CONFIG_FILE_PATH, - experiments_root_path=EXPERIMENTS_ROOT_PATH, - ) - - components = cast( - TrainingComponentsInstantiationModel, - modalities_main.build_components(components_model_type=TrainingComponentsInstantiationModel), - ) - - # WORKAROUNDS (wip) - # TODO implement those into moe code - # 1. some parameters remain on cpu - device = torch.device(f"cuda:{torch.cuda.current_device()}") - for name, param in components.model_raw.named_parameters(): - if param.device.type == "cpu": - param.data = param.data.to(device) - - # 2. cast EP params to bf16 — FSDP2 skips them via ignored_params, so they stay - # fp32 from model init. Cast here to match the MixedPrecisionPolicy applied to - # dense params (param_dtype=BF_16). Halves EP memory: 29 GB → 14.5 GB at tp=4. - for mod in components.model_raw.modules(): - if getattr(mod, "_ep_enabled", False): - for pname, p in list(mod._parameters.items()): - if isinstance(p, DTensor) and p.dtype != torch.bfloat16: - bf16_local = p.to_local().to(torch.bfloat16) - bf16_p = DTensor.from_local(bf16_local, p.device_mesh, p.placements, run_check=False) - mod._parameters[pname] = torch.nn.Parameter(bf16_p, requires_grad=p.requires_grad) - - debug_ep(components.model_raw) - modalities_main.run(components) - - -if __name__ == "__main__": - main() diff --git a/moe/scripts/monitor_gpus.sh b/scripts/monitor_gpus.sh similarity index 100% rename from moe/scripts/monitor_gpus.sh rename to scripts/monitor_gpus.sh From e09aa06f5f514d78c47a6d19751c8cd25a05b1f4 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 11 Jun 2026 09:19:42 +0000 Subject: [PATCH 05/12] docs: Add removed comments --- src/modalities/models/moe/loss_functions.py | 1 + src/modalities/models/moe/model_factory.py | 1 + src/modalities/optimizers/ep_adamw.py | 16 ++++++++++++++++ .../gradient_clipping/ep_gradient_clipper.py | 2 ++ 4 files changed, 20 insertions(+) diff --git a/src/modalities/models/moe/loss_functions.py b/src/modalities/models/moe/loss_functions.py index 642efb47a..57f30da69 100644 --- a/src/modalities/models/moe/loss_functions.py +++ b/src/modalities/models/moe/loss_functions.py @@ -31,6 +31,7 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: labels.contiguous().long().view(-1), ) + # Aux loss for layer in self.model.layers.values(): if hasattr(layer, "aux_loss") and layer.aux_loss is not None: loss = loss + layer.aux_loss.to(loss.dtype) diff --git a/src/modalities/models/moe/model_factory.py b/src/modalities/models/moe/model_factory.py index 406da1964..d5b95d9bb 100644 --- a/src/modalities/models/moe/model_factory.py +++ b/src/modalities/models/moe/model_factory.py @@ -10,6 +10,7 @@ from modalities.util import get_module_class_from_name +# TODO refactor these funtions into a utils def _resolve_ep_mesh(device_mesh: DeviceMesh, ep_mesh_dim_name: str | None) -> DeviceMesh: mesh_dim_names = tuple(device_mesh.mesh_dim_names or ()) diff --git a/src/modalities/optimizers/ep_adamw.py b/src/modalities/optimizers/ep_adamw.py index 2b5e72aae..006f9faf9 100644 --- a/src/modalities/optimizers/ep_adamw.py +++ b/src/modalities/optimizers/ep_adamw.py @@ -24,6 +24,14 @@ def _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_ class EPAdamW(Optimizer): + """ + ZeRO stage-1 for EP (DTensor) params + standard AdamW for dense params. + + Each dp_shard rank stores optimizer states for 1/dp_shard of the EP params. + After each step, updated EP param values are broadcast from owner to all ranks. + Dense params are handled by a separate AdamW (FSDP2 shards them independently). + """ + def __init__( self, model: Module, @@ -42,6 +50,7 @@ def __init__( ep_param_ids = _get_ep_param_ids(model) self._all_ep_params = [p for p in model.parameters() if id(p) in ep_param_ids] + # rank r owns params[r::dp_size] self._owned_ep_params = self._all_ep_params[self._dp_rank :: self._dp_size] dense_groups = _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_groups_excluded) @@ -52,6 +61,8 @@ def __init__( self._ep_adamw = None self._dense_adamw = AdamW(dense_groups, lr=lr, betas=betas, eps=eps) + # unified param groups for lr_scheduler compatibility: + # group 0 = all EP params, groups 1+ = dense weight-decay split ep_group = {"params": self._all_ep_params, "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay} all_groups = [ep_group] + [{**g, "lr": lr, "betas": betas, "eps": eps} for g in dense_groups] super().__init__(all_groups, {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}) @@ -63,6 +74,7 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() + # all-reduce for p in self._all_ep_params: if p.grad is None: continue @@ -74,16 +86,20 @@ def step(self, closure=None): dist.all_reduce(p.grad, op=dist.ReduceOp.SUM, group=self._dp_group) p.grad.div_(self._dp_size) + # Sync lr if self._ep_adamw is not None: self._ep_adamw.param_groups[0]["lr"] = self.param_groups[0]["lr"] for i, group in enumerate(self._dense_adamw.param_groups): group["lr"] = self.param_groups[i + 1]["lr"] + # Update ep params if self._ep_adamw is not None: self._ep_adamw.step() + # Update dense params self._dense_adamw.step() + # broadcast updated EP param local tensors for i, p in enumerate(self._all_ep_params): owner_local_rank = i % self._dp_size owner_global_rank = dist.get_global_rank(self._dp_group, owner_local_rank) diff --git a/src/modalities/training/gradient_clipping/ep_gradient_clipper.py b/src/modalities/training/gradient_clipping/ep_gradient_clipper.py index a2b6f25b8..2efc5ed58 100644 --- a/src/modalities/training/gradient_clipping/ep_gradient_clipper.py +++ b/src/modalities/training/gradient_clipping/ep_gradient_clipper.py @@ -78,6 +78,8 @@ def clip_gradients(self) -> torch.Tensor: dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) total_norm **= 1.0 / self.norm_type.value + # do not use torch.nn.utils.clip_grads_with_norm_ here: it batches grads with + # torch._foreach_mul_, which fails when the list mixes DTensors from different meshes. clip_coef = self.max_norm / (total_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) From baf94e945e94f69c37e4fb15db9b3203b5a2025c Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 11 Jun 2026 09:33:38 +0000 Subject: [PATCH 06/12] test: Add tests for MoE components --- tests/models/moe/__init__.py | 0 tests/models/moe/test_loss_functions.py | 59 ++++++++++++ tests/models/moe/test_qwen_model.py | 60 ++++++++++++ tests/optimizers/test_ep_adamw.py | 92 +++++++++++++++++++ .../test_ep_gradient_clipper.py | 50 ++++++++++ 5 files changed, 261 insertions(+) create mode 100644 tests/models/moe/__init__.py create mode 100644 tests/models/moe/test_loss_functions.py create mode 100644 tests/models/moe/test_qwen_model.py create mode 100644 tests/optimizers/test_ep_adamw.py create mode 100644 tests/training/gradient_clipping/test_ep_gradient_clipper.py diff --git a/tests/models/moe/__init__.py b/tests/models/moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/models/moe/test_loss_functions.py b/tests/models/moe/test_loss_functions.py new file mode 100644 index 000000000..346b69818 --- /dev/null +++ b/tests/models/moe/test_loss_functions.py @@ -0,0 +1,59 @@ +import torch +from torch.nn import CrossEntropyLoss + +from modalities.batch import InferenceResultBatch +from modalities.models.moe.loss_functions import MoECrossEntropyLoss + + +class DummyLayer: + def __init__(self, aux_loss): + self.aux_loss = aux_loss + + +class DummyModel: + def __init__(self, aux_losses: list[torch.Tensor | None]): + self.layers = {str(i): DummyLayer(aux) for i, aux in enumerate(aux_losses)} + + +def test_moe_cross_entropy_loss_adds_aux_losses(): + logits = torch.tensor( + [ + [[1.2, 0.3, -0.5], [0.1, 1.8, -0.3]], + [[0.5, -0.4, 1.1], [0.7, 0.2, -0.1]], + ], + dtype=torch.float32, + ) + targets = torch.tensor([[0, 1], [2, 0]], dtype=torch.long) + + batch = InferenceResultBatch( + targets={"targets": targets}, + predictions={"logits": logits}, + ) + + aux_1 = torch.tensor(0.2) + aux_2 = torch.tensor(0.3) + model = DummyModel(aux_losses=[aux_1, None, aux_2]) + loss_fn = MoECrossEntropyLoss(target_key="targets", prediction_key="logits", model=model) + + loss = loss_fn(batch) + base_ce = CrossEntropyLoss(reduction="mean")(logits.view(-1, logits.size(-1)), targets.view(-1)) + + assert torch.allclose(loss, base_ce + aux_1 + aux_2) + + +def test_moe_cross_entropy_loss_without_aux_matches_plain_ce(): + logits = torch.randn(2, 3, 5) + targets = torch.randint(0, 5, (2, 3), dtype=torch.long) + + batch = InferenceResultBatch( + targets={"labels": targets}, + predictions={"pred": logits}, + ) + + model = DummyModel(aux_losses=[None, None]) + loss_fn = MoECrossEntropyLoss(target_key="labels", prediction_key="pred", model=model) + + loss = loss_fn(batch) + expected = CrossEntropyLoss(reduction="mean")(logits.view(-1, logits.size(-1)), targets.view(-1)) + + assert torch.allclose(loss, expected) diff --git a/tests/models/moe/test_qwen_model.py b/tests/models/moe/test_qwen_model.py new file mode 100644 index 000000000..d4d90b592 --- /dev/null +++ b/tests/models/moe/test_qwen_model.py @@ -0,0 +1,60 @@ +import torch + +from modalities.models.moe.qwen_model import GroupedExperts, QwenModel + + +def _build_tiny_qwen_model() -> QwenModel: + return QwenModel( + vocab_size=32, + max_seq_len=16, + d_model=16, + n_heads=4, + n_kv_heads=2, + d_ff=32, + num_layers=1, + moe_d_ff=24, + moe_num_experts=4, + moe_top_k=2, + moe_capacity_factor=1.25, + moe_min_capacity=1, + moe_overflow_policy="residual", + moe_aux_loss_coef=0.01, + moe_z_loss_coef=0.0, + ) + + +def test_qwen_model_forward_dict_output_shape(): + torch.manual_seed(0) + model = _build_tiny_qwen_model() + batch_size, seq_len = 2, 5 + + input_ids = torch.randint(0, 32, (batch_size, seq_len), dtype=torch.long) + output = model({"input_ids": input_ids}) + + assert "logits" in output + assert output["logits"].shape == (batch_size, seq_len, 32) + + +def test_grouped_experts_forward_local_preserves_input_dtype(): + experts = GroupedExperts(num_experts=2, d_model=8, d_ff=12, ffn_dropout=0.0) + experts.reset_parameters() + + # Input in bf16 while expert weights are initialized in fp32. + routed_input = torch.randn(4, 8, dtype=torch.bfloat16) + num_tokens_per_expert = torch.tensor([2, 2], dtype=torch.long) + + out = experts._forward_local(routed_input=routed_input, num_tokens_per_expert=num_tokens_per_expert) + + assert out.shape == routed_input.shape + assert out.dtype == routed_input.dtype + + +def test_transformer_block_exposes_aux_loss_after_forward(): + torch.manual_seed(1) + model = _build_tiny_qwen_model() + input_ids = torch.randint(0, 32, (2, 4), dtype=torch.long) + + _ = model({"input_ids": input_ids}) + + first_layer = next(iter(model.layers.values())) + assert first_layer.aux_loss is not None diff --git a/tests/optimizers/test_ep_adamw.py b/tests/optimizers/test_ep_adamw.py new file mode 100644 index 000000000..bb366627b --- /dev/null +++ b/tests/optimizers/test_ep_adamw.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn + +from modalities.models.model import NNModel +from modalities.optimizers.ep_adamw import EPAdamW + + +class DummyDPShardMesh: + def __init__(self): + self._group = object() + + def get_group(self): + return self._group + + +class EPSubmodule(nn.Module): + def __init__(self): + super().__init__() + self.ep_weight = nn.Parameter(torch.tensor([1.0, -1.0])) + self._ep_enabled = True + + +class TinyModel(NNModel): + def __init__(self): + super().__init__(weight_decay_groups={"linear": ["linear"], "embedding": [], "layernorm": ["norm"]}) + self.linear = nn.Linear(2, 2, bias=False) + self.norm = nn.LayerNorm(2) + self.experts = EPSubmodule() + + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + x = inputs["x"] + return {"y": self.linear(x)} + + +def _patch_distributed_ops(monkeypatch): + from modalities.optimizers import ep_adamw as ep_adamw_module + + monkeypatch.setattr(ep_adamw_module.dist, "get_rank", lambda group=None: 0) + monkeypatch.setattr(ep_adamw_module.dist, "get_world_size", lambda group=None: 1) + monkeypatch.setattr(ep_adamw_module.dist, "all_reduce", lambda tensor, op=None, group=None: tensor) + monkeypatch.setattr(ep_adamw_module.dist, "broadcast", lambda tensor, src=0, group=None: tensor) + monkeypatch.setattr(ep_adamw_module.dist, "get_global_rank", lambda group, group_rank: group_rank) + + +def test_ep_adamw_state_dict_and_load_state_dict(monkeypatch): + _patch_distributed_ops(monkeypatch) + + model = TinyModel() + optimizer = EPAdamW( + model=model, + device_mesh={"dp_shard": DummyDPShardMesh()}, + lr=1e-2, + betas=(0.9, 0.95), + eps=1e-8, + weight_decay=0.1, + weight_decay_groups_excluded=["layernorm"], + ) + + state = optimizer.state_dict() + assert "ep_adamw" in state + assert "dense_adamw" in state + + optimizer.load_state_dict(state) + + +def test_ep_adamw_step_updates_parameters_and_zero_grad(monkeypatch): + _patch_distributed_ops(monkeypatch) + + model = TinyModel() + optimizer = EPAdamW( + model=model, + device_mesh={"dp_shard": DummyDPShardMesh()}, + lr=1e-2, + betas=(0.9, 0.95), + eps=1e-8, + weight_decay=0.1, + weight_decay_groups_excluded=["layernorm"], + ) + + before = [p.detach().clone() for p in model.parameters()] + for p in model.parameters(): + p.grad = torch.ones_like(p) + + optimizer.step() + after = list(model.parameters()) + + for p_before, p_after in zip(before, after): + assert not torch.allclose(p_before, p_after) + + optimizer.zero_grad(set_to_none=True) + for p in model.parameters(): + assert p.grad is None diff --git a/tests/training/gradient_clipping/test_ep_gradient_clipper.py b/tests/training/gradient_clipping/test_ep_gradient_clipper.py new file mode 100644 index 000000000..322ece24f --- /dev/null +++ b/tests/training/gradient_clipping/test_ep_gradient_clipper.py @@ -0,0 +1,50 @@ +import pytest +import torch +import torch.nn as nn + +from modalities.training.gradient_clipping.ep_gradient_clipper import EPGradientClipper +from modalities.training.gradient_clipping.fsdp_gradient_clipper import GradientClippingMode + + +class MockModel(nn.Module): + def __init__(self): + super().__init__() + self.param1 = nn.Parameter(torch.tensor([1.0, 2.0])) + self.param2 = nn.Parameter(torch.tensor([3.0, 4.0])) + + +def test_ep_gradient_clipper_clips_gradients(): + model = MockModel() + model.param1.grad = torch.tensor([1.0, 1.0]) + model.param2.grad = torch.tensor([1.0, 1.0]) + + clipper = EPGradientClipper(model_parts=model, max_norm=1.0, norm_type=GradientClippingMode.P2_NORM) + total_norm = clipper.clip_gradients() + + assert torch.allclose(total_norm, torch.tensor(2.0)) + assert torch.allclose(model.param1.grad, torch.tensor([0.5, 0.5]), atol=1e-6) + assert torch.allclose(model.param2.grad, torch.tensor([0.5, 0.5]), atol=1e-6) + + +def test_ep_gradient_clipper_returns_zero_for_no_gradients(): + model = MockModel() + + clipper = EPGradientClipper(model_parts=model, max_norm=1.0, norm_type=GradientClippingMode.P2_NORM) + total_norm = clipper.clip_gradients() + + assert torch.allclose(total_norm.cpu(), torch.tensor(0.0)) + + +def test_ep_gradient_clipper_raises_for_nonfinite_norm(): + model = MockModel() + model.param1.grad = torch.tensor([float("nan"), 1.0]) + + clipper = EPGradientClipper( + model_parts=model, + max_norm=1.0, + norm_type=GradientClippingMode.P2_NORM, + error_if_nonfinite=True, + ) + + with pytest.raises(RuntimeError, match="non-finite"): + clipper.clip_gradients() From 2a3b81a36378f6e84eab07edaa8fbfe8f505739a Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 11 Jun 2026 09:58:49 +0000 Subject: [PATCH 07/12] test: Add e2e moe test --- tests/end2end_tests/test_moe_ep_fsdp2_e2e.py | 142 +++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 tests/end2end_tests/test_moe_ep_fsdp2_e2e.py diff --git a/tests/end2end_tests/test_moe_ep_fsdp2_e2e.py b/tests/end2end_tests/test_moe_ep_fsdp2_e2e.py new file mode 100644 index 000000000..70e7203e9 --- /dev/null +++ b/tests/end2end_tests/test_moe_ep_fsdp2_e2e.py @@ -0,0 +1,142 @@ +import logging +import multiprocessing as py_mp +import os +import traceback +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.multiprocessing as mp + +from modalities.__main__ import Main, load_app_config_dict +from modalities.batch import EvaluationResultBatch +from modalities.config.config import ProcessGroupBackendType +from modalities.config.instantiation_models import TrainingComponentsInstantiationModel +from modalities.logging_broker.messages import Message +from tests.end2end_tests.custom_components import ( + MultiProcessingCudaEnv, + SaveAllResultSubscriber, + SaveAllResultSubscriberConfig, +) +from tests.utility import find_free_port, monitor_child_processes + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="This E2E test requires 4 CUDA devices.") +class TestMoEEPFSDP2E2E: + @staticmethod + def _patch_for_short_test_run(config_dict: dict[str, Any], checkpoint_root_path: Path) -> None: + # Keep runtime short while preserving EP + FSDP2 wiring. + config_dict["settings"]["intervals"]["training_log_interval_in_steps"] = 1 + config_dict["settings"]["intervals"]["checkpointing_interval_in_steps"] = 1 + config_dict["settings"]["intervals"]["evaluation_interval_in_steps"] = 1000 + + config_dict["settings"]["step_profile"]["sequence_length"] = 64 + config_dict["settings"]["step_profile"]["local_train_micro_batch_size"] = 1 + config_dict["settings"]["step_profile"]["gradient_accumulation_steps"] = 1 + + config_dict["settings"]["training_target"]["num_target_tokens"] = 512 + config_dict["settings"]["training_target"]["num_target_steps"] = 2 + config_dict["lr_scheduler"]["config"]["total_steps"] = 2 + + config_dict["train_dataset"]["config"]["sequence_length"] = 64 + config_dict["test_dataset"]["config"]["sequence_length"] = 64 + config_dict["train_dataloader"]["config"]["num_workers"] = 0 + config_dict["test_dataloader"]["config"]["num_workers"] = 0 + config_dict["train_dataloader"]["config"]["pin_memory"] = False + config_dict["test_dataloader"]["config"]["pin_memory"] = False + + config_dict["settings"]["paths"]["checkpoint_saving_path"] = checkpoint_root_path + config_dict["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"][ + "checkpoint_path" + ] = checkpoint_root_path + + @staticmethod + def _worker_wrapper( + process_id: int, + world_size: int, + rdvz_port: int, + config_file_path: Path, + tmp_path: Path, + error_queue: Any, + ) -> None: + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=rdvz_port, + ): + try: + TestMoEEPFSDP2E2E._worker_impl( + process_id=process_id, + config_file_path=config_file_path, + tmp_path=tmp_path, + ) + except Exception as exc: + tb = traceback.format_exc() + logging.error(f"Process {process_id} failed: {exc}\n{tb}") + try: + error_queue.put((process_id, tb)) + except Exception: + logging.error("Failed to write child exception to queue.") + os._exit(1) + + @staticmethod + def _worker_impl(process_id: int, config_file_path: Path, tmp_path: Path) -> None: + experiment_id = "moe-ep-fsdp2-e2e" + checkpoint_root_path = tmp_path / experiment_id / "checkpoints" + cfg = load_app_config_dict( + config_file_path=config_file_path, experiments_root_path=tmp_path, experiment_id=experiment_id + ) + TestMoEEPFSDP2E2E._patch_for_short_test_run(cfg, checkpoint_root_path) + + main_obj = Main(config_file_path, experiments_root_path=tmp_path, experiment_id=experiment_id) + main_obj.config_dict = cfg + main_obj.add_custom_component( + component_key="results_subscriber", + variant_key="save_all", + custom_component=SaveAllResultSubscriber, + custom_config=SaveAllResultSubscriberConfig, + ) + main_obj.config_dict["evaluation_subscriber"]["variant_key"] = "save_all" + main_obj.config_dict["evaluation_subscriber"]["config"] = {} + + components: TrainingComponentsInstantiationModel = main_obj.build_components( + components_model_type=TrainingComponentsInstantiationModel + ) + + assert getattr(components.model_raw, "_ep_wrapped", False), "Expected EP wrapping marker on raw model." + first_layer = next(iter(components.model_raw.layers.values())) + assert getattr(first_layer.ffn.experts, "_ep_enabled", False), "Expected experts to be EP-enabled." + + main_obj.run(components) + + result_messages: list[Message[EvaluationResultBatch]] = components.evaluation_subscriber.message_list + assert len(result_messages) > 0, "Expected training messages in evaluation subscriber." + for message in result_messages: + loss_value = message.payload.losses["train loss avg"].value + assert torch.isfinite(loss_value), f"Found non-finite train loss: {loss_value}" + + if process_id == 0: + checkpoint_info_file_path = checkpoint_root_path / "last_checkpoint_info.json" + assert checkpoint_info_file_path.exists(), "Expected checkpoint info file from DCP save." + + @staticmethod + def test_moe_ep_fsdp2_training_and_checkpointing(tmp_path: Path) -> None: + repo_root = Path(__file__).resolve().parents[2] + config_file_path = repo_root / "config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml" + + world_size = 4 + rdvz_port = find_free_port() + + manager = py_mp.Manager() + error_queue = manager.Queue() + proc_ctx = mp.spawn( + TestMoEEPFSDP2E2E._worker_wrapper, + args=(world_size, rdvz_port, config_file_path, tmp_path, error_queue), + nprocs=world_size, + join=False, + ) + + monitor_child_processes(manager, error_queue, proc_ctx) From f7f87b1ab07735079efad5bc105878fad84b8fb5 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 11 Jun 2026 11:53:34 +0000 Subject: [PATCH 08/12] refactor: Merge shared rotary embedding logic --- .../models/components/rotary_embedding.py | 126 +++++++++++++++ src/modalities/models/gpt2/gpt2_model.py | 149 ++++-------------- src/modalities/models/moe/qwen_model.py | 38 +++-- 3 files changed, 184 insertions(+), 129 deletions(-) create mode 100644 src/modalities/models/components/rotary_embedding.py diff --git a/src/modalities/models/components/rotary_embedding.py b/src/modalities/models/components/rotary_embedding.py new file mode 100644 index 000000000..c569a787e --- /dev/null +++ b/src/modalities/models/components/rotary_embedding.py @@ -0,0 +1,126 @@ +import math +from typing import Optional + +import torch + + +def compute_default_inv_freq(dim_model: int, base_freq: float, device: Optional[torch.device] = None) -> torch.Tensor: + return 1.0 / (base_freq ** (torch.arange(0, dim_model, 2, device=device).float() / dim_model)) + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seq_length_dim: int) -> torch.Tensor: + cos = cos[:, :, : x.shape[seq_length_dim], :] + sin = sin[:, :, : x.shape[seq_length_dim], :] + return (x * cos) + (rotate_half(x) * sin) + + +def update_cos_sin_tables( + x: torch.Tensor, + inv_freq: torch.Tensor, + attention_scaling: float, + seq_length_dim: int, + seq_len_cached: Optional[int], + cos_cached: Optional[torch.Tensor], + sin_cached: Optional[torch.Tensor], +) -> tuple[int, torch.Tensor, torch.Tensor]: + seq_len = x.shape[seq_length_dim] + + if ( + seq_len != seq_len_cached + or cos_cached is None + or sin_cached is None + or cos_cached.device != x.device + or cos_cached.dtype != x.dtype + ): + t = torch.arange(seq_len, device=x.device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + cos_cached = (emb.cos() * attention_scaling)[None, None, :, :].to(x.dtype) + sin_cached = (emb.sin() * attention_scaling)[None, None, :, :].to(x.dtype) + seq_len_cached = seq_len + + return seq_len_cached, cos_cached, sin_cached + + +def compute_yarn_inv_freq_and_attention_scaling( + dim_model: int, + base_freq: float, + max_position_embeddings: int, + original_max_position_embeddings: int, + factor: Optional[float], + attention_factor: Optional[float], + mscale: Optional[float], + mscale_all_dim: Optional[float], + beta_fast: float, + beta_slow: float, + truncate: bool, + device: Optional[torch.device] = None, +) -> tuple[torch.Tensor, float]: + factor_float = ( + float(factor) if factor is not None else float(max_position_embeddings / original_max_position_embeddings) + ) + + def get_mscale(scale: float, mscale_value: float = 1.0) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale_value * math.log(scale) + 1.0 + + if attention_factor is None: + if mscale is not None and mscale_all_dim is not None: + attention_factor = float( + get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim)) + ) + else: + attention_factor = get_mscale(factor_float) + + def find_correction_dim(num_rotations: float, dim: int, base: float, max_pos_emb: int) -> float: + return (dim * math.log(max_pos_emb / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range( + low_rot: float, + high_rot: float, + dim: int, + base: float, + max_pos_emb: int, + do_truncate: bool, + ) -> tuple[float, float]: + low = find_correction_dim(low_rot, dim, base, max_pos_emb) + high = find_correction_dim(high_rot, dim, base, max_pos_emb) + if do_truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor: + if min_value == max_value: + max_value += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) + return torch.clamp(linear_func, 0, 1) + + pos_freqs = base_freq ** (torch.arange(0, dim_model, 2, device=device, dtype=torch.float) / dim_model) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor_float * pos_freqs) + + low, high = find_correction_range( + beta_fast, + beta_slow, + dim_model, + base_freq, + original_max_position_embeddings, + bool(truncate), + ) + + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim_model // 2).to( + device=device, dtype=torch.float + ) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, float(attention_factor) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index f43e6e87b..2e93b0be1 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -3,7 +3,7 @@ from abc import abstractmethod from enum import Enum from numbers import Real -from typing import Annotated, Literal, Optional, overload +from typing import Annotated, Literal, Optional, cast, overload import torch import torch.nn as nn @@ -17,6 +17,13 @@ RMSLayerNorm, RMSLayerNormConfig, ) +from modalities.models.components.rotary_embedding import ( + apply_rotary_pos_emb, + compute_default_inv_freq, + compute_yarn_inv_freq_and_attention_scaling, + rotate_half, + update_cos_sin_tables, +) from modalities.models.model import ActivationType, NNModel, SwiGLU from modalities.util import parse_enum_by_name @@ -221,9 +228,7 @@ def reset_parameters(self): if rope_type == "yarn": inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device) else: - inv_freq = 1.0 / ( - self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) - ) + inv_freq = compute_default_inv_freq(dim_model=self.dim_model, base_freq=self.base_freq, device=device) self.attention_scaling = 1.0 self.register_buffer("inv_freq", inv_freq) @@ -243,8 +248,7 @@ def rotate_half(self, x: torch.Tensor): torch.Tensor: The output tensor. """ - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) + return rotate_half(x) def apply_rotary_pos_emb(self, x, cos, sin): """ @@ -258,16 +262,7 @@ def apply_rotary_pos_emb(self, x, cos, sin): Returns: torch.Tensor: Tensor after applying rotary positional embedding. """ - # NOTE: This could probably be moved to Triton - - # Handle a possible sequence length mismatch in between q and k - cos = cos[:, :, : x.shape[self.seq_length_dim], :] - sin = sin[:, :, : x.shape[self.seq_length_dim], :] - - # the rotation is not really a rotation in higher dimensions, - # It merely swaps and negates certain dimensions to make - # the rotation below work - return (x * cos) + (self.rotate_half(x) * sin) + return apply_rotary_pos_emb(x=x, cos=cos, sin=sin, seq_length_dim=self.seq_length_dim) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor @@ -297,109 +292,31 @@ def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.T if self.max_position_embeddings is None: raise ValueError("YaRN requires max_position_embeddings to be set.") - original_max_position_embeddings = self.rope_scaling.original_max_position_embeddings - factor = self.rope_scaling.factor - if factor is None: - factor = self.max_position_embeddings / original_max_position_embeddings - factor_float = float(factor) - - attention_factor = self.rope_scaling.attention_factor - mscale_pair = None - if self.rope_scaling.mscale is not None and self.rope_scaling.mscale_all_dim is not None: - mscale_pair = (self.rope_scaling.mscale, self.rope_scaling.mscale_all_dim) - - beta_fast = self.rope_scaling.beta_fast - beta_slow = self.rope_scaling.beta_slow - truncate = self.rope_scaling.truncate - - def get_mscale(scale: float, mscale: float = 1.0) -> float: - """Return the YaRN mscale coefficient for a given scaling factor.""" - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - if attention_factor is None: - if mscale_pair is not None: - mscale, mscale_all_dim = mscale_pair - attention_factor = float( - get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim)) - ) - else: - attention_factor = get_mscale(factor_float) - - def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float: - """Map a target number of rotations to a rotary dimension index.""" - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - def find_correction_range( - low_rot: float, - high_rot: float, - dim: int, - base: int, - max_position_embeddings: int, - truncate: bool, - ) -> tuple[float, float]: - """Compute the lower and upper rotary-dimension correction bounds for YaRN.""" - low = find_correction_dim(low_rot, dim, base, max_position_embeddings) - high = find_correction_dim(high_rot, dim, base, max_position_embeddings) - if truncate: - low = math.floor(low) - high = math.ceil(high) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor: - """Create a clamped linear ramp used to blend interpolation and extrapolation.""" - if min_value == max_value: - max_value += 0.001 - linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - dim = self.dim_model - base = self.base_freq - - pos_freqs = base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (factor_float * pos_freqs) - - low, high = find_correction_range( - beta_fast, - beta_slow, - dim, - base, - original_max_position_embeddings, - bool(truncate), - ) - inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - + inv_freq_extrapolation * inv_freq_extrapolation_factor + return compute_yarn_inv_freq_and_attention_scaling( + dim_model=self.dim_model, + base_freq=self.base_freq, + max_position_embeddings=self.max_position_embeddings, + original_max_position_embeddings=self.rope_scaling.original_max_position_embeddings, + factor=self.rope_scaling.factor, + attention_factor=self.rope_scaling.attention_factor, + mscale=self.rope_scaling.mscale, + mscale_all_dim=self.rope_scaling.mscale_all_dim, + beta_fast=self.rope_scaling.beta_fast, + beta_slow=self.rope_scaling.beta_slow, + truncate=self.rope_scaling.truncate, + device=device, ) - return inv_freq, float(attention_factor) - def _update_cos_sin_tables(self, x): - # Update the cosine and sine tables. - seq_len = x.shape[self.seq_length_dim] - - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seq_len != self._seq_len_cached - or self._cos_cached is None - or self._sin_cached is None - or self._cos_cached.device != x.device - or self._cos_cached.dtype != x.dtype - ): - self._seq_len_cached = seq_len - t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) - emb = torch.cat((freqs, freqs), dim=-1).to( - x.device - ) # here, we combine the two matrices (not zipping them). - self._cos_cached = (emb.cos() * self.attention_scaling)[None, None, :, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.attention_scaling)[None, None, :, :].to(x.dtype) - + self._seq_len_cached, self._cos_cached, self._sin_cached = update_cos_sin_tables( + x=x, + inv_freq=cast(torch.Tensor, self.inv_freq), + attention_scaling=self.attention_scaling, + seq_length_dim=self.seq_length_dim, + seq_len_cached=self._seq_len_cached, + cos_cached=self._cos_cached, + sin_cached=self._sin_cached, + ) return self._cos_cached, self._sin_cached diff --git a/src/modalities/models/moe/qwen_model.py b/src/modalities/models/moe/qwen_model.py index ac4cab752..b20dd05ef 100644 --- a/src/modalities/models/moe/qwen_model.py +++ b/src/modalities/models/moe/qwen_model.py @@ -6,6 +6,11 @@ import torch.nn.functional as F from pydantic import BaseModel +from modalities.models.components.rotary_embedding import ( + apply_rotary_pos_emb, + compute_default_inv_freq, + update_cos_sin_tables, +) from modalities.models.model import NNModel try: @@ -62,33 +67,40 @@ def __init__(self, head_dim: int, max_seq_len: int, base: float = 1000000.0): self.head_dim = head_dim self.max_seq_len = max_seq_len self.base = base + self.register_buffer("inv_freq", None, persistent=False) self.register_buffer("cos_cached", None, persistent=False) self.register_buffer("sin_cached", None, persistent=False) + self._seq_len_cached: Optional[int] = None def _compute_cache(self, device): - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim)) - t = torch.arange(self.max_seq_len, device=device).float() - freqs = torch.outer(t, inv_freq) - emb = torch.cat([freqs, freqs], dim=-1) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self.inv_freq = compute_default_inv_freq(dim_model=self.head_dim, base_freq=self.base, device=device) + self._seq_len_cached = None + self.cos_cached = None + self.sin_cached = None def forward(self, x: torch.Tensor, seq_len: int): - if self.cos_cached is None: + if self.inv_freq is None: self._compute_cache(x.device) + self._seq_len_cached, self.cos_cached, self.sin_cached = update_cos_sin_tables( + x=x, + inv_freq=self.inv_freq, + attention_scaling=1.0, + seq_length_dim=-2, + seq_len_cached=self._seq_len_cached, + cos_cached=self.cos_cached, + sin_cached=self.sin_cached, + ) return ( self.cos_cached[:, :, :seq_len, :].to(x.dtype), self.sin_cached[:, :, :seq_len, :].to(x.dtype), ) -def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - return torch.cat([-x2, x1], dim=-1) - - def apply_rotary_emb(q, k, cos, sin): - return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return ( + apply_rotary_pos_emb(x=q, cos=cos, sin=sin, seq_length_dim=-2), + apply_rotary_pos_emb(x=k, cos=cos, sin=sin, seq_length_dim=-2), + ) class GroupedQueryAttention(nn.Module): From 3eb373e10ec04b2196975275d977593423c7cbb2 Mon Sep 17 00:00:00 2001 From: Giovanni Esposito Date: Wed, 17 Jun 2026 06:40:12 +0200 Subject: [PATCH 09/12] implemented torch native expert parallelism, torchtitan dependency removed --- src/modalities/models/moe/model_factory.py | 7 +- .../models/parallelism/expert_parallelism.py | 151 ++++++++++++++++++ 2 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 src/modalities/models/parallelism/expert_parallelism.py diff --git a/src/modalities/models/moe/model_factory.py b/src/modalities/models/moe/model_factory.py index d5b95d9bb..4348add09 100644 --- a/src/modalities/models/moe/model_factory.py +++ b/src/modalities/models/moe/model_factory.py @@ -5,8 +5,7 @@ import torch.nn as nn from torch.distributed._composable.fsdp import MixedPrecisionPolicy from torch.distributed.device_mesh import DeviceMesh -from torchtitan.distributed.expert_parallel import ExpertParallel - +from modalities.models.parallelism.expert_parallelism import ExpertParallel from modalities.util import get_module_class_from_name @@ -65,7 +64,7 @@ def _attach_ep_metadata(module, ep_mesh) -> None: setattr(module, "_ep_rank", ep_mesh.get_local_rank()) -def _apply_torchtitan_ep(module, ep_mesh) -> None: +def _apply_ep(module, ep_mesh) -> None: module.experts = ExpertParallel()._apply(module.experts, ep_mesh) setattr(module.experts, "_ep_enabled", True) @@ -116,7 +115,7 @@ def get_ep_wrapped_model( _validate_moe_block_for_ep(ep_target_module) _attach_ep_metadata(ep_target_module, ep_mesh) - _apply_torchtitan_ep(ep_target_module, ep_mesh) + _apply_ep(ep_target_module, ep_mesh) wrapped_blocks += 1 diff --git a/src/modalities/models/parallelism/expert_parallelism.py b/src/modalities/models/parallelism/expert_parallelism.py new file mode 100644 index 000000000..a2030e3cc --- /dev/null +++ b/src/modalities/models/parallelism/expert_parallelism.py @@ -0,0 +1,151 @@ +# Some portions of this implementation are inspired, adapted, or refactored +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + +import torch +import torch.nn as nn +from torch import Tensor +from torch.distributed._functional_collectives import all_to_all_single, all_to_all_single_autograd +from torch.distributed.tensor import DeviceMesh, Shard, distribute_module, distribute_tensor + + +def _permute_tokens( + x: Tensor, + num_tokens_per_expert_group: Tensor, + ep_degree: int, + num_local_experts: int, +) -> tuple[tuple, Tensor, Tensor, Tensor]: + """ + Reorder tokens from the post-all-to-all layout to per-local-expert contiguous layout. + + After the all-to-all, received tokens are ordered as: + [e0_from_rank0 tokens, e1_from_rank0 tokens, ..., e0_from_rank1 tokens, ...] + + We reorder to: + [all tokens for local_expert_0, all tokens for local_expert_1, ...] + + Returns (original_shape, x_permuted, permuted_indices, new_num_tokens_per_expert). + """ + counts = num_tokens_per_expert_group.view(ep_degree, num_local_experts) # (ep_degree, num_local_experts) + + flat_counts = counts.flatten() # length = ep_degree * num_local_experts + + offsets = flat_counts.cumsum(0) - flat_counts + + # build permuted_indices + indices_per_expert: list[Tensor] = [] + for e in range(num_local_experts): + for r in range(ep_degree): + count = int(counts[r, e].item()) + if count > 0: + start = int(offsets[r * num_local_experts + e].item()) + indices_per_expert.append( + torch.arange(start, start + count, device=x.device, dtype=torch.long) + ) + + if indices_per_expert: + permuted_indices = torch.cat(indices_per_expert) + else: + permuted_indices = torch.zeros(0, dtype=torch.long, device=x.device) + + new_num_tokens_per_expert = counts.sum(dim=0) # (num_local_experts,) + original_shape = x.shape + x_permuted = x[permuted_indices] if permuted_indices.numel() > 0 else x.new_zeros((0, x.shape[-1])) + return original_shape, x_permuted, permuted_indices, new_num_tokens_per_expert + + +def _unpermute_tokens(out: Tensor, original_shape: tuple, permuted_indices: Tensor) -> Tensor: + """ + Inverse of _permute_tokens: scatter expert outputs back to the all-to-all layout. + """ + out_unpermuted = out.new_zeros(original_shape) + if permuted_indices.numel() > 0: + out_unpermuted[permuted_indices] = out + return out_unpermuted + + +class ExpertParallel: + """ + Expert Parallelism for grouped-expert MoE layers. + + Shards GroupedExperts parameters on the expert dimension (Shard(0)) across EP ranks, + and wraps forward() with all-to-all token dispatch/combine collectives. + + Usage: + module.experts = ExpertParallel()._apply(module.experts, ep_mesh) + """ + + def __init__(self) -> None: + self.input_splits: list[int] | None = None + self.output_splits: list[int] | None = None + self.original_shape: tuple | None = None + self.permuted_indices: Tensor | None = None + + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: + for param_name, param in mod.named_parameters(recurse=False): + mod.register_parameter( + param_name, + nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])), + ) + + def _token_dispatch( + self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh + ) -> tuple[Tensor, Tensor]: + routed_input, num_tokens_per_expert = inputs + ep_degree = device_mesh.shape[0] + num_local_experts = num_tokens_per_expert.shape[0] // ep_degree + + with torch.no_grad(): + num_tokens_per_expert_group = all_to_all_single( + num_tokens_per_expert, None, None, group=device_mesh.get_group() + ) + + num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( + num_tokens_per_expert_group + ) + input_splits = ( + num_tokens_per_expert.view(ep_degree, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=True) + ) + + output_splits = ( + num_tokens_per_expert_group.view(ep_degree, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=False) + ) + self.input_splits = input_splits.tolist() + self.output_splits = output_splits.tolist() + + routed_input = all_to_all_single_autograd( + routed_input, + self.output_splits, + self.input_splits, + device_mesh.get_group(), + ) + + self.original_shape, routed_input, self.permuted_indices, num_tokens_per_expert_group = ( + _permute_tokens(routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts) + ) + return routed_input, num_tokens_per_expert_group + + def _token_combine( + self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh + ) -> Tensor: + routed_output = _unpermute_tokens(routed_output, self.original_shape, self.permuted_indices) + routed_output = all_to_all_single_autograd( + routed_output, + self.input_splits, + self.output_splits, + device_mesh.get_group(), + ) + return routed_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) From e0217fbcc84d9340d0fd7f80fa1b156c0325d349 Mon Sep 17 00:00:00 2001 From: Giovanni Esposito Date: Wed, 17 Jun 2026 07:36:01 +0200 Subject: [PATCH 10/12] feat: add dedicated EP dimension to device mesh --- .../config_lorem_ipsum_long_moe_ep_fsdp2.yaml | 3 +- src/modalities/config/config.py | 1 - src/modalities/models/moe/model_factory.py | 32 ++++--------------- .../running_env/fsdp/device_mesh.py | 11 +++++++ 4 files changed, 18 insertions(+), 29 deletions(-) diff --git a/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml b/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml index 577656b3b..073ef1a3f 100644 --- a/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml +++ b/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml @@ -184,7 +184,7 @@ device_mesh: device_type: cuda data_parallel_replicate_degree: 1 data_parallel_shard_degree: -1 - tensor_parallel_degree: 4 + expert_parallel_degree: 4 world_size: ${settings.cuda_env.world_size} dp_degree: @@ -238,7 +238,6 @@ ep_model: device_mesh: instance_key: device_mesh pass_type: BY_REFERENCE - ep_mesh_dim_name: tp block_names: [TransformerBlock] ac_model: diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index af280a0a4..fada33791 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -338,7 +338,6 @@ class EPWrappedModelConfig(BaseModel): model: PydanticPytorchModuleOrListType block_names: list[str] device_mesh: PydanticDeviceMeshIFType - ep_mesh_dim_name: str | None = None class DebuggingEnrichedModelConfig(BaseModel): diff --git a/src/modalities/models/moe/model_factory.py b/src/modalities/models/moe/model_factory.py index 4348add09..2b56f8f51 100644 --- a/src/modalities/models/moe/model_factory.py +++ b/src/modalities/models/moe/model_factory.py @@ -5,28 +5,12 @@ import torch.nn as nn from torch.distributed._composable.fsdp import MixedPrecisionPolicy from torch.distributed.device_mesh import DeviceMesh + from modalities.models.parallelism.expert_parallelism import ExpertParallel +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method from modalities.util import get_module_class_from_name -# TODO refactor these funtions into a utils -def _resolve_ep_mesh(device_mesh: DeviceMesh, ep_mesh_dim_name: str | None) -> DeviceMesh: - mesh_dim_names = tuple(device_mesh.mesh_dim_names or ()) - - if ep_mesh_dim_name is not None: - if ep_mesh_dim_name not in mesh_dim_names: - raise ValueError(f"ep_mesh_dim_name='{ep_mesh_dim_name}' not in mesh_dim_names={mesh_dim_names}") - return device_mesh[ep_mesh_dim_name] - - if len(mesh_dim_names) <= 1: - return device_mesh - - raise ValueError( - "DeviceMesh has multiple dimensions. Pass ep_mesh_dim_name explicitly. " - f"Available dimensions: {mesh_dim_names}" - ) - - def _validate_moe_block_for_ep(module) -> None: if not hasattr(module, "experts"): raise ValueError(f"Module {type(module).__name__} has no 'experts' attribute") @@ -64,16 +48,10 @@ def _attach_ep_metadata(module, ep_mesh) -> None: setattr(module, "_ep_rank", ep_mesh.get_local_rank()) -def _apply_ep(module, ep_mesh) -> None: - module.experts = ExpertParallel()._apply(module.experts, ep_mesh) - setattr(module.experts, "_ep_enabled", True) - - def get_ep_wrapped_model( model, block_names: list[str], device_mesh: DeviceMesh, - ep_mesh_dim_name: str | None = None, mp_param_dtype=torch.bfloat16, mp_reduce_dtype=torch.bfloat16, ) -> nn.Module: @@ -97,7 +75,7 @@ def get_ep_wrapped_model( if len(block_types) == 0: raise ValueError(f"None of the requested MoE block names were found: {block_names}") - ep_mesh = _resolve_ep_mesh(device_mesh, ep_mesh_dim_name) + ep_mesh = get_mesh_for_parallelism_method(device_mesh, ParallelismDegrees.EP) MixedPrecisionPolicy(param_dtype=mp_param_dtype, reduce_dtype=mp_reduce_dtype) wrapped_blocks = 0 @@ -115,7 +93,9 @@ def get_ep_wrapped_model( _validate_moe_block_for_ep(ep_target_module) _attach_ep_metadata(ep_target_module, ep_mesh) - _apply_ep(ep_target_module, ep_mesh) + + ep_target_module.experts = ExpertParallel()._apply(ep_target_module.experts, ep_mesh) + setattr(ep_target_module.experts, "_ep_enabled", True) wrapped_blocks += 1 diff --git a/src/modalities/running_env/fsdp/device_mesh.py b/src/modalities/running_env/fsdp/device_mesh.py index cd456938c..f4f3b7e26 100644 --- a/src/modalities/running_env/fsdp/device_mesh.py +++ b/src/modalities/running_env/fsdp/device_mesh.py @@ -21,6 +21,7 @@ class DeviceMeshConfig(BaseModel): tensor_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1 pipeline_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1 context_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1 + expert_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1 enable_loss_parallel: Optional[bool] = False world_size: Annotated[int, Field(strict=True, gt=0)] @@ -28,6 +29,7 @@ class DeviceMeshConfig(BaseModel): def _validate(self): for d in ( self.context_parallel_degree, + self.expert_parallel_degree, self.tensor_parallel_degree, self.pipeline_parallel_degree, ): @@ -50,6 +52,7 @@ def _validate(self): self.data_parallel_shard_degree = self.world_size // ( self.data_parallel_replicate_degree * self.context_parallel_degree + * self.expert_parallel_degree * self.tensor_parallel_degree * self.pipeline_parallel_degree ) @@ -58,12 +61,14 @@ def _validate(self): self.data_parallel_replicate_degree = self.world_size // ( self.data_parallel_shard_degree * self.context_parallel_degree + * self.expert_parallel_degree * self.tensor_parallel_degree * self.pipeline_parallel_degree ) if ( self.data_parallel_shard_degree * self.data_parallel_replicate_degree + * self.expert_parallel_degree * self.tensor_parallel_degree * self.pipeline_parallel_degree * self.context_parallel_degree @@ -72,6 +77,7 @@ def _validate(self): raise ConfigError( f"Invalid parallel dims: data_parallel_shard_degree({self.data_parallel_shard_degree}) * " f"data_parallel_replicate_degree({self.data_parallel_replicate_degree}) * " + f"expert_parallel_degree({self.expert_parallel_degree}) * " f"tensor_parallel_degree({self.tensor_parallel_degree}) *" f"* pipeline_parallel_degree({self.pipeline_parallel_degree}) *" f"context_parallel_degree({self.context_parallel_degree})!= WORLD_SIZE({self.world_size})" @@ -85,6 +91,7 @@ class ParallelismDegrees(Enum): DP_REPLICATE = "dp_replicate" DP_SHARD = "dp_shard" CP = "cp" + EP = "ep" TP = "tp" PP = "pp" @@ -96,6 +103,7 @@ def get_device_mesh( tensor_parallel_degree: int, pipeline_parallel_degree: int, context_parallel_degree: int, + expert_parallel_degree: int, enable_loss_parallel: bool, world_size: int, ) -> DeviceMesh: @@ -109,6 +117,7 @@ def get_device_mesh( tensor_parallel_degree (int): The tensor parallel degree. pipeline_parallel_degree (int): The pipeline parallel degree. context_parallel_degree (int): The context parallel degree. + expert_parallel_degree (int): The expert parallel degree. enable_loss_parallel (bool): Whether to enable loss parallelism. world_size (int): The world size. @@ -123,6 +132,7 @@ def get_device_mesh( data_parallel_replicate_degree, data_parallel_shard_degree, context_parallel_degree, + expert_parallel_degree, tensor_parallel_degree, ], [ @@ -130,6 +140,7 @@ def get_device_mesh( ParallelismDegrees.DP_REPLICATE.value, ParallelismDegrees.DP_SHARD.value, ParallelismDegrees.CP.value, + ParallelismDegrees.EP.value, ParallelismDegrees.TP.value, ], strict=True, From f638aecef3765206689a0de173cfae3e91926057 Mon Sep 17 00:00:00 2001 From: Giovanni Esposito Date: Wed, 17 Jun 2026 08:15:32 +0200 Subject: [PATCH 11/12] fix: mixed precision bug in ep layers --- .../config_lorem_ipsum_long_moe_ep_fsdp2.yaml | 3 + src/modalities/config/config.py | 1 + src/modalities/models/moe/model_factory.py | 91 ++++++++----------- 3 files changed, 43 insertions(+), 52 deletions(-) diff --git a/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml b/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml index 073ef1a3f..b17bb89e1 100644 --- a/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml +++ b/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml @@ -238,6 +238,9 @@ ep_model: device_mesh: instance_key: device_mesh pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 block_names: [TransformerBlock] ac_model: diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index fada33791..fea034546 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -338,6 +338,7 @@ class EPWrappedModelConfig(BaseModel): model: PydanticPytorchModuleOrListType block_names: list[str] device_mesh: PydanticDeviceMeshIFType + mixed_precision_settings: FSDP2MixedPrecisionSettings class DebuggingEnrichedModelConfig(BaseModel): diff --git a/src/modalities/models/moe/model_factory.py b/src/modalities/models/moe/model_factory.py index 2b56f8f51..35e9b7e10 100644 --- a/src/modalities/models/moe/model_factory.py +++ b/src/modalities/models/moe/model_factory.py @@ -1,59 +1,21 @@ import warnings -import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed._composable.fsdp import MixedPrecisionPolicy from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor from modalities.models.parallelism.expert_parallelism import ExpertParallel +from modalities.running_env.env_utils import FSDP2MixedPrecisionSettings from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method from modalities.util import get_module_class_from_name -def _validate_moe_block_for_ep(module) -> None: - if not hasattr(module, "experts"): - raise ValueError(f"Module {type(module).__name__} has no 'experts' attribute") - - experts = module.experts - required_attrs = ["w1", "w2"] - missing = [attr for attr in required_attrs if not hasattr(experts, attr)] - if missing: - raise ValueError( - f"Module {type(module).__name__}.experts is not grouped-experts compatible. Missing: {missing}" - ) - - if experts.w1.ndim != 3 or experts.w2.ndim != 3: - raise ValueError( - f"Expected grouped expert parameters with ndim=3. Got w1.ndim={experts.w1.ndim}, " - f"w2.ndim={experts.w2.ndim}" - ) - - -def _get_ep_target_module(module): - if hasattr(module, "experts"): - return module - - ffn = getattr(module, "ffn", None) - if ffn is not None and hasattr(ffn, "experts"): - return ffn - - return None - - -def _attach_ep_metadata(module, ep_mesh) -> None: - setattr(module, "_ep_mesh", ep_mesh) - setattr(module, "_ep_group", ep_mesh.get_group()) - setattr(module, "_ep_size", ep_mesh.size()) - setattr(module, "_ep_rank", ep_mesh.get_local_rank()) - - def get_ep_wrapped_model( model, block_names: list[str], device_mesh: DeviceMesh, - mp_param_dtype=torch.bfloat16, - mp_reduce_dtype=torch.bfloat16, + mixed_precision_settings: FSDP2MixedPrecisionSettings, ) -> nn.Module: block_types = [] missing_block_names = [] @@ -76,34 +38,59 @@ def get_ep_wrapped_model( raise ValueError(f"None of the requested MoE block names were found: {block_names}") ep_mesh = get_mesh_for_parallelism_method(device_mesh, ParallelismDegrees.EP) - MixedPrecisionPolicy(param_dtype=mp_param_dtype, reduce_dtype=mp_reduce_dtype) + target_dtype = mixed_precision_settings.param_dtype.value wrapped_blocks = 0 for module in model.modules(): if isinstance(module, block_types): - ep_target_module = _get_ep_target_module(module) - if ep_target_module is None: + if hasattr(module, "experts"): + ep_target = module + elif (ffn := getattr(module, "ffn", None)) is not None and hasattr(ffn, "experts"): + ep_target = ffn + else: raise ValueError( f"Module {type(module).__name__} has no EP-compatible experts location. " "Expected `experts` or `ffn.experts`." ) - if getattr(ep_target_module, "_ep_enabled", False): + if getattr(ep_target, "_ep_enabled", False): continue - _validate_moe_block_for_ep(ep_target_module) - _attach_ep_metadata(ep_target_module, ep_mesh) + experts = ep_target.experts + missing = [a for a in ("w1", "w2") if not hasattr(experts, a)] + if missing: + raise ValueError( + f"Module {type(ep_target).__name__}.experts is not grouped-experts compatible. Missing: {missing}" + ) + if experts.w1.ndim != 3 or experts.w2.ndim != 3: + raise ValueError( + f"Expected grouped expert parameters with ndim=3. Got w1.ndim={experts.w1.ndim}, " + f"w2.ndim={experts.w2.ndim}" + ) + + ep_target._ep_mesh = ep_mesh + ep_target._ep_group = ep_mesh.get_group() + ep_target._ep_size = ep_mesh.size() + ep_target._ep_rank = ep_mesh.get_local_rank() + + ep_target.experts = ExpertParallel()._apply(ep_target.experts, ep_mesh) + ep_target.experts._ep_enabled = True - ep_target_module.experts = ExpertParallel()._apply(ep_target_module.experts, ep_mesh) - setattr(ep_target_module.experts, "_ep_enabled", True) + for pname, p in list(ep_target.experts._parameters.items()): + if isinstance(p, DTensor) and p.dtype != target_dtype: + local = p.to_local().to(target_dtype) + ep_target.experts._parameters[pname] = nn.Parameter( + DTensor.from_local(local, p.device_mesh, p.placements, run_check=False), + requires_grad=p.requires_grad, + ) wrapped_blocks += 1 if wrapped_blocks == 0: raise ValueError(f"No blocks matched the requested types: {[t.__name__ for t in block_types]}") - setattr(model, "_ep_wrapped", True) - setattr(model, "_ep_mesh", ep_mesh) - setattr(model, "_ep_num_wrapped_blocks", wrapped_blocks) + model._ep_wrapped = True + model._ep_mesh = ep_mesh + model._ep_num_wrapped_blocks = wrapped_blocks return model From 3730c0906085ecb3dc4fd1ce6f3ccd500b958123 Mon Sep 17 00:00:00 2001 From: Giovanni Esposito Date: Wed, 17 Jun 2026 09:27:10 +0200 Subject: [PATCH 12/12] fix: apply black formatting to expert_parallel --- .../models/parallelism/expert_parallelism.py | 34 ++++++------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/src/modalities/models/parallelism/expert_parallelism.py b/src/modalities/models/parallelism/expert_parallelism.py index a2030e3cc..eaa33937c 100644 --- a/src/modalities/models/parallelism/expert_parallelism.py +++ b/src/modalities/models/parallelism/expert_parallelism.py @@ -29,7 +29,7 @@ def _permute_tokens( counts = num_tokens_per_expert_group.view(ep_degree, num_local_experts) # (ep_degree, num_local_experts) flat_counts = counts.flatten() # length = ep_degree * num_local_experts - + offsets = flat_counts.cumsum(0) - flat_counts # build permuted_indices @@ -39,9 +39,7 @@ def _permute_tokens( count = int(counts[r, e].item()) if count > 0: start = int(offsets[r * num_local_experts + e].item()) - indices_per_expert.append( - torch.arange(start, start + count, device=x.device, dtype=torch.long) - ) + indices_per_expert.append(torch.arange(start, start + count, device=x.device, dtype=torch.long)) if indices_per_expert: permuted_indices = torch.cat(indices_per_expert) @@ -88,9 +86,7 @@ def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> N nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])), ) - def _token_dispatch( - self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh - ) -> tuple[Tensor, Tensor]: + def _token_dispatch(self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh) -> tuple[Tensor, Tensor]: routed_input, num_tokens_per_expert = inputs ep_degree = device_mesh.shape[0] num_local_experts = num_tokens_per_expert.shape[0] // ep_degree @@ -99,20 +95,14 @@ def _token_dispatch( num_tokens_per_expert_group = all_to_all_single( num_tokens_per_expert, None, None, group=device_mesh.get_group() ) - - num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( - num_tokens_per_expert_group - ) + + num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(num_tokens_per_expert_group) input_splits = ( - num_tokens_per_expert.view(ep_degree, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=True) + num_tokens_per_expert.view(ep_degree, -1).sum(dim=1).to(torch.device("cpu"), non_blocking=True) ) - + output_splits = ( - num_tokens_per_expert_group.view(ep_degree, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=False) + num_tokens_per_expert_group.view(ep_degree, -1).sum(dim=1).to(torch.device("cpu"), non_blocking=False) ) self.input_splits = input_splits.tolist() self.output_splits = output_splits.tolist() @@ -124,14 +114,12 @@ def _token_dispatch( device_mesh.get_group(), ) - self.original_shape, routed_input, self.permuted_indices, num_tokens_per_expert_group = ( - _permute_tokens(routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts) + self.original_shape, routed_input, self.permuted_indices, num_tokens_per_expert_group = _permute_tokens( + routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts ) return routed_input, num_tokens_per_expert_group - def _token_combine( - self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh - ) -> Tensor: + def _token_combine(self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh) -> Tensor: routed_output = _unpermute_tokens(routed_output, self.original_shape, self.permuted_indices) routed_output = all_to_all_single_autograd( routed_output,