diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py index 9c31b280875..7d9179d1c50 100644 --- a/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/core/distributed/finalize_model_grads.py @@ -328,7 +328,11 @@ def reset_model_temporary_tensors(config: TransformerConfig, model: List[torch.n module.reset_global_aux_loss_tracker() -def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig): +def _update_router_expert_bias( + model: List[torch.nn.Module], + config: TransformerConfig, + tp_dp_cp_group: Optional[torch.distributed.ProcessGroup] = None, +): """ Update the expert bias of the router for a global batch. This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks @@ -350,7 +354,10 @@ def _update_router_expert_bias(model: List[torch.nn.Module], config: Transformer stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0) stacked_expert_bias = torch.stack(expert_bias_list, dim=0) stacked_updated_expert_bias = get_updated_expert_bias( - stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate + stacked_tokens_per_expert, + stacked_expert_bias, + config.moe_router_bias_update_rate, + tp_dp_cp_group=tp_dp_cp_group, ) for expert_bias, updated_expert_bias in zip(expert_bias_list, stacked_updated_expert_bias): @@ -448,6 +455,7 @@ def finalize_model_grads( """ config = get_model_config(model[0]) + tp_dp_cp_group = None if pg_collection is not None: assert hasattr(pg_collection, 'tp') assert hasattr(pg_collection, 'pp') @@ -466,6 +474,11 @@ def finalize_model_grads( "If you don't need pos_embd_group, you need to explicitly set it to None." ) assert hasattr(pg_collection, 'dp_cp') + if config.moe_router_enable_expert_bias: + assert hasattr(pg_collection, 'tp_dp_cp') and pg_collection.tp_dp_cp is not None, ( + "pg_collection must have tp_dp_cp when " "moe_router_enable_expert_bias is enabled." + ) + tp_dp_cp_group = pg_collection.tp_dp_cp tp_group = pg_collection.tp pp_group = pg_collection.pp embd_group = pg_collection.embd @@ -519,7 +532,11 @@ def finalize_model_grads( config.timers('embedding-grads-all-reduce').stop() if config.moe_router_enable_expert_bias: - _update_router_expert_bias(model, config) + if pg_collection is None: + tp_dp_cp_group = parallel_state.get_tensor_and_data_parallel_group( + with_context_parallel=True + ) + _update_router_expert_bias(model, config, tp_dp_cp_group=tp_dp_cp_group) reset_model_temporary_tensors(config, model) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 14fc6041574..a4d79a1b21c 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -623,6 +623,9 @@ def forward_backward_no_pipelining( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.tp_dp_cp = parallel_state.get_tensor_and_data_parallel_group( + with_context_parallel=True + ) elif pg_collection is not None: assert hasattr(pg_collection, 'tp'), "pg_collection must have tp" @@ -943,6 +946,9 @@ def forward_backward_pipelining_with_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.tp_dp_cp = parallel_state.get_tensor_and_data_parallel_group( + with_context_parallel=True + ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model[0]) @@ -2100,6 +2106,9 @@ def forward_backward_pipelining_without_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + pg_collection.tp_dp_cp = parallel_state.get_tensor_and_data_parallel_group( + with_context_parallel=True + ) elif p2p_communicator is not None and pg_collection is not None: assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config" diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 2ddc17a567a..0baf2c65cc5 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -8,7 +8,7 @@ import torch -from megatron.core import parallel_state, tensor_parallel, utils +from megatron.core import tensor_parallel, utils from megatron.core.extensions.transformer_engine import HAVE_TE from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.module import MegatronModule @@ -509,7 +509,7 @@ def shared_experts_compute(self, hidden_states: torch.Tensor): apply_module(self.shared_experts), False, tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), + self.tp_group, hidden_states, ) else: @@ -672,7 +672,7 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None): custom_forward, False, tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), + self.tp_group, hidden_states, intermediate_tensors, padding_mask, diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index f258f3474ae..e1b581898ab 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -1161,7 +1161,10 @@ def track_moe_metrics( def get_updated_expert_bias( - tokens_per_expert: torch.Tensor, expert_bias: torch.Tensor, expert_bias_update_rate: float + tokens_per_expert: torch.Tensor, + expert_bias: torch.Tensor, + expert_bias_update_rate: float, + tp_dp_cp_group: Optional[torch.distributed.ProcessGroup] = None, ) -> torch.Tensor: """Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1# @@ -1169,17 +1172,21 @@ def get_updated_expert_bias( tokens_per_expert (torch.Tensor): The number of tokens assigned to each expert. expert_bias (torch.Tensor): The bias for each expert. expert_bias_udpate_rate (float): The update rate for the expert bias. + tp_dp_cp_group (torch.distributed.ProcessGroup, optional): The group spanning the tensor, + data, and context parallel ranks that share the router expert-bias update. Returns: torch.Tensor: The updated expert bias. """ with torch.no_grad(): - # All Reduce Across TPxCPxDP group - torch.distributed.all_reduce( - tokens_per_expert, + if tp_dp_cp_group is None: # TODO(Hepteract): delete the usage of the global parallel_state. - group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True), - ) + tp_dp_cp_group = parallel_state.get_tensor_and_data_parallel_group( + with_context_parallel=True + ) + + # All Reduce Across TPxCPxDP group + torch.distributed.all_reduce(tokens_per_expert, group=tp_dp_cp_group) average_tokens = tokens_per_expert.sum(dim=-1, keepdim=True) / tokens_per_expert.shape[-1] offset = average_tokens - tokens_per_expert updated_expert_bias = expert_bias + torch.sign(offset) * expert_bias_update_rate diff --git a/megatron/core/transformer/moe/shared_experts.py b/megatron/core/transformer/moe/shared_experts.py index 61ea47955b8..a565e2ec718 100644 --- a/megatron/core/transformer/moe/shared_experts.py +++ b/megatron/core/transformer/moe/shared_experts.py @@ -229,10 +229,12 @@ def pre_forward_comm(self, input, wait_current_stream=True): self.gate_score = torch.nn.functional.sigmoid(logits) if self.config.sequence_parallel: self.cached_fc1_input = gather_from_sequence_parallel_region( - input, tensor_parallel_output_grad=True + input, tensor_parallel_output_grad=True, group=self.tp_group ) else: - self.cached_fc1_input = copy_to_tensor_model_parallel_region(input) + self.cached_fc1_input = copy_to_tensor_model_parallel_region( + input, group=self.tp_group + ) set_tensor_grad_fn_sequence_sr(self.cached_fc1_input, torch.iinfo(torch.int).max) @overlap_state_check( @@ -321,11 +323,11 @@ def post_forward_comm(self): with torch.cuda.stream(self.stream): if self.config.sequence_parallel: self.cached_output = reduce_scatter_to_sequence_parallel_region( - self.cached_fc2_output + self.cached_fc2_output, group=self.tp_group ) else: self.cached_output = reduce_from_tensor_model_parallel_region( - self.cached_fc2_output + self.cached_fc2_output, group=self.tp_group ) self.cached_fc2_output = None set_tensor_grad_fn_sequence_sr(self.cached_output, torch.iinfo(torch.int).max) diff --git a/tests/unit_tests/distributed/test_finalize_model_grads.py b/tests/unit_tests/distributed/test_finalize_model_grads.py index e1e2e760693..ee535c29baf 100644 --- a/tests/unit_tests/distributed/test_finalize_model_grads.py +++ b/tests/unit_tests/distributed/test_finalize_model_grads.py @@ -1,24 +1,122 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - import inspect import os import pytest import torch +import torch.distributed as dist from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.distributed.finalize_model_grads import ( _allreduce_non_tensor_model_parallel_grads, _allreduce_word_embedding_grads, + finalize_model_grads, ) from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils +class _RouterExpertBiasModel(torch.nn.Module): + def __init__(self, config, local_tokens_per_expert): + super().__init__() + self.config = config + self.ddp_config = DistributedDataParallelConfig() + self.router = torch.nn.Module() + self.router.register_buffer("local_tokens_per_expert", local_tokens_per_expert) + self.router.register_buffer("expert_bias", torch.zeros_like(local_tokens_per_expert)) + self.finish_grad_sync_calls = 0 + + def finish_grad_sync(self, force_all_reduce=False): + del force_all_reduce + self.finish_grad_sync_calls += 1 + + +def _router_expert_bias_config(): + return TransformerConfig( + num_layers=1, + hidden_size=8, + num_attention_heads=1, + use_cpu_initialization=True, + moe_router_enable_expert_bias=True, + moe_router_score_function="sigmoid", + moe_router_bias_update_rate=0.25, + moe_router_load_balancing_type="none", + ) + + +_NO_TP_DP_CP = object() + + +def _router_bias_pg_collection(tp_dp_cp=_NO_TP_DP_CP): + kwargs = { + 'tp': dist.group.WORLD, + 'pp': dist.group.WORLD, + 'embd': None, + 'pos_embd': None, + 'dp_cp': dist.group.WORLD, + } + if tp_dp_cp is not _NO_TP_DP_CP: + kwargs['tp_dp_cp'] = tp_dp_cp + return ProcessGroupCollection(**kwargs) + + +class TestFinalizeModelGradsMoEExpertBias: + def setup_method(self, method): + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + Utils.destroy_model_parallel() + Utils.initialize_distributed() + parallel_state.destroy_model_parallel() + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_finalize_model_grads_updates_router_expert_bias_with_custom_group(self): + assert not parallel_state.model_parallel_is_initialized() + + config = _router_expert_bias_config() + device = torch.device("cuda", torch.cuda.current_device()) + local_tokens = torch.tensor( + [0.0, 2.0] if dist.get_rank() == 0 else [0.0, 0.0], device=device + ) + model = _RouterExpertBiasModel(config, local_tokens) + + finalize_model_grads( + [model], pg_collection=_router_bias_pg_collection(tp_dp_cp=dist.group.WORLD) + ) + + expected_bias = torch.tensor([0.25, -0.25], device=device) + torch.testing.assert_close(model.router.expert_bias, expected_bias) + torch.testing.assert_close( + model.router.local_tokens_per_expert, torch.zeros_like(local_tokens) + ) + assert model.finish_grad_sync_calls == 1 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_finalize_model_grads_requires_custom_group_before_grad_sync(self): + assert not parallel_state.model_parallel_is_initialized() + config = _router_expert_bias_config() + device = torch.device("cuda", torch.cuda.current_device()) + pg_collections = [ + _router_bias_pg_collection(), + _router_bias_pg_collection(tp_dp_cp=dist.group.WORLD), + ] + pg_collections[1].tp_dp_cp = None + + for pg_collection in pg_collections: + model = _RouterExpertBiasModel(config, torch.tensor([1.0, 0.0], device=device)) + with pytest.raises(AssertionError, match="tp_dp_cp"): + finalize_model_grads([model], pg_collection=pg_collection) + assert model.finish_grad_sync_calls == 0 + + class TestAllReduceLNGrads: def init_model(self, share_embeddings_and_output_weights: bool = False):