Skip to content

Commit d4d8ee4

Browse files
committed
Fix test
1 parent 7b1d88d commit d4d8ee4

2 files changed

Lines changed: 33 additions & 12 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_lora.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,21 +157,39 @@ def _validate_tied_weight_config(self, model: ModelType) -> None:
157157
``output_layer`` share the same weight tensor. Both lists must treat the
158158
tied pair as a unit: list both layers or neither. Listing only one side
159159
raises ``ValueError``.
160+
161+
The check walks the actual model so that wildcard patterns (e.g.
162+
``"embedding.*"``) are evaluated against real module paths rather than
163+
synthetic names.
160164
"""
161165
if not self._get_is_tied(model):
162166
return
163167

164-
target_word_emb = self._matches_lora_target("word_embeddings", "word_embeddings")
165-
target_output = self._matches_lora_target("output_layer", "output_layer")
168+
targeted_short_names: set[str] = set()
169+
skip_frozen_short_names: set[str] = set()
170+
171+
def _collect(module: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module:
172+
full_name = f"{prefix}.{name}" if prefix else (name or "")
173+
short_name = name or ""
174+
if self._matches_lora_target(short_name, full_name):
175+
targeted_short_names.add(short_name)
176+
if self._matches_skip_freeze(short_name, full_name):
177+
skip_frozen_short_names.add(short_name)
178+
return module
179+
180+
self._walk_model(model, _collect)
181+
182+
target_word_emb = "word_embeddings" in targeted_short_names
183+
target_output = "output_layer" in targeted_short_names
166184
if target_word_emb != target_output:
167185
raise ValueError(
168186
"share_embeddings_and_output_weights is enabled: target_modules must "
169187
"include both word_embeddings and output_layer, or neither. "
170188
f"word_embeddings matched: {target_word_emb}, output_layer matched: {target_output}."
171189
)
172190

173-
skip_word_emb = self._matches_skip_freeze("word_embeddings", "word_embeddings")
174-
skip_output = self._matches_skip_freeze("output_layer", "output_layer")
191+
skip_word_emb = "word_embeddings" in skip_frozen_short_names
192+
skip_output = "output_layer" in skip_frozen_short_names
175193
if skip_word_emb != skip_output:
176194
raise ValueError(
177195
"share_embeddings_and_output_weights is enabled: skip_freeze_modules must "

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/test_evo2_lora_1.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -517,19 +517,22 @@ def test_lora_checkpoint_excludes_frozen_embeddings(self, tmp_path: Path, base_c
517517
assert len(adapter_keys) > 0, "Checkpoint should still contain LoRA adapter keys."
518518

519519
@pytest.mark.parametrize(
520-
"skip_module, expected_key_substr, lora_targets",
520+
"skip_freeze, expected_key_substr, lora_targets",
521521
[
522-
("word_embeddings", "word_embeddings", None),
523-
("final_norm", "final_norm", None),
524-
("dense", "mixer.dense.", None),
525-
("linear_fc2", "mlp.linear_fc2.", ["dense_projection", "linear_qkv", "linear_proj", "linear_fc1"]),
522+
# word_embeddings and output_layer share a weight tensor when tying is enabled;
523+
# both must appear in skip_freeze to satisfy the symmetry contract.
524+
(["word_embeddings", "output_layer"], "word_embeddings", None),
525+
(["final_norm"], "final_norm", None),
526+
(["dense"], "mixer.dense.", None),
527+
(["linear_fc2"], "mlp.linear_fc2.", ["dense_projection", "linear_qkv", "linear_proj", "linear_fc1"]),
526528
],
529+
ids=["word_embeddings", "final_norm", "dense", "linear_fc2"],
527530
)
528531
def test_lora_skip_freeze_saves_and_trains_module(
529532
self,
530533
tmp_path: Path,
531534
base_ckpt: Path,
532-
skip_module: str,
535+
skip_freeze: list[str],
533536
expected_key_substr: str,
534537
lora_targets: list[str] | None,
535538
):
@@ -538,12 +541,12 @@ def test_lora_skip_freeze_saves_and_trains_module(
538541

539542
from bionemo.evo2.models.evo2_provider import hyena_forward_step
540543

541-
lora_dir = tmp_path / f"lora_{skip_module}"
544+
lora_dir = tmp_path / f"lora_{skip_freeze[0]}"
542545
cfg = _build_pretrain_config(
543546
lora_dir,
544547
train_iters=1,
545548
lora=True,
546-
skip_freeze=[skip_module],
549+
skip_freeze=skip_freeze,
547550
lora_target_modules=lora_targets,
548551
pretrained_ckpt_dir=str(base_ckpt),
549552
)

0 commit comments

Comments
 (0)