Skip to content

Commit ec9b8d7

Browse files
committed
Fix test
1 parent 7b1d88d commit ec9b8d7

1 file changed

Lines changed: 11 additions & 8 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/test_evo2_lora_1.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -517,19 +517,22 @@ def test_lora_checkpoint_excludes_frozen_embeddings(self, tmp_path: Path, base_c
517517
assert len(adapter_keys) > 0, "Checkpoint should still contain LoRA adapter keys."
518518

519519
@pytest.mark.parametrize(
520-
"skip_module, expected_key_substr, lora_targets",
520+
"skip_freeze, expected_key_substr, lora_targets",
521521
[
522-
("word_embeddings", "word_embeddings", None),
523-
("final_norm", "final_norm", None),
524-
("dense", "mixer.dense.", None),
525-
("linear_fc2", "mlp.linear_fc2.", ["dense_projection", "linear_qkv", "linear_proj", "linear_fc1"]),
522+
# word_embeddings and output_layer share a weight tensor when tying is enabled;
523+
# both must appear in skip_freeze to satisfy the symmetry contract.
524+
(["word_embeddings", "output_layer"], "word_embeddings", None),
525+
(["final_norm"], "final_norm", None),
526+
(["dense"], "mixer.dense.", None),
527+
(["linear_fc2"], "mlp.linear_fc2.", ["dense_projection", "linear_qkv", "linear_proj", "linear_fc1"]),
526528
],
529+
ids=["word_embeddings", "final_norm", "dense", "linear_fc2"],
527530
)
528531
def test_lora_skip_freeze_saves_and_trains_module(
529532
self,
530533
tmp_path: Path,
531534
base_ckpt: Path,
532-
skip_module: str,
535+
skip_freeze: list[str],
533536
expected_key_substr: str,
534537
lora_targets: list[str] | None,
535538
):
@@ -538,12 +541,12 @@ def test_lora_skip_freeze_saves_and_trains_module(
538541

539542
from bionemo.evo2.models.evo2_provider import hyena_forward_step
540543

541-
lora_dir = tmp_path / f"lora_{skip_module}"
544+
lora_dir = tmp_path / f"lora_{skip_freeze[0]}"
542545
cfg = _build_pretrain_config(
543546
lora_dir,
544547
train_iters=1,
545548
lora=True,
546-
skip_freeze=[skip_module],
549+
skip_freeze=skip_freeze,
547550
lora_target_modules=lora_targets,
548551
pretrained_ckpt_dir=str(base_ckpt),
549552
)

0 commit comments

Comments
 (0)