Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def top1gating(logits: Tensor,
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
# Communicate across expert processes to pick the maximum capacity.
if ep_group is not None:
if ep_group is not None and dist.get_world_size(group=ep_group) > 1:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
Expand Down Expand Up @@ -335,7 +335,7 @@ def top2gating(logits: Tensor,
else:
# Do not drop tokens - set capacity according to current expert assignments
new_capacity = torch.max(exp_counts)
if ep_group is not None:
if ep_group is not None and dist.get_world_size(group=ep_group) > 1:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
Expand Down Expand Up @@ -421,7 +421,7 @@ def topkgating(
else:
# Do not drop tokens - set capacity according to current expert assignments
new_capacity = torch.max(exp_counts)
if ep_group is not None:
if ep_group is not None and dist.get_world_size(group=ep_group) > 1:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
Expand Down Expand Up @@ -628,7 +628,10 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
# an allgather to ensure correctness,
dispatched_input = drop_tokens(dispatched_input, dim=1)

dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
if self.ep_size > 1:
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
else:
dispatched_input = dispatched_input.contiguous()

if self.wall_clock_breakdown:
self.timers(FIRST_ALLTOALL_TIMER).stop()
Expand All @@ -654,7 +657,10 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
if self.wall_clock_breakdown:
self.timers(SECOND_ALLTOALL_TIMER).start()

expert_output = _AllToAll.apply(self.ep_group, expert_output)
if self.ep_size > 1:
expert_output = _AllToAll.apply(self.ep_group, expert_output)
else:
expert_output = expert_output.contiguous()

if self.wall_clock_breakdown:
self.timers(SECOND_ALLTOALL_TIMER).stop()
Expand Down
91 changes: 90 additions & 1 deletion tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from unit.common import DistributedTest
from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader
import deepspeed.comm as dist
import deepspeed.moe.sharded_moe as sharded_moe
from deepspeed import get_accelerator
from deepspeed.moe.sharded_moe import top1gating, topkgating
from deepspeed.moe.layer import MoE
from deepspeed.moe.sharded_moe import top1gating, top2gating, topkgating
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param
from deepspeed.utils.torch import required_torch_version

Expand Down Expand Up @@ -209,6 +211,93 @@ def test(self):
use_tutel=False)


class TestMoESingleton(DistributedTest):
world_size = 2

@pytest.mark.parametrize("ep_size, expected_calls", [(1, 0), (2, 2)], ids=["single", "multi"])
def test_all_to_all(self, monkeypatch, ep_size, expected_calls):
if not required_torch_version(min_version=1.8):
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

config_dict = {"train_micro_batch_size_per_gpu": 1, "steps_per_print": 1}
hidden_dim = 8
expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim))
model = MoE(hidden_size=hidden_dim, expert=expert, num_experts=2, ep_size=ep_size, k=1, min_capacity=0)
optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
optimizer=optimizer,
dist_init_required=False)

all_to_all_calls = []

def counted_all_to_all(group, input):
all_to_all_calls.append((group, input.shape))
return input

monkeypatch.setattr(sharded_moe._AllToAll, "apply", counted_all_to_all)

x = torch.randn(1, 4, hidden_dim, device=model.device, requires_grad=True)
output, l_aux, _ = model(x)
assert len(all_to_all_calls) == expected_calls

loss = output.float().sum() + l_aux.float()
model.backward(loss)
assert len(all_to_all_calls) == expected_calls
assert x.grad is not None
Comment on lines +232 to +247
assert any(param.grad is not None for param in model.module.parameters())

@pytest.mark.parametrize("gate_fn, capacity_args", [(top1gating, (1, 0)), (top2gating, (1, 0)),
(topkgating, (3, 1, 0))],
ids=["top1", "top2", "topk"])
@pytest.mark.parametrize("ep_world_size, expected_calls", [(1, 0), (2, 1)], ids=["single", "multi"])
def test_capacity(self, monkeypatch, gate_fn, capacity_args, ep_world_size, expected_calls):
if not required_torch_version(min_version=1.8):
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

ep_group = None
if ep_world_size == 1:
for rank in range(dist.get_world_size()):
group = dist.new_group([rank])
if rank == dist.get_rank():
ep_group = group
else:
ep_group = dist.new_group(list(range(dist.get_world_size())))

all_reduce_calls = []
original_all_reduce = sharded_moe.dist.all_reduce

def counted_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None):
all_reduce_calls.append((tensor, op, group))
return original_all_reduce(tensor, op=op, group=group)

monkeypatch.setattr(sharded_moe.dist, "all_reduce", counted_all_reduce)

device = get_accelerator().current_device_name()
logits = torch.randn(8, 4, device=device)
gate_fn(logits, *capacity_args, drop_tokens=False, ep_group=ep_group)

assert len(all_reduce_calls) == expected_calls
if all_reduce_calls:
_, op, group = all_reduce_calls[0]
assert op == dist.ReduceOp.MAX
assert group is ep_group

def test_no_ep_group(self, monkeypatch):
if not required_torch_version(min_version=1.8):
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

def fail_collective(*args, **kwargs):
raise AssertionError("ep_group=None should not enter expert-parallel collective code")

monkeypatch.setattr(sharded_moe.dist, "get_world_size", fail_collective)
monkeypatch.setattr(sharded_moe.dist, "all_reduce", fail_collective)

device = get_accelerator().current_device_name()
logits = torch.randn(8, 4, device=device)
top2gating(logits, 1, 0, drop_tokens=False, ep_group=None, top2_2nd_expert_sampling=False)


class TestTopkGate(DistributedTest):

def test(self):
Expand Down
Loading