@@ -527,3 +527,115 @@ def test_default_parameters_forwarded(self, mock_unimodal, mock_compose):
527527 assert args [8 ] is False # data_parallel_random_init
528528 assert args [9 ] is Float16Module # mixed_precision_wrapper
529529 assert args [11 ] is ModelType .encoder_or_decoder # model_type
530+
531+
532+ # =============================================================================
533+ # Section 4 — YaRN positional embedding support
534+ # =============================================================================
535+
536+
537+ class TestMambaModelConfigYarnDefaults :
538+ """Tests that MambaModelConfig exposes the expected YaRN field defaults."""
539+
540+ def test_yarn_field_defaults (self ):
541+ config = _make_mamba_config ()
542+ assert config .yarn_rotary_scaling_factor == 8.0
543+ assert config .yarn_original_max_position_embeddings is None
544+ assert config .yarn_beta_fast == 32.0
545+ assert config .yarn_beta_slow == 1.0
546+ assert config .yarn_mscale == 1.0
547+ assert config .yarn_mscale_all_dim == 0.0
548+ assert config .yarn_correction_range_round_to_int is True
549+
550+ def test_yarn_position_embedding_type_accepted (self ):
551+ config = _make_mamba_config (position_embedding_type = "yarn" )
552+ assert config .position_embedding_type == "yarn"
553+
554+ def test_yarn_fields_settable (self ):
555+ config = _make_mamba_config (
556+ position_embedding_type = "yarn" ,
557+ yarn_rotary_scaling_factor = 4.0 ,
558+ yarn_original_max_position_embeddings = 1024 ,
559+ yarn_beta_fast = 16.0 ,
560+ yarn_beta_slow = 0.5 ,
561+ yarn_mscale = 0.8 ,
562+ yarn_mscale_all_dim = 1.0 ,
563+ yarn_correction_range_round_to_int = False ,
564+ )
565+ assert config .yarn_rotary_scaling_factor == 4.0
566+ assert config .yarn_original_max_position_embeddings == 1024
567+ assert config .yarn_beta_fast == 16.0
568+ assert config .yarn_beta_slow == 0.5
569+ assert config .yarn_mscale == 0.8
570+ assert config .yarn_mscale_all_dim == 1.0
571+ assert config .yarn_correction_range_round_to_int is False
572+
573+
574+ class TestMambaModelBuilderBuildModelWithYarn :
575+ """Tests for YaRN attribute injection in MambaModelBuilder.build_model()."""
576+
577+ def setup_method (self ):
578+ self .config = _make_mamba_config (
579+ vocab_size = 32000 ,
580+ seq_length = 4096 ,
581+ position_embedding_type = "yarn" ,
582+ yarn_rotary_scaling_factor = 8.0 ,
583+ )
584+ self .builder = MambaModelBuilder (self .config )
585+ self .pg = Mock ()
586+ self .pg .pp = Mock ()
587+
588+ @patch ("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size" )
589+ @patch ("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage" , return_value = True )
590+ @patch ("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage" , return_value = True )
591+ @patch ("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel" )
592+ def test_yarn_attrs_injected_onto_transformer (self , mock_model , * _ ):
593+ """All YaRN attrs must be set on the embedded TransformerConfig before MCoreMambaModel is called."""
594+ self .builder .build_model (self .pg , pre_process = True , post_process = True )
595+ t = self .config .transformer
596+ assert t .yarn_rotary_scaling_factor == 8.0
597+ assert t .yarn_beta_fast == 32.0
598+ assert t .yarn_beta_slow == 1.0
599+ assert t .yarn_mscale == 1.0
600+ assert t .yarn_mscale_all_dim == 0.0
601+ assert t .yarn_correction_range_round_to_int is True
602+
603+ @patch ("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size" )
604+ @patch ("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage" , return_value = True )
605+ @patch ("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage" , return_value = True )
606+ @patch ("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel" )
607+ def test_yarn_original_max_defaulted_from_seq_length (self , mock_model , * _ ):
608+ """None yarn_original_max_position_embeddings defaults to seq_length / scaling_factor."""
609+ assert self .config .yarn_original_max_position_embeddings is None
610+ self .builder .build_model (self .pg , pre_process = True , post_process = True )
611+ expected = int (self .config .seq_length / self .config .yarn_rotary_scaling_factor )
612+ assert self .config .transformer .yarn_original_max_position_embeddings == expected
613+
614+ @patch ("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size" )
615+ @patch ("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage" , return_value = True )
616+ @patch ("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage" , return_value = True )
617+ @patch ("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel" )
618+ def test_yarn_original_max_explicit_value_preserved (self , mock_model , * _ ):
619+ """An explicit yarn_original_max_position_embeddings is passed through unchanged."""
620+ self .config .__dict__ ["yarn_original_max_position_embeddings" ] = 512
621+ self .builder .build_model (self .pg , pre_process = True , post_process = True )
622+ assert self .config .transformer .yarn_original_max_position_embeddings == 512
623+
624+ @patch ("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size" )
625+ @patch ("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage" , return_value = True )
626+ @patch ("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage" , return_value = True )
627+ @patch ("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel" )
628+ def test_no_yarn_injection_for_rope (self , mock_model , * _ ):
629+ """YaRN attrs must NOT be injected when position_embedding_type is 'rope'."""
630+ config = _make_mamba_config (vocab_size = 32000 , position_embedding_type = "rope" )
631+ MambaModelBuilder (config ).build_model (self .pg , pre_process = True , post_process = True )
632+ assert not hasattr (config .transformer , "yarn_rotary_scaling_factor" )
633+
634+ @patch ("megatron.bridge.models.mamba.mamba_builder.calculate_padded_vocab_size" )
635+ @patch ("megatron.bridge.models.mamba.mamba_builder.is_pp_last_stage" , return_value = True )
636+ @patch ("megatron.bridge.models.mamba.mamba_builder.is_pp_first_stage" , return_value = True )
637+ @patch ("megatron.bridge.models.mamba.mamba_builder.MCoreMambaModel" )
638+ def test_position_embedding_type_yarn_forwarded_to_mcore (self , mock_model , * _ ):
639+ """position_embedding_type='yarn' must be forwarded to MCoreMambaModel."""
640+ self .builder .build_model (self .pg , pre_process = True , post_process = True )
641+ assert mock_model .call_args .kwargs ["position_embedding_type" ] == "yarn"
0 commit comments