Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/megatron/bridge/models/mamba/mamba_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,18 @@ class MambaModelConfig(ModelConfig):
hybrid_layer_pattern: str | None = None
seq_length: int = 8192
# Mamba with no attention has no need for position embeddings, so none is default
position_embedding_type: Literal["learned_absolute", "rope", "none"] = "none"
position_embedding_type: Literal["learned_absolute", "rope", "yarn", "none"] = "none"
rotary_percent: float = 1.0
rotary_base: int = 10000
seq_len_interpolation_factor: float | None = None
# YaRN parameters — only applied when position_embedding_type == "yarn"
yarn_rotary_scaling_factor: float = 8.0
yarn_original_max_position_embeddings: int | None = None
yarn_beta_fast: float = 32.0
yarn_beta_slow: float = 1.0
yarn_mscale: float = 1.0
yarn_mscale_all_dim: float = 0.0
yarn_correction_range_round_to_int: bool = True
make_vocab_size_divisible_by: int = 128
mamba_stack_spec: ModuleSpec | Callable[[], ModuleSpec] | Callable[["MambaModelConfig"], ModuleSpec] = (
get_default_mamba_stack_spec
Expand Down Expand Up @@ -224,6 +232,22 @@ def build_model(

pre_process = pre_process if pre_process is not None else is_pp_first_stage(pg_collection.pp)
post_process = post_process if post_process is not None else is_pp_last_stage(pg_collection.pp)

if self._model_config.position_embedding_type == "yarn":
cfg = self._model_config
t = cfg.transformer
t.yarn_rotary_scaling_factor = cfg.yarn_rotary_scaling_factor
t.yarn_original_max_position_embeddings = (
cfg.yarn_original_max_position_embeddings
if cfg.yarn_original_max_position_embeddings is not None
else int(cfg.seq_length / cfg.yarn_rotary_scaling_factor)
)
t.yarn_beta_fast = cfg.yarn_beta_fast
t.yarn_beta_slow = cfg.yarn_beta_slow
t.yarn_mscale = cfg.yarn_mscale
t.yarn_mscale_all_dim = cfg.yarn_mscale_all_dim
t.yarn_correction_range_round_to_int = cfg.yarn_correction_range_round_to_int

return MCoreMambaModel(
config=self._model_config.transformer,
mamba_stack_spec=mamba_stack_spec,
Expand Down
13 changes: 12 additions & 1 deletion src/megatron/bridge/models/mamba/mamba_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,19 @@ class MambaModelProvider(TransformerConfig, ModelProviderMixin[MCoreMambaModel])
hybrid_layer_pattern: Optional[str] = None
seq_length: int = 8192
# Mamba with no attention has no need for position embeddings, so none is default
position_embedding_type: Literal["learned_absolute", "rope", "none"] = "none"
position_embedding_type: Literal["learned_absolute", "rope", "yarn", "none"] = "none"
rotary_percent: float = 1.0
rotary_base: int = 10000
seq_len_interpolation_factor: Optional[float] = None
apply_rope_fusion: bool = True
# YaRN parameters — only applied when position_embedding_type == "yarn"
yarn_rotary_scaling_factor: float = 8.0
yarn_original_max_position_embeddings: Optional[int] = None
yarn_beta_fast: float = 32.0
yarn_beta_slow: float = 1.0
yarn_mscale: float = 1.0
yarn_mscale_all_dim: float = 0.0
yarn_correction_range_round_to_int: bool = True
make_vocab_size_divisible_by: int = 128
gated_linear_unit: bool = False
normalization: str = "RMSNorm"
Expand Down Expand Up @@ -294,6 +302,9 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreMa
"models due to upstream MCore MambaModel API dependency"
)

if self.position_embedding_type == "yarn" and self.yarn_original_max_position_embeddings is None:
self.yarn_original_max_position_embeddings = int(self.seq_length / self.yarn_rotary_scaling_factor)

assert self.vocab_size is not None, "vocab_size must be configured before calling provide()"
if self.should_pad_vocab:
padded_vocab_size = calculate_padded_vocab_size(
Expand Down
112 changes: 112 additions & 0 deletions tests/unit_tests/models/mamba/test_mamba_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,115 @@ def test_default_parameters_forwarded(self, mock_unimodal, mock_compose):
assert args[8] is False # data_parallel_random_init
assert args[9] is Float16Module # mixed_precision_wrapper
assert args[11] is ModelType.encoder_or_decoder # model_type


# =============================================================================
# Section 4 — YaRN positional embedding support
# =============================================================================


class TestMambaModelConfigYarnDefaults:
"""Tests that MambaModelConfig exposes the expected YaRN field defaults."""

def test_yarn_field_defaults(self):
config = _make_mamba_config()
assert config.yarn_rotary_scaling_factor == 8.0
assert config.yarn_original_max_position_embeddings is None
assert config.yarn_beta_fast == 32.0
assert config.yarn_beta_slow == 1.0
assert config.yarn_mscale == 1.0
assert config.yarn_mscale_all_dim == 0.0
assert config.yarn_correction_range_round_to_int is True

def test_yarn_position_embedding_type_accepted(self):
config = _make_mamba_config(position_embedding_type="yarn")
assert config.position_embedding_type == "yarn"

def test_yarn_fields_settable(self):
config = _make_mamba_config(
position_embedding_type="yarn",
yarn_rotary_scaling_factor=4.0,
yarn_original_max_position_embeddings=1024,
yarn_beta_fast=16.0,
yarn_beta_slow=0.5,
yarn_mscale=0.8,
yarn_mscale_all_dim=1.0,
yarn_correction_range_round_to_int=False,
)
assert config.yarn_rotary_scaling_factor == 4.0
assert config.yarn_original_max_position_embeddings == 1024
assert config.yarn_beta_fast == 16.0
assert config.yarn_beta_slow == 0.5
assert config.yarn_mscale == 0.8
assert config.yarn_mscale_all_dim == 1.0
assert config.yarn_correction_range_round_to_int is False


class TestMambaModelBuilderBuildModelWithYarn:
"""Tests for YaRN attribute injection in MambaModelBuilder.build_model()."""

def setup_method(self):
self.config = _make_mamba_config(
vocab_size=32000,
seq_length=4096,
position_embedding_type="yarn",
yarn_rotary_scaling_factor=8.0,
)
self.builder = MambaModelBuilder(self.config)
self.pg = Mock()
self.pg.pp = Mock()

@patch("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size")
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage", return_value=True)
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage", return_value=True)
@patch("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel")
def test_yarn_attrs_injected_onto_transformer(self, mock_model, *_):
"""All YaRN attrs must be set on the embedded TransformerConfig before MCoreMambaModel is called."""
self.builder.build_model(self.pg, pre_process=True, post_process=True)
t = self.config.transformer
assert t.yarn_rotary_scaling_factor == 8.0
assert t.yarn_beta_fast == 32.0
assert t.yarn_beta_slow == 1.0
assert t.yarn_mscale == 1.0
assert t.yarn_mscale_all_dim == 0.0
assert t.yarn_correction_range_round_to_int is True

@patch("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size")
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage", return_value=True)
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage", return_value=True)
@patch("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel")
def test_yarn_original_max_defaulted_from_seq_length(self, mock_model, *_):
"""None yarn_original_max_position_embeddings defaults to seq_length / scaling_factor."""
assert self.config.yarn_original_max_position_embeddings is None
self.builder.build_model(self.pg, pre_process=True, post_process=True)
expected = int(self.config.seq_length / self.config.yarn_rotary_scaling_factor)
assert self.config.transformer.yarn_original_max_position_embeddings == expected

@patch("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size")
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage", return_value=True)
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage", return_value=True)
@patch("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel")
def test_yarn_original_max_explicit_value_preserved(self, mock_model, *_):
"""An explicit yarn_original_max_position_embeddings is passed through unchanged."""
self.config.__dict__["yarn_original_max_position_embeddings"] = 512
self.builder.build_model(self.pg, pre_process=True, post_process=True)
assert self.config.transformer.yarn_original_max_position_embeddings == 512

@patch("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size")
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage", return_value=True)
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage", return_value=True)
@patch("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel")
def test_no_yarn_injection_for_rope(self, mock_model, *_):
"""YaRN attrs must NOT be injected when position_embedding_type is 'rope'."""
config = _make_mamba_config(vocab_size=32000, position_embedding_type="rope")
MambaModelBuilder(config).build_model(self.pg, pre_process=True, post_process=True)
assert not hasattr(config.transformer, "yarn_rotary_scaling_factor")

@patch("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size")
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage", return_value=True)
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage", return_value=True)
@patch("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel")
def test_position_embedding_type_yarn_forwarded_to_mcore(self, mock_model, *_):
"""position_embedding_type='yarn' must be forwarded to MCoreMambaModel."""
self.builder.build_model(self.pg, pre_process=True, post_process=True)
assert mock_model.call_args.kwargs["position_embedding_type"] == "yarn"
118 changes: 118 additions & 0 deletions tests/unit_tests/models/mamba/test_mamba_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,121 @@ def test_finalize_uses_compatible_hybrid_layer_count(self):

assert provider.num_layers == 9
mock_finalize.assert_called_once_with(provider)


# =============================================================================
# YaRN positional embedding support
# =============================================================================


class TestMambaModelProviderYarnDefaults:
"""Tests that MambaModelProvider exposes the expected YaRN field defaults."""

def test_yarn_field_defaults(self):
provider = MambaModelProvider(num_layers=2, hidden_size=128, num_attention_heads=1)
assert provider.yarn_rotary_scaling_factor == 8.0
assert provider.yarn_original_max_position_embeddings is None
assert provider.yarn_beta_fast == 32.0
assert provider.yarn_beta_slow == 1.0
assert provider.yarn_mscale == 1.0
assert provider.yarn_mscale_all_dim == 0.0
assert provider.yarn_correction_range_round_to_int is True

def test_yarn_position_embedding_type_accepted(self):
provider = MambaModelProvider(
num_layers=2, hidden_size=128, num_attention_heads=1, position_embedding_type="yarn"
)
assert provider.position_embedding_type == "yarn"

def test_yarn_custom_fields(self):
provider = MambaModelProvider(
num_layers=2,
hidden_size=128,
num_attention_heads=1,
position_embedding_type="yarn",
yarn_rotary_scaling_factor=4.0,
yarn_original_max_position_embeddings=512,
yarn_beta_fast=16.0,
yarn_beta_slow=0.5,
yarn_mscale=0.8,
yarn_mscale_all_dim=1.0,
yarn_correction_range_round_to_int=False,
)
assert provider.yarn_rotary_scaling_factor == 4.0
assert provider.yarn_original_max_position_embeddings == 512
assert provider.yarn_beta_fast == 16.0
assert provider.yarn_beta_slow == 0.5
assert provider.yarn_mscale == 0.8
assert provider.yarn_mscale_all_dim == 1.0
assert provider.yarn_correction_range_round_to_int is False


class TestMambaModelProviderProvideWithYarn:
"""Tests for YaRN handling in MambaModelProvider.provide()."""

def _make_provider(self, **kwargs):
defaults = dict(
num_layers=2,
hidden_size=128,
num_attention_heads=1,
vocab_size=1000,
tensor_model_parallel_size=1,
make_vocab_size_divisible_by=128,
position_embedding_type="yarn",
seq_length=4096,
yarn_rotary_scaling_factor=8.0,
)
defaults.update(kwargs)
provider = MambaModelProvider(**defaults)
provider._pg_collection = type("PG", (), {"pp": object()})()
return provider

def test_yarn_original_max_defaulted_from_seq_length(self):
"""When yarn_original_max_position_embeddings is None, provide() fills it in."""
provider = self._make_provider()
assert provider.yarn_original_max_position_embeddings is None

with patch("megatron.bridge.models.mamba.mamba_provider.calculate_padded_vocab_size", return_value=1024):
with patch("megatron.bridge.models.mamba.mamba_provider.MCoreMambaModel"):
provider.provide(pre_process=True, post_process=True)

assert provider.yarn_original_max_position_embeddings == int(4096 / 8.0)

def test_yarn_original_max_explicit_value_preserved(self):
"""An explicit yarn_original_max_position_embeddings is not overwritten."""
provider = self._make_provider(yarn_original_max_position_embeddings=256)

with patch("megatron.bridge.models.mamba.mamba_provider.calculate_padded_vocab_size", return_value=1024):
with patch("megatron.bridge.models.mamba.mamba_provider.MCoreMambaModel"):
provider.provide(pre_process=True, post_process=True)

assert provider.yarn_original_max_position_embeddings == 256

def test_no_yarn_default_injection_for_rope(self):
"""yarn_original_max_position_embeddings should stay None when using rope."""
provider = MambaModelProvider(
num_layers=2,
hidden_size=128,
num_attention_heads=1,
vocab_size=1000,
tensor_model_parallel_size=1,
position_embedding_type="rope",
)
provider._pg_collection = type("PG", (), {"pp": object()})()

with patch("megatron.bridge.models.mamba.mamba_provider.calculate_padded_vocab_size", return_value=1024):
with patch("megatron.bridge.models.mamba.mamba_provider.MCoreMambaModel"):
provider.provide(pre_process=True, post_process=True)

assert provider.yarn_original_max_position_embeddings is None

def test_position_embedding_type_yarn_forwarded_to_mcore(self):
"""position_embedding_type='yarn' must be forwarded to MCoreMambaModel."""
provider = self._make_provider()

with patch("megatron.bridge.models.mamba.mamba_provider.calculate_padded_vocab_size", return_value=1024):
with patch("megatron.bridge.models.mamba.mamba_provider.MCoreMambaModel") as mock_mamba:
mock_mamba.return_value = Mock()
provider.provide(pre_process=True, post_process=True)

assert mock_mamba.call_args.kwargs["position_embedding_type"] == "yarn"
Loading