Skip to content

Commit 0acb904

Browse files
committed
add yarn support for mamba_model
Signed-off-by: guihong-nv <guihongl@nvidia.com>
1 parent 0bf1333 commit 0acb904

4 files changed

Lines changed: 267 additions & 2 deletions

File tree

src/megatron/bridge/models/mamba/mamba_builder.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,18 @@ class MambaModelConfig(ModelConfig):
112112
hybrid_layer_pattern: str | None = None
113113
seq_length: int = 8192
114114
# Mamba with no attention has no need for position embeddings, so none is default
115-
position_embedding_type: Literal["learned_absolute", "rope", "none"] = "none"
115+
position_embedding_type: Literal["learned_absolute", "rope", "yarn", "none"] = "none"
116116
rotary_percent: float = 1.0
117117
rotary_base: int = 10000
118118
seq_len_interpolation_factor: float | None = None
119+
# YaRN parameters — only applied when position_embedding_type == "yarn"
120+
yarn_rotary_scaling_factor: float = 8.0
121+
yarn_original_max_position_embeddings: int | None = None
122+
yarn_beta_fast: float = 32.0
123+
yarn_beta_slow: float = 1.0
124+
yarn_mscale: float = 1.0
125+
yarn_mscale_all_dim: float = 0.0
126+
yarn_correction_range_round_to_int: bool = True
119127
make_vocab_size_divisible_by: int = 128
120128
mamba_stack_spec: ModuleSpec | Callable[[], ModuleSpec] | Callable[["MambaModelConfig"], ModuleSpec] = (
121129
get_default_mamba_stack_spec
@@ -224,6 +232,22 @@ def build_model(
224232

225233
pre_process = pre_process if pre_process is not None else is_pp_first_stage(pg_collection.pp)
226234
post_process = post_process if post_process is not None else is_pp_last_stage(pg_collection.pp)
235+
236+
if self._model_config.position_embedding_type == "yarn":
237+
cfg = self._model_config
238+
t = cfg.transformer
239+
t.yarn_rotary_scaling_factor = cfg.yarn_rotary_scaling_factor
240+
t.yarn_original_max_position_embeddings = (
241+
cfg.yarn_original_max_position_embeddings
242+
if cfg.yarn_original_max_position_embeddings is not None
243+
else int(cfg.seq_length / cfg.yarn_rotary_scaling_factor)
244+
)
245+
t.yarn_beta_fast = cfg.yarn_beta_fast
246+
t.yarn_beta_slow = cfg.yarn_beta_slow
247+
t.yarn_mscale = cfg.yarn_mscale
248+
t.yarn_mscale_all_dim = cfg.yarn_mscale_all_dim
249+
t.yarn_correction_range_round_to_int = cfg.yarn_correction_range_round_to_int
250+
227251
return MCoreMambaModel(
228252
config=self._model_config.transformer,
229253
mamba_stack_spec=mamba_stack_spec,

src/megatron/bridge/models/mamba/mamba_provider.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,19 @@ class MambaModelProvider(TransformerConfig, ModelProviderMixin[MCoreMambaModel])
148148
hybrid_layer_pattern: Optional[str] = None
149149
seq_length: int = 8192
150150
# Mamba with no attention has no need for position embeddings, so none is default
151-
position_embedding_type: Literal["learned_absolute", "rope", "none"] = "none"
151+
position_embedding_type: Literal["learned_absolute", "rope", "yarn", "none"] = "none"
152152
rotary_percent: float = 1.0
153153
rotary_base: int = 10000
154154
seq_len_interpolation_factor: Optional[float] = None
155155
apply_rope_fusion: bool = True
156+
# YaRN parameters — only applied when position_embedding_type == "yarn"
157+
yarn_rotary_scaling_factor: float = 8.0
158+
yarn_original_max_position_embeddings: Optional[int] = None
159+
yarn_beta_fast: float = 32.0
160+
yarn_beta_slow: float = 1.0
161+
yarn_mscale: float = 1.0
162+
yarn_mscale_all_dim: float = 0.0
163+
yarn_correction_range_round_to_int: bool = True
156164
make_vocab_size_divisible_by: int = 128
157165
gated_linear_unit: bool = False
158166
normalization: str = "RMSNorm"
@@ -294,6 +302,9 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreMa
294302
"models due to upstream MCore MambaModel API dependency"
295303
)
296304

305+
if self.position_embedding_type == "yarn" and self.yarn_original_max_position_embeddings is None:
306+
self.yarn_original_max_position_embeddings = int(self.seq_length / self.yarn_rotary_scaling_factor)
307+
297308
assert self.vocab_size is not None, "vocab_size must be configured before calling provide()"
298309
if self.should_pad_vocab:
299310
padded_vocab_size = calculate_padded_vocab_size(

tests/unit_tests/models/mamba/test_mamba_builder.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,3 +527,115 @@ def test_default_parameters_forwarded(self, mock_unimodal, mock_compose):
527527
assert args[8] is False # data_parallel_random_init
528528
assert args[9] is Float16Module # mixed_precision_wrapper
529529
assert args[11] is ModelType.encoder_or_decoder # model_type
530+
531+
532+
# =============================================================================
533+
# Section 4 — YaRN positional embedding support
534+
# =============================================================================
535+
536+
537+
class TestMambaModelConfigYarnDefaults:
538+
"""Tests that MambaModelConfig exposes the expected YaRN field defaults."""
539+
540+
def test_yarn_field_defaults(self):
541+
config = _make_mamba_config()
542+
assert config.yarn_rotary_scaling_factor == 8.0
543+
assert config.yarn_original_max_position_embeddings is None
544+
assert config.yarn_beta_fast == 32.0
545+
assert config.yarn_beta_slow == 1.0
546+
assert config.yarn_mscale == 1.0
547+
assert config.yarn_mscale_all_dim == 0.0
548+
assert config.yarn_correction_range_round_to_int is True
549+
550+
def test_yarn_position_embedding_type_accepted(self):
551+
config = _make_mamba_config(position_embedding_type="yarn")
552+
assert config.position_embedding_type == "yarn"
553+
554+
def test_yarn_fields_settable(self):
555+
config = _make_mamba_config(
556+
position_embedding_type="yarn",
557+
yarn_rotary_scaling_factor=4.0,
558+
yarn_original_max_position_embeddings=1024,
559+
yarn_beta_fast=16.0,
560+
yarn_beta_slow=0.5,
561+
yarn_mscale=0.8,
562+
yarn_mscale_all_dim=1.0,
563+
yarn_correction_range_round_to_int=False,
564+
)
565+
assert config.yarn_rotary_scaling_factor == 4.0
566+
assert config.yarn_original_max_position_embeddings == 1024
567+
assert config.yarn_beta_fast == 16.0
568+
assert config.yarn_beta_slow == 0.5
569+
assert config.yarn_mscale == 0.8
570+
assert config.yarn_mscale_all_dim == 1.0
571+
assert config.yarn_correction_range_round_to_int is False
572+
573+
574+
class TestMambaModelBuilderBuildModelWithYarn:
575+
"""Tests for YaRN attribute injection in MambaModelBuilder.build_model()."""
576+
577+
def setup_method(self):
578+
self.config = _make_mamba_config(
579+
vocab_size=32000,
580+
seq_length=4096,
581+
position_embedding_type="yarn",
582+
yarn_rotary_scaling_factor=8.0,
583+
)
584+
self.builder = MambaModelBuilder(self.config)
585+
self.pg = Mock()
586+
self.pg.pp = Mock()
587+
588+
@patch("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size")
589+
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage", return_value=True)
590+
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage", return_value=True)
591+
@patch("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel")
592+
def test_yarn_attrs_injected_onto_transformer(self, mock_model, *_):
593+
"""All YaRN attrs must be set on the embedded TransformerConfig before MCoreMambaModel is called."""
594+
self.builder.build_model(self.pg, pre_process=True, post_process=True)
595+
t = self.config.transformer
596+
assert t.yarn_rotary_scaling_factor == 8.0
597+
assert t.yarn_beta_fast == 32.0
598+
assert t.yarn_beta_slow == 1.0
599+
assert t.yarn_mscale == 1.0
600+
assert t.yarn_mscale_all_dim == 0.0
601+
assert t.yarn_correction_range_round_to_int is True
602+
603+
@patch("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size")
604+
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage", return_value=True)
605+
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage", return_value=True)
606+
@patch("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel")
607+
def test_yarn_original_max_defaulted_from_seq_length(self, mock_model, *_):
608+
"""None yarn_original_max_position_embeddings defaults to seq_length / scaling_factor."""
609+
assert self.config.yarn_original_max_position_embeddings is None
610+
self.builder.build_model(self.pg, pre_process=True, post_process=True)
611+
expected = int(self.config.seq_length / self.config.yarn_rotary_scaling_factor)
612+
assert self.config.transformer.yarn_original_max_position_embeddings == expected
613+
614+
@patch("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size")
615+
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage", return_value=True)
616+
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage", return_value=True)
617+
@patch("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel")
618+
def test_yarn_original_max_explicit_value_preserved(self, mock_model, *_):
619+
"""An explicit yarn_original_max_position_embeddings is passed through unchanged."""
620+
self.config.__dict__["yarn_original_max_position_embeddings"] = 512
621+
self.builder.build_model(self.pg, pre_process=True, post_process=True)
622+
assert self.config.transformer.yarn_original_max_position_embeddings == 512
623+
624+
@patch("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size")
625+
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage", return_value=True)
626+
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage", return_value=True)
627+
@patch("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel")
628+
def test_no_yarn_injection_for_rope(self, mock_model, *_):
629+
"""YaRN attrs must NOT be injected when position_embedding_type is 'rope'."""
630+
config = _make_mamba_config(vocab_size=32000, position_embedding_type="rope")
631+
MambaModelBuilder(config).build_model(self.pg, pre_process=True, post_process=True)
632+
assert not hasattr(config.transformer, "yarn_rotary_scaling_factor")
633+
634+
@patch("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size")
635+
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage", return_value=True)
636+
@patch("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage", return_value=True)
637+
@patch("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel")
638+
def test_position_embedding_type_yarn_forwarded_to_mcore(self, mock_model, *_):
639+
"""position_embedding_type='yarn' must be forwarded to MCoreMambaModel."""
640+
self.builder.build_model(self.pg, pre_process=True, post_process=True)
641+
assert mock_model.call_args.kwargs["position_embedding_type"] == "yarn"

tests/unit_tests/models/mamba/test_mamba_provider.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,121 @@ def test_finalize_uses_compatible_hybrid_layer_count(self):
337337

338338
assert provider.num_layers == 9
339339
mock_finalize.assert_called_once_with(provider)
340+
341+
342+
# =============================================================================
343+
# YaRN positional embedding support
344+
# =============================================================================
345+
346+
347+
class TestMambaModelProviderYarnDefaults:
348+
"""Tests that MambaModelProvider exposes the expected YaRN field defaults."""
349+
350+
def test_yarn_field_defaults(self):
351+
provider = MambaModelProvider(num_layers=2, hidden_size=128, num_attention_heads=1)
352+
assert provider.yarn_rotary_scaling_factor == 8.0
353+
assert provider.yarn_original_max_position_embeddings is None
354+
assert provider.yarn_beta_fast == 32.0
355+
assert provider.yarn_beta_slow == 1.0
356+
assert provider.yarn_mscale == 1.0
357+
assert provider.yarn_mscale_all_dim == 0.0
358+
assert provider.yarn_correction_range_round_to_int is True
359+
360+
def test_yarn_position_embedding_type_accepted(self):
361+
provider = MambaModelProvider(
362+
num_layers=2, hidden_size=128, num_attention_heads=1, position_embedding_type="yarn"
363+
)
364+
assert provider.position_embedding_type == "yarn"
365+
366+
def test_yarn_custom_fields(self):
367+
provider = MambaModelProvider(
368+
num_layers=2,
369+
hidden_size=128,
370+
num_attention_heads=1,
371+
position_embedding_type="yarn",
372+
yarn_rotary_scaling_factor=4.0,
373+
yarn_original_max_position_embeddings=512,
374+
yarn_beta_fast=16.0,
375+
yarn_beta_slow=0.5,
376+
yarn_mscale=0.8,
377+
yarn_mscale_all_dim=1.0,
378+
yarn_correction_range_round_to_int=False,
379+
)
380+
assert provider.yarn_rotary_scaling_factor == 4.0
381+
assert provider.yarn_original_max_position_embeddings == 512
382+
assert provider.yarn_beta_fast == 16.0
383+
assert provider.yarn_beta_slow == 0.5
384+
assert provider.yarn_mscale == 0.8
385+
assert provider.yarn_mscale_all_dim == 1.0
386+
assert provider.yarn_correction_range_round_to_int is False
387+
388+
389+
class TestMambaModelProviderProvideWithYarn:
390+
"""Tests for YaRN handling in MambaModelProvider.provide()."""
391+
392+
def _make_provider(self, **kwargs):
393+
defaults = dict(
394+
num_layers=2,
395+
hidden_size=128,
396+
num_attention_heads=1,
397+
vocab_size=1000,
398+
tensor_model_parallel_size=1,
399+
make_vocab_size_divisible_by=128,
400+
position_embedding_type="yarn",
401+
seq_length=4096,
402+
yarn_rotary_scaling_factor=8.0,
403+
)
404+
defaults.update(kwargs)
405+
provider = MambaModelProvider(**defaults)
406+
provider._pg_collection = type("PG", (), {"pp": object()})()
407+
return provider
408+
409+
def test_yarn_original_max_defaulted_from_seq_length(self):
410+
"""When yarn_original_max_position_embeddings is None, provide() fills it in."""
411+
provider = self._make_provider()
412+
assert provider.yarn_original_max_position_embeddings is None
413+
414+
with patch("megatron.bridge.models.mamba.mamba_provider.calculate_padded_vocab_size", return_value=1024):
415+
with patch("megatron.bridge.models.mamba.mamba_provider.MCoreMambaModel"):
416+
provider.provide(pre_process=True, post_process=True)
417+
418+
assert provider.yarn_original_max_position_embeddings == int(4096 / 8.0)
419+
420+
def test_yarn_original_max_explicit_value_preserved(self):
421+
"""An explicit yarn_original_max_position_embeddings is not overwritten."""
422+
provider = self._make_provider(yarn_original_max_position_embeddings=256)
423+
424+
with patch("megatron.bridge.models.mamba.mamba_provider.calculate_padded_vocab_size", return_value=1024):
425+
with patch("megatron.bridge.models.mamba.mamba_provider.MCoreMambaModel"):
426+
provider.provide(pre_process=True, post_process=True)
427+
428+
assert provider.yarn_original_max_position_embeddings == 256
429+
430+
def test_no_yarn_default_injection_for_rope(self):
431+
"""yarn_original_max_position_embeddings should stay None when using rope."""
432+
provider = MambaModelProvider(
433+
num_layers=2,
434+
hidden_size=128,
435+
num_attention_heads=1,
436+
vocab_size=1000,
437+
tensor_model_parallel_size=1,
438+
position_embedding_type="rope",
439+
)
440+
provider._pg_collection = type("PG", (), {"pp": object()})()
441+
442+
with patch("megatron.bridge.models.mamba.mamba_provider.calculate_padded_vocab_size", return_value=1024):
443+
with patch("megatron.bridge.models.mamba.mamba_provider.MCoreMambaModel"):
444+
provider.provide(pre_process=True, post_process=True)
445+
446+
assert provider.yarn_original_max_position_embeddings is None
447+
448+
def test_position_embedding_type_yarn_forwarded_to_mcore(self):
449+
"""position_embedding_type='yarn' must be forwarded to MCoreMambaModel."""
450+
provider = self._make_provider()
451+
452+
with patch("megatron.bridge.models.mamba.mamba_provider.calculate_padded_vocab_size", return_value=1024):
453+
with patch("megatron.bridge.models.mamba.mamba_provider.MCoreMambaModel") as mock_mamba:
454+
mock_mamba.return_value = Mock()
455+
provider.provide(pre_process=True, post_process=True)
456+
457+
assert mock_mamba.call_args.kwargs["position_embedding_type"] == "yarn"

0 commit comments

Comments
 (0)