Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions megatron/core/models/hybrid/hybrid_layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ def select_pipeline_segment(
vp_stage: Optional[int],
first_stage_layers: Optional[int] = None,
last_stage_layers: Optional[int] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
dp_cp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Tuple[List[str], int]:
"""Select and validate the pipeline segment for the given PP rank and VP stage.

Expand All @@ -352,6 +354,8 @@ def select_pipeline_segment(
uneven PP. Only valid when the pattern has no pipe separators.
last_stage_layers: Number of layers on the last pipeline stage for
uneven PP. Only valid when the pattern has no pipe separators.
tp_group: Optional tensor-parallel process group used for per-stage logging.
dp_cp_group: Optional data/context-parallel process group used for per-stage logging.

Returns:
Tuple of (layer_type_list, layer_offset) where layer_type_list is
Expand Down Expand Up @@ -445,6 +449,8 @@ def select_pipeline_segment(
f"HybridModel: pp_rank={pp_rank}/{pp_size}, vp_stage={vp_stage}, "
f"layers='{''.join(selected)}' ({len(selected)} layers), "
f"layer_offset={offset} (auto-split)",
tp_group=tp_group,
dp_cp_group=dp_cp_group,
)
return selected, offset

Expand Down Expand Up @@ -479,6 +485,8 @@ def select_pipeline_segment(
f"segment_index={segment_index}/{len(segments)}, "
f"layers='{my_segment}' ({len(layer_type_list)} layers), "
f"layer_offset={layer_offset}",
tp_group=tp_group,
dp_cp_group=dp_cp_group,
)

return layer_type_list, layer_offset
Expand Down
15 changes: 15 additions & 0 deletions megatron/core/models/hybrid/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@
logger = logging.getLogger(__name__)


def _hybrid_logging_pg_kwargs(pg_collection: ProcessGroupCollection) -> dict:
tp_group = getattr(pg_collection, 'tp', None)
dp_cp_group = getattr(pg_collection, 'dp_cp', None)
if (tp_group is None) != (dp_cp_group is None):
raise ValueError(
"pg_collection.tp and pg_collection.dp_cp must both be set or both be unset."
)
if tp_group is None:
return {}
return {'tp_group': tp_group, 'dp_cp_group': dp_cp_group}


class HybridModel(LanguageModule, GraphableMegatronModule):
"""Hybrid language model.

Expand Down Expand Up @@ -186,12 +198,15 @@ def __init__(
self.mtp_pattern = parsed.mtp_pattern
self.mtp_num_depths = parsed.mtp_num_depths

logging_pg_kwargs = _hybrid_logging_pg_kwargs(self.pg_collection)

layer_type_list, layer_offset = select_pipeline_segment(
parsed.main_pattern or '',
self.pg_collection.pp,
vp_stage,
first_stage_layers=self.config.num_layers_in_first_pipeline_stage,
last_stage_layers=self.config.num_layers_in_last_pipeline_stage,
**logging_pg_kwargs,
)

# Determine if MTP is needed (based on pattern parsing)
Expand Down
23 changes: 22 additions & 1 deletion tests/unit_tests/models/test_hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from datetime import timedelta
from itertools import accumulate
from types import SimpleNamespace

import pytest
import torch
Expand All @@ -17,7 +18,7 @@
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import YarnRotaryEmbedding
from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec
from megatron.core.models.hybrid.hybrid_model import HybridModel
from megatron.core.models.hybrid.hybrid_model import HybridModel, _hybrid_logging_pg_kwargs
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer import TransformerConfig
Expand All @@ -27,6 +28,26 @@
from tests.unit_tests.test_utilities import Utils


def test_hybrid_logging_process_groups_are_paired():
tp_group = object()
dp_cp_group = object()

assert _hybrid_logging_pg_kwargs(SimpleNamespace()) == {}
assert _hybrid_logging_pg_kwargs(SimpleNamespace(tp=tp_group, dp_cp=dp_cp_group)) == {
'tp_group': tp_group,
'dp_cp_group': dp_cp_group,
}

with pytest.raises(ValueError, match="tp.*dp_cp"):
_hybrid_logging_pg_kwargs(SimpleNamespace(tp=tp_group))
with pytest.raises(ValueError, match="tp.*dp_cp"):
_hybrid_logging_pg_kwargs(SimpleNamespace(dp_cp=dp_cp_group))
with pytest.raises(ValueError, match="tp.*dp_cp"):
_hybrid_logging_pg_kwargs(SimpleNamespace(tp=tp_group, dp_cp=None))
with pytest.raises(ValueError, match="tp.*dp_cp"):
_hybrid_logging_pg_kwargs(SimpleNamespace(tp=None, dp_cp=dp_cp_group))


class TestHybridModel:

def setup_method(self, method):
Expand Down
10 changes: 10 additions & 0 deletions tests/unit_tests/ssm/test_hybrid_layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,16 @@ def test_logging_is_called(self, mock_log):
select_pipeline_segment("M*M*", pp_group=None, vp_stage=None)
mock_log.assert_called_once()

@patch('megatron.core.models.hybrid.hybrid_layer_allocation.log_on_each_pipeline_stage')
def test_logging_receives_explicit_groups(self, mock_log):
tp_group = object()
dp_cp_group = object()
select_pipeline_segment(
"M*M*", pp_group=None, vp_stage=None, tp_group=tp_group, dp_cp_group=dp_cp_group
)
assert mock_log.call_args.kwargs["tp_group"] is tp_group
assert mock_log.call_args.kwargs["dp_cp_group"] is dp_cp_group

@patch('megatron.core.models.hybrid.hybrid_layer_allocation.log_on_each_pipeline_stage')
def test_mutual_exclusivity_pipes_with_first_stage(self, mock_log):
"""Pipe separators + first_stage_layers should raise ValueError."""
Expand Down
Loading