Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions megatron/core/distributed/finalize_model_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,11 @@ def reset_model_temporary_tensors(config: TransformerConfig, model: List[torch.n
module.reset_global_aux_loss_tracker()


def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig):
def _update_router_expert_bias(
model: List[torch.nn.Module],
config: TransformerConfig,
tp_dp_cp_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Update the expert bias of the router for a global batch.
This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks
Expand All @@ -350,7 +354,10 @@ def _update_router_expert_bias(model: List[torch.nn.Module], config: Transformer
stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0)
stacked_expert_bias = torch.stack(expert_bias_list, dim=0)
stacked_updated_expert_bias = get_updated_expert_bias(
stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate
stacked_tokens_per_expert,
stacked_expert_bias,
config.moe_router_bias_update_rate,
tp_dp_cp_group=tp_dp_cp_group,
)

for expert_bias, updated_expert_bias in zip(expert_bias_list, stacked_updated_expert_bias):
Expand Down Expand Up @@ -448,6 +455,7 @@ def finalize_model_grads(
"""

config = get_model_config(model[0])
tp_dp_cp_group = None
if pg_collection is not None:
assert hasattr(pg_collection, 'tp')
assert hasattr(pg_collection, 'pp')
Expand All @@ -466,6 +474,11 @@ def finalize_model_grads(
"If you don't need pos_embd_group, you need to explicitly set it to None."
)
assert hasattr(pg_collection, 'dp_cp')
if config.moe_router_enable_expert_bias:
assert hasattr(pg_collection, 'tp_dp_cp') and pg_collection.tp_dp_cp is not None, (
"pg_collection must have tp_dp_cp when " "moe_router_enable_expert_bias is enabled."
)
tp_dp_cp_group = pg_collection.tp_dp_cp
tp_group = pg_collection.tp
pp_group = pg_collection.pp
embd_group = pg_collection.embd
Expand Down Expand Up @@ -519,7 +532,11 @@ def finalize_model_grads(
config.timers('embedding-grads-all-reduce').stop()

if config.moe_router_enable_expert_bias:
_update_router_expert_bias(model, config)
if pg_collection is None:
tp_dp_cp_group = parallel_state.get_tensor_and_data_parallel_group(
with_context_parallel=True
)
_update_router_expert_bias(model, config, tp_dp_cp_group=tp_dp_cp_group)

reset_model_temporary_tensors(config, model)

Expand Down
9 changes: 9 additions & 0 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,9 @@ def forward_backward_no_pipelining(
pg_collection.dp_cp = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)
pg_collection.tp_dp_cp = parallel_state.get_tensor_and_data_parallel_group(
with_context_parallel=True
)

elif pg_collection is not None:
assert hasattr(pg_collection, 'tp'), "pg_collection must have tp"
Expand Down Expand Up @@ -943,6 +946,9 @@ def forward_backward_pipelining_with_interleaving(
pg_collection.dp_cp = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)
pg_collection.tp_dp_cp = parallel_state.get_tensor_and_data_parallel_group(
with_context_parallel=True
)

elif p2p_communicator is not None and pg_collection is not None:
model_type = get_model_type(model[0])
Expand Down Expand Up @@ -2100,6 +2106,9 @@ def forward_backward_pipelining_without_interleaving(
pg_collection.dp_cp = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)
pg_collection.tp_dp_cp = parallel_state.get_tensor_and_data_parallel_group(
with_context_parallel=True
)

elif p2p_communicator is not None and pg_collection is not None:
assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config"
Expand Down
6 changes: 3 additions & 3 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from megatron.core import parallel_state, tensor_parallel, utils
from megatron.core import tensor_parallel, utils
from megatron.core.extensions.transformer_engine import HAVE_TE
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.module import MegatronModule
Expand Down Expand Up @@ -509,7 +509,7 @@ def shared_experts_compute(self, hidden_states: torch.Tensor):
apply_module(self.shared_experts),
False,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
self.tp_group,
hidden_states,
)
else:
Expand Down Expand Up @@ -672,7 +672,7 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None):
custom_forward,
False,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
self.tp_group,
hidden_states,
intermediate_tensors,
padding_mask,
Expand Down
19 changes: 13 additions & 6 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,25 +1161,32 @@ def track_moe_metrics(


def get_updated_expert_bias(
tokens_per_expert: torch.Tensor, expert_bias: torch.Tensor, expert_bias_update_rate: float
tokens_per_expert: torch.Tensor,
expert_bias: torch.Tensor,
expert_bias_update_rate: float,
tp_dp_cp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> torch.Tensor:
"""Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1#

Args:
tokens_per_expert (torch.Tensor): The number of tokens assigned to each expert.
expert_bias (torch.Tensor): The bias for each expert.
expert_bias_udpate_rate (float): The update rate for the expert bias.
tp_dp_cp_group (torch.distributed.ProcessGroup, optional): The group spanning the tensor,
data, and context parallel ranks that share the router expert-bias update.

Returns:
torch.Tensor: The updated expert bias.
"""
with torch.no_grad():
# All Reduce Across TPxCPxDP group
torch.distributed.all_reduce(
tokens_per_expert,
if tp_dp_cp_group is None:
# TODO(Hepteract): delete the usage of the global parallel_state.
group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True),
)
tp_dp_cp_group = parallel_state.get_tensor_and_data_parallel_group(
with_context_parallel=True
)

# All Reduce Across TPxCPxDP group
torch.distributed.all_reduce(tokens_per_expert, group=tp_dp_cp_group)
average_tokens = tokens_per_expert.sum(dim=-1, keepdim=True) / tokens_per_expert.shape[-1]
offset = average_tokens - tokens_per_expert
updated_expert_bias = expert_bias + torch.sign(offset) * expert_bias_update_rate
Expand Down
10 changes: 6 additions & 4 deletions megatron/core/transformer/moe/shared_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,12 @@ def pre_forward_comm(self, input, wait_current_stream=True):
self.gate_score = torch.nn.functional.sigmoid(logits)
if self.config.sequence_parallel:
self.cached_fc1_input = gather_from_sequence_parallel_region(
input, tensor_parallel_output_grad=True
input, tensor_parallel_output_grad=True, group=self.tp_group
)
else:
self.cached_fc1_input = copy_to_tensor_model_parallel_region(input)
self.cached_fc1_input = copy_to_tensor_model_parallel_region(
input, group=self.tp_group
)
set_tensor_grad_fn_sequence_sr(self.cached_fc1_input, torch.iinfo(torch.int).max)

@overlap_state_check(
Expand Down Expand Up @@ -321,11 +323,11 @@ def post_forward_comm(self):
with torch.cuda.stream(self.stream):
if self.config.sequence_parallel:
self.cached_output = reduce_scatter_to_sequence_parallel_region(
self.cached_fc2_output
self.cached_fc2_output, group=self.tp_group
)
else:
self.cached_output = reduce_from_tensor_model_parallel_region(
self.cached_fc2_output
self.cached_fc2_output, group=self.tp_group
)
self.cached_fc2_output = None
set_tensor_grad_fn_sequence_sr(self.cached_output, torch.iinfo(torch.int).max)
Expand Down
100 changes: 99 additions & 1 deletion tests/unit_tests/distributed/test_finalize_model_grads.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,122 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import inspect
import os

import pytest
import torch
import torch.distributed as dist

from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed.finalize_model_grads import (
_allreduce_non_tensor_model_parallel_grads,
_allreduce_word_embedding_grads,
finalize_model_grads,
)
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils


class _RouterExpertBiasModel(torch.nn.Module):
def __init__(self, config, local_tokens_per_expert):
super().__init__()
self.config = config
self.ddp_config = DistributedDataParallelConfig()
self.router = torch.nn.Module()
self.router.register_buffer("local_tokens_per_expert", local_tokens_per_expert)
self.router.register_buffer("expert_bias", torch.zeros_like(local_tokens_per_expert))
self.finish_grad_sync_calls = 0

def finish_grad_sync(self, force_all_reduce=False):
del force_all_reduce
self.finish_grad_sync_calls += 1


def _router_expert_bias_config():
return TransformerConfig(
num_layers=1,
hidden_size=8,
num_attention_heads=1,
use_cpu_initialization=True,
moe_router_enable_expert_bias=True,
moe_router_score_function="sigmoid",
moe_router_bias_update_rate=0.25,
moe_router_load_balancing_type="none",
)


_NO_TP_DP_CP = object()


def _router_bias_pg_collection(tp_dp_cp=_NO_TP_DP_CP):
kwargs = {
'tp': dist.group.WORLD,
'pp': dist.group.WORLD,
'embd': None,
'pos_embd': None,
'dp_cp': dist.group.WORLD,
}
if tp_dp_cp is not _NO_TP_DP_CP:
kwargs['tp_dp_cp'] = tp_dp_cp
return ProcessGroupCollection(**kwargs)


class TestFinalizeModelGradsMoEExpertBias:
def setup_method(self, method):
os.environ.pop('NVTE_FUSED_ATTN', None)
os.environ.pop('NVTE_FLASH_ATTN', None)
os.environ.pop('NVTE_UNFUSED_ATTN', None)
Utils.destroy_model_parallel()
Utils.initialize_distributed()
parallel_state.destroy_model_parallel()

def teardown_method(self, method):
Utils.destroy_model_parallel()

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_finalize_model_grads_updates_router_expert_bias_with_custom_group(self):
assert not parallel_state.model_parallel_is_initialized()

config = _router_expert_bias_config()
device = torch.device("cuda", torch.cuda.current_device())
local_tokens = torch.tensor(
[0.0, 2.0] if dist.get_rank() == 0 else [0.0, 0.0], device=device
)
model = _RouterExpertBiasModel(config, local_tokens)

finalize_model_grads(
[model], pg_collection=_router_bias_pg_collection(tp_dp_cp=dist.group.WORLD)
)

expected_bias = torch.tensor([0.25, -0.25], device=device)
torch.testing.assert_close(model.router.expert_bias, expected_bias)
torch.testing.assert_close(
model.router.local_tokens_per_expert, torch.zeros_like(local_tokens)
)
assert model.finish_grad_sync_calls == 1

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_finalize_model_grads_requires_custom_group_before_grad_sync(self):
assert not parallel_state.model_parallel_is_initialized()
config = _router_expert_bias_config()
device = torch.device("cuda", torch.cuda.current_device())
pg_collections = [
_router_bias_pg_collection(),
_router_bias_pg_collection(tp_dp_cp=dist.group.WORLD),
]
pg_collections[1].tp_dp_cp = None

for pg_collection in pg_collections:
model = _RouterExpertBiasModel(config, torch.tensor([1.0, 0.0], device=device))
with pytest.raises(AssertionError, match="tp_dp_cp"):
finalize_model_grads([model], pg_collection=pg_collection)
assert model.finish_grad_sync_calls == 0


class TestAllReduceLNGrads:

def init_model(self, share_embeddings_and_output_weights: bool = False):
Expand Down
Loading