@@ -517,19 +517,22 @@ def test_lora_checkpoint_excludes_frozen_embeddings(self, tmp_path: Path, base_c
517517 assert len (adapter_keys ) > 0 , "Checkpoint should still contain LoRA adapter keys."
518518
519519 @pytest .mark .parametrize (
520- "skip_module , expected_key_substr, lora_targets" ,
520+ "skip_freeze , expected_key_substr, lora_targets" ,
521521 [
522- ("word_embeddings" , "word_embeddings" , None ),
523- ("final_norm" , "final_norm" , None ),
524- ("dense" , "mixer.dense." , None ),
525- ("linear_fc2" , "mlp.linear_fc2." , ["dense_projection" , "linear_qkv" , "linear_proj" , "linear_fc1" ]),
522+ # word_embeddings and output_layer share a weight tensor when tying is enabled;
523+ # both must appear in skip_freeze to satisfy the symmetry contract.
524+ (["word_embeddings" , "output_layer" ], "word_embeddings" , None ),
525+ (["final_norm" ], "final_norm" , None ),
526+ (["dense" ], "mixer.dense." , None ),
527+ (["linear_fc2" ], "mlp.linear_fc2." , ["dense_projection" , "linear_qkv" , "linear_proj" , "linear_fc1" ]),
526528 ],
529+ ids = ["word_embeddings" , "final_norm" , "dense" , "linear_fc2" ],
527530 )
528531 def test_lora_skip_freeze_saves_and_trains_module (
529532 self ,
530533 tmp_path : Path ,
531534 base_ckpt : Path ,
532- skip_module : str ,
535+ skip_freeze : list [ str ] ,
533536 expected_key_substr : str ,
534537 lora_targets : list [str ] | None ,
535538 ):
@@ -538,12 +541,12 @@ def test_lora_skip_freeze_saves_and_trains_module(
538541
539542 from bionemo .evo2 .models .evo2_provider import hyena_forward_step
540543
541- lora_dir = tmp_path / f"lora_{ skip_module } "
544+ lora_dir = tmp_path / f"lora_{ skip_freeze [ 0 ] } "
542545 cfg = _build_pretrain_config (
543546 lora_dir ,
544547 train_iters = 1 ,
545548 lora = True ,
546- skip_freeze = [ skip_module ] ,
549+ skip_freeze = skip_freeze ,
547550 lora_target_modules = lora_targets ,
548551 pretrained_ckpt_dir = str (base_ckpt ),
549552 )
0 commit comments