Skip to content

Commit bcedecd

Browse files
mosheislandMoshe Island
andauthored
Support MoE for GPTModelPipe (#373)
* MOE: Support MoE layers creation for GPTModelPipe Signed-off-by: Moshe Island <misland@habana.ai> * MOE: Support MoE aux loss for GPTModelPipe Propagate aux loss along GPTModelPipe layers by forwarding the aggregated loss from each transformer layer to the next transformer layer. In addition, add a layer to GPTModelPipe, after the last transformer layer, to catch the final aggregated aux loss and cache it for use in the loss function. Signed-off-by: Moshe Island <misland@habana.ai> * MOE: Support display of MoE loss for GPTModelPipe Signed-off-by: Moshe Island <misland@habana.ai> * MOE: Verify MoE with no pipe/grad partitioned Currently PipelineEngine supports only a single tensor partitioning with grad. MoE model requires to forward with grad both the activations and the aux_loss. Therefore, until PilelineEngine limitation is removed, verify no partitioning when using MoE. Signed-off-by: Moshe Island <misland@habana.ai> --------- Signed-off-by: Moshe Island <misland@habana.ai> Co-authored-by: Moshe Island <misland@habana.ai>
1 parent 3c5f475 commit bcedecd

3 files changed

Lines changed: 121 additions & 27 deletions

File tree

megatron/model/gpt_model.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""GPT-2 model."""
44

55
import torch
6+
from collections import OrderedDict
67

78
from megatron import get_args
89
from megatron.core import mpu, tensor_parallel, sequence_parallel
@@ -16,7 +17,7 @@
1617

1718
from megatron.model import LayerNorm, RMSNorm
1819
from .language_model import EmbeddingPipe
19-
from .transformer import ParallelTransformerLayerPipe, LMHeadPipe
20+
from .transformer import ParallelTransformerLayerPipe, LMHeadPipe, get_num_experts_per_layer
2021
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
2122

2223

@@ -360,12 +361,33 @@ def _to_float16(inputs):
360361
embedding_weights_in_fp32=args.embedding_weights_in_fp32,
361362
tied_weight_attr='word_embeddings_weight'))
362363

364+
experts_per_layer = get_num_experts_per_layer(args.num_experts, args.num_layers, args.expert_interval)
365+
self.is_moe_model = any(n_experts > 1 for n_experts in experts_per_layer)
366+
367+
# Currently PipelineEngine does not support more than 1 pipe and/or grad partitioned tensors that
368+
# require grads.
369+
# When using MoE, we have 2 tensors that are passed along pipeline stages and both require grads.
370+
# Therefore, verify that both pipe_partitioned / grad_partitioned are not enabled
371+
if self.is_moe_model and args.pipeline_model_parallel_size > 1 and args.tensor_model_parallel_size > 1:
372+
pipe_partitioned_enabled = args.deepspeed_config_dict.get('pipeline', {}).get('pipe_partitioned', False)
373+
grad_partitioned_enabled = args.deepspeed_config_dict.get('pipeline', {}).get('grad_partitioned', False)
374+
assert not pipe_partitioned_enabled and not grad_partitioned_enabled, \
375+
'Pipe and/or Grad partitioning are not supported for MoE model'
376+
363377
for layer_idx in range(args.num_layers):
364378
self.specs.append(
365379
LayerSpec(ParallelTransformerLayerPipe,
366-
config,
367-
layer_number=layer_idx,
368-
self_attn_mask_type=AttnMaskType.causal))
380+
config,
381+
layer_number=layer_idx,
382+
self_attn_mask_type=AttnMaskType.causal,
383+
num_experts=experts_per_layer[layer_idx],
384+
input_aggregated_moe_loss=(self.is_moe_model and layer_idx > 0),
385+
return_aggregated_moe_loss=self.is_moe_model))
386+
387+
# if model has experts, add a layer to get and cache the aggregated moe loss from the
388+
# last transformer layer
389+
if self.is_moe_model:
390+
self.specs.append(self._calculate_moe_loss)
369391

370392
# Final layernorm after transformer layers
371393
if args.normalization == 'layernorm':
@@ -404,6 +426,11 @@ def _logits_helper(embedding, lm_output):
404426
if args.fp16 or args.bf16:
405427
self.specs.append(float16_to_fp32)
406428

429+
# Cache losses
430+
self.moe_loss = None
431+
self.last_lm_loss = None # detached, for display only
432+
self.last_moe_loss = None # detached, for display only
433+
407434
if args.checkpoint_activations:
408435
interval = args.checkpoint_num_layers
409436
elif args.recompute_granularity == "full" and args.recompute_method == 'uniform':
@@ -418,10 +445,34 @@ def _logits_helper(embedding, lm_output):
418445
num_dp=mpu.get_data_parallel_world_size())
419446

420447
super().__init__(layers=self.specs,
421-
loss_fn=CrossEntropy,
448+
loss_fn=self.loss_func,
422449
topology=topo,
423450
activation_checkpoint_interval=interval,
424451
partition_method='type:transformer')
425452

453+
def _calculate_moe_loss(self, inputs):
454+
""" Calculate MoE auxiliary loss """
455+
assert isinstance(inputs, tuple) and len(inputs) == 2
456+
hidden, aggregated_moe_loss = inputs[0], inputs[1]
457+
args = get_args()
458+
self.moe_loss = aggregated_moe_loss * args.moe_loss_coeff
459+
return hidden
460+
461+
def loss_func(self, output, labels):
462+
loss = CrossEntropy(output, labels)
463+
self.last_lm_loss = loss.clone().detach()
464+
if self.moe_loss is not None:
465+
loss += self.moe_loss
466+
self.last_moe_loss = self.moe_loss.clone().detach()
467+
return loss
468+
426469
def universal_checkpoint_info(self):
427470
return UniversalCheckpointInfo(using_model_pipe=True).get()
471+
472+
def get_additional_losses(self):
473+
if not self.is_moe_model:
474+
return None
475+
return OrderedDict({
476+
'lm loss': self.last_lm_loss,
477+
'moe loss': self.last_moe_loss
478+
})

megatron/model/transformer.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,8 @@ def forward(self, hidden_states, attention_mask=None,
12291229
retriever_output=None,
12301230
retriever_attn_mask=None,
12311231
inference_params=None,
1232-
rotary_pos_emb=None):
1232+
rotary_pos_emb=None,
1233+
aggregated_moe_loss=None):
12331234
# hidden_states: [s, b, h]
12341235

12351236
# Layer norm at the beginning of the transformer layer.
@@ -1321,6 +1322,10 @@ def forward(self, hidden_states, attention_mask=None,
13211322
else:
13221323
mlp_output, moe_loss, _ = self.mlp(layernorm_output)
13231324

1325+
# when aggregated_moe_loss received, returned moe_loss is the aggregated moe loss
1326+
if aggregated_moe_loss is not None:
1327+
moe_loss += aggregated_moe_loss
1328+
13241329
# Second residual connection.
13251330
if self.apply_residual_connection_post_layernorm:
13261331
residual = layernorm_output
@@ -1381,23 +1386,51 @@ class ParallelTransformerLayerPipe(ParallelTransformerLayer):
13811386
If no mask is provided, the module will query `self._args.attn_mask`
13821387
for the mask and only return `super().forward(...)`
13831388
"""
1389+
def __init__(self, config,
1390+
layer_number, layer_type=LayerType.encoder,
1391+
self_attn_mask_type=AttnMaskType.padding,
1392+
drop_path_rate=0., num_experts=1,
1393+
input_aggregated_moe_loss=False, return_aggregated_moe_loss=False):
1394+
self.input_aggregated_moe_loss = input_aggregated_moe_loss
1395+
self.return_aggregated_moe_loss = return_aggregated_moe_loss
1396+
super().__init__(config, layer_number, layer_type, self_attn_mask_type, drop_path_rate, num_experts)
1397+
13841398
def forward(self, inputs, **kwargs):
13851399
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
13861400
if not hasattr(self, '_args'):
13871401
self._args = get_args()
13881402
rotary_pos_emb = self._args.rotary_pos_emb if self._args.use_rotary_position_embeddings else None
13891403
if torch.is_tensor(inputs) or len(inputs) == 1:
1404+
assert not self.input_aggregated_moe_loss, f'Expecting an input tuple of size >= 2'
13901405
# No attention mask forwarded, search for args.attn_mask
13911406
hidden_states, attention_mask = inputs, self._args.attn_mask
1392-
# HACK: currently MoE model does not support pipeline parallel, so
1393-
# here we just ignore the moe_loss returned by forward()
1394-
return super().forward(hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb)[0]
1395-
elif len(inputs) == 2:
1396-
# Attention mask is an activation.
1397-
hidden_states, attention_mask = inputs[0], inputs[1]
1398-
# HACK: currently MoE model does not support pipeline parallel, so
1399-
# here we just ignore the moe_loss returned by forward()
1400-
return super().forward(*inputs, **kwargs, rotary_pos_emb=rotary_pos_emb)[0], attention_mask
1407+
output, moe_loss = super().forward(hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb)
1408+
return (output, moe_loss) if self.return_aggregated_moe_loss else output
1409+
elif len(inputs) in (2, 3):
1410+
# Attention mask and aggregated_moe can both be activations.
1411+
return_attention_mask = False
1412+
if len(inputs) == 2:
1413+
if self.input_aggregated_moe_loss:
1414+
hidden_states, aggregated_moe_loss = inputs[0], inputs[1]
1415+
attention_mask = self._args.attn_mask
1416+
else:
1417+
hidden_states, attention_mask = inputs[0], inputs[1]
1418+
return_attention_mask = True
1419+
else:
1420+
hidden_states, attention_mask, aggregated_moe_loss = inputs[0], inputs[1], inputs[2]
1421+
1422+
# Forward aggregated_moe_loss to ParallelTransformerLayer for further accumulation
1423+
if self.input_aggregated_moe_loss:
1424+
kwargs.update({'aggregated_moe_loss': aggregated_moe_loss})
1425+
1426+
output, moe_loss = super().forward(hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb)
1427+
1428+
ret = (output, )
1429+
if return_attention_mask:
1430+
ret += (attention_mask, )
1431+
if self.return_aggregated_moe_loss:
1432+
ret += (moe_loss, )
1433+
return ret
14011434
else:
14021435
raise RuntimeError('Received more inputs than understood.')
14031436

@@ -1499,6 +1532,19 @@ def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
14991532
return default_layer_type
15001533

15011534

1535+
def get_num_experts_per_layer(num_experts: list, num_layers: int, expert_interval: int, offset: int = 0) -> list:
1536+
assert len(num_experts) == 1 or len(num_experts) == num_layers // expert_interval, \
1537+
'num_experts must be either a single value or a list of the same length as the number of MoE layers'
1538+
if len(num_experts) == 1:
1539+
num_experts = num_experts * (num_layers // expert_interval)
1540+
experts_per_layer = []
1541+
for i in range(num_layers):
1542+
layer_num = i + 1 + offset
1543+
n_e = num_experts[(layer_num-1) // expert_interval] if layer_num % expert_interval == 0 else 1
1544+
experts_per_layer.append(n_e)
1545+
return experts_per_layer
1546+
1547+
15021548
class ParallelTransformer(MegatronModule):
15031549
"""Transformer class."""
15041550

@@ -1682,21 +1728,12 @@ def build_layer(layer_number, n_e):
16821728
self.num_layers = 1
16831729
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
16841730
else:
1685-
assert len(num_experts) == 1 or len(num_experts) == args.num_layers // args.expert_interval, \
1686-
'num_experts must be either a single value or a list of the same length as the number of MoE layers'
1687-
1688-
# Create the list of MoE experts
1689-
if len(num_experts) == 1:
1690-
num_experts = num_experts * (args.num_layers // args.expert_interval)
1691-
16921731
# Build the layers
16931732
self.layers = []
1733+
experts_per_layer = get_num_experts_per_layer(num_experts, self.num_layers, args.expert_interval, offset)
16941734
for i in range(self.num_layers):
16951735
layer_num = i + 1 + offset
1696-
if layer_num % args.expert_interval == 0:
1697-
n_e = num_experts[(layer_num-1) // args.expert_interval]
1698-
else:
1699-
n_e = 1
1736+
n_e = experts_per_layer[i]
17001737
self.layers.append(build_layer(layer_num, n_e))
17011738
self.layers = torch.nn.ModuleList(self.layers)
17021739

megatron/training.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# The earliest we can measure the start time.
1111
_TRAIN_START_TIME = time.time()
1212
import torch
13+
from collections import OrderedDict
1314
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
1415

1516
from megatron import get_args
@@ -667,8 +668,13 @@ def train_step(forward_step_func, data_iterator,
667668
num_zeros_in_grad = 0
668669
assert isinstance(model[0], deepspeed.PipelineEngine)
669670
loss = model[0].train_batch(data_iter=data_iterator)
671+
additional_losses = model[0].get_additional_losses()
672+
loss_key = 'lm loss' if additional_losses is None else 'loss' # use "lm loss" for backward compatibility
673+
loss_dict = OrderedDict({loss_key: loss})
674+
if additional_losses is not None:
675+
loss_dict.update(additional_losses)
670676
grad_norm = model[0].get_global_grad_norm()
671-
return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad
677+
return loss_dict, skipped_iter, grad_norm, num_zeros_in_grad
672678

673679
# Set grad to zero.
674680
if not args.deepspeed:

0 commit comments

Comments
 (0)