Skip to content

Commit 73df1bb

Browse files
committed
Add more tests cases
1 parent 9ab73ee commit 73df1bb

2 files changed

Lines changed: 165 additions & 12 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/README.md

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,137 @@ Options:
266266
- `--mixed-precision-recipe` — precision recipe (default: `bf16_mixed`). NOTE for checkpoints sensitive to FP8 and Hopper you need to run with `--mixed-precision-recipe bf16-mixed` and also supply the `--vortex-style-fp8` option for prediction/inference, you should not use the fp8 recipe for those models, as they are sensitive to the exact FP8 configuration they were trained with in savanna, see the [table under the section on available nvidia checkpoints for download from NGC](#available-models-in-ngc-currently-nemo-format-so-first-convert-to-mbridge).
267267
- `--verbose` / `-v` — enable debug logging.
268268

269+
## LoRA Fine-tuning
270+
271+
`Evo2LoRA` is a LoRA variant built on top of the Megatron Bridge PEFT stack. It
272+
freezes the entire base model and attaches low-rank adapter matrices to the
273+
modules you specify, with an optional escape hatch to keep selected modules
274+
fully trainable.
275+
276+
### Basic usage
277+
278+
Add `--lora-finetune` to any `train_evo2` command alongside a checkpoint:
279+
280+
```bash
281+
torchrun --nproc-per-node 2 --no-python \
282+
train_evo2 \
283+
--hf-tokenizer-model-path tokenizers/nucleotide_fast_tokenizer_512 \
284+
--model-size evo2_1b_base --max-steps 500 --eval-interval 100 \
285+
--eval-iters 3 --mock-data \
286+
--micro-batch-size 4 --global-batch-size 8 --seq-length 1024 \
287+
--mixed-precision-recipe bf16_mixed \
288+
--result-dir lora_run \
289+
--finetune-ckpt-dir $CKPT_OUT_DIR \
290+
--lora-finetune \
291+
--lora-dim 16 \
292+
--lora-alpha 32 \
293+
--lora-dropout 0.1 \
294+
--lora-target-modules "dense_projection,linear_qkv,linear_proj,linear_fc1,linear_fc2"
295+
```
296+
297+
### LoRA configuration flags
298+
299+
| Flag | Default | Description |
300+
| ---------------------------- | ---------- | -------------------------------------------------------------------------------------------- |
301+
| `--lora-finetune` | *(absent)* | Presence flag. Pass to enable LoRA fine-tuning; omit for standard fine-tuning. |
302+
| `--lora-dim` | `16` | Rank `r` of the low-rank decomposition |
303+
| `--lora-alpha` | `32` | Scaling factor α; effective scale = α/r |
304+
| `--lora-dropout` | `0.1` | Dropout applied to the LoRA path |
305+
| `--lora-target-modules` | see below | Comma-separated list of module short-names to attach LoRA adapters to |
306+
| `--lora-skip-freeze-modules` | `""` | Comma-separated list of module short-names to leave **fully trainable** (no LoRA, no freeze) |
307+
308+
**Default `--lora-target-modules`:** `dense_projection,dense,linear_qkv,linear_proj,linear_fc1,linear_fc2`
309+
310+
These cover the dense projection inside each Hyena mixer (`dense_projection`,
311+
`dense`) and the four standard transformer MLP/attention projections
312+
(`linear_qkv`, `linear_proj`, `linear_fc1`, `linear_fc2`).
313+
314+
### Module name matching
315+
316+
Both `--lora-target-modules` and `--lora-skip-freeze-modules` use the same
317+
two-level matching syntax:
318+
319+
- **Short name** — matches any module whose immediate attribute name equals the
320+
pattern, regardless of depth (e.g. `"mixer"` matches
321+
`model.layers.3.mixer`).
322+
- **Wildcard path** — if the pattern contains `*`, it is matched against the
323+
full dotted path using `*` as a substring wildcard (e.g.
324+
`"*.layers.0.*.mixer"` matches only layer 0).
325+
326+
A module that matches `--lora-target-modules` will have its base weights frozen
327+
and LoRA adapter matrices attached. A module that matches
328+
`--lora-skip-freeze-modules` is left entirely unfrozen — its full weight is
329+
trainable — and no LoRA adapter is applied. If a module matches **both** lists,
330+
`Evo2LoRA` raises a `ValueError` at startup.
331+
332+
### Weight tying and shared embeddings
333+
334+
Evo2 models default to `share_embeddings_and_output_weights=True`. Under this
335+
setting, the vocabulary embedding table and the output projection **share the
336+
same weight tensor**: `embedding.word_embeddings.weight` owns the data and
337+
`output_layer` allocates no weight of its own (`output_layer.weight is None`).
338+
The output layer receives the embedding weight as a runtime argument during the
339+
forward pass.
340+
341+
This has direct consequences when you try to apply LoRA or control freezing on
342+
these layers.
343+
344+
**Design principle:** `Evo2LoRA` treats weight tying as a contract that must be
345+
honoured in full. Any LoRA configuration that would apply adapters or change the
346+
trainability of only one side of a tied pair is rejected with an error rather
347+
than silently producing asymmetric behaviour. If you genuinely need to treat the
348+
embedding and output projection as independent modules — for example to apply
349+
LoRA to one but not the other — you must first opt out of weight tying by
350+
setting `share_embeddings_and_output_weights=False` in the model config. Making
351+
the intent explicit at the model level prevents hard-to-diagnose inconsistencies
352+
during training and checkpoint export.
353+
354+
#### `--lora-target-modules` and weight tying
355+
356+
| `share_embeddings_and_output_weights` | `--lora-target-modules` includes | Behavior |
357+
| :-----------------------------------: | -------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
358+
| `False` | `word_embeddings` only | LoRA adapter on the embedding lookup. Output projection weight is independent and frozen by default. |
359+
| `False` | `output_layer` only | LoRA adapter on the output projection. Embedding weight is independent and frozen by default. |
360+
| `False` | both | Independent LoRA adapters on both layers. Both base weights are frozen. |
361+
| `True` | `word_embeddings` only | **Error.** Applying LoRA to only one side of a tied pair breaks the weight-tying invariant. Both must be listed together. |
362+
| `True` | `output_layer` only | **Error.** Applying LoRA to only one side of a tied pair breaks the weight-tying invariant. Both must be listed together. |
363+
| `True` | both | **Not yet implemented.** Symmetric LoRA on a tied weight pair requires a transpose-view adapter mechanism (see note below). This combination is accepted as a design goal and will raise a `NotImplementedError` until it is implemented. |
364+
365+
> **Symmetric LoRA on tied weights (future work).** When both `word_embeddings`
366+
> and `output_layer` are targeted with weight tying enabled, the correct
367+
> approach is to apply a single LoRA decomposition to the shared weight and
368+
> expose it symmetrically to both the embedding lookup and the output
369+
> projection — analogous to HuggingFace PEFT's `ensure_weight_tying` mechanism,
370+
> which shares the adapter parameters via transposed views. This is not yet
371+
> implemented.
372+
373+
#### `--lora-skip-freeze-modules` and weight tying
374+
375+
| `share_embeddings_and_output_weights` | `--lora-skip-freeze-modules` includes | Behavior |
376+
| :-----------------------------------: | ------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
377+
| `False` | `word_embeddings` only | Embedding weight is fully trainable. Output projection is frozen unless also listed. |
378+
| `False` | `output_layer` only | Output projection weight is fully trainable. Embedding is frozen unless also listed. |
379+
| `False` | both | Both weights are fully trainable. |
380+
| `True` | `word_embeddings` only | **Error.** Listing only one side of a tied pair breaks the weight-tying invariant. Both must be listed together. |
381+
| `True` | `output_layer` only | **Error.** Listing only one side of a tied pair breaks the weight-tying invariant. Both must be listed together. |
382+
| `True` | both | Accepted. The shared weight (owned by `word_embeddings`) is unfrozen, so both the embedding lookup and the output projection train via the same tensor. **Note:** because `output_layer` allocates no weight of its own, gradient flow through the output projection path back to the shared tensor is a TODO item and may not be fully wired in all pipeline-parallel configurations. |
383+
384+
#### Recommendations
385+
386+
- **Default (vocabulary weights frozen, LoRA on inner layers):** omit both
387+
embedding/output modules from both flags. The default `--lora-target-modules`
388+
does not touch either layer.
389+
- **Fully fine-tune the shared vocabulary weight alongside LoRA on inner
390+
layers:** list **both** `word_embeddings` and `output_layer` in
391+
`--lora-skip-freeze-modules`.
392+
```
393+
--lora-skip-freeze-modules "word_embeddings,output_layer"
394+
```
395+
- **Never list only one of the two tied layers in either flag when
396+
`share_embeddings_and_output_weights=True`** — the invariant is that tied
397+
weights are always treated as a unit, and any asymmetric configuration will
398+
raise an error.
399+
269400
## Exporting to Vortex format
270401

271402
Vortex is ARC Institute's inference format for Evo2 Hyena models, used by the

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/test_evo2_lora_1.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,28 @@
3939
# ---------------------------------------------------------------------------
4040

4141

42+
class _MLP(nn.Module):
43+
def __init__(self, hidden: int, ffn: int):
44+
super().__init__()
45+
self.linear_fc1 = nn.Linear(hidden, ffn)
46+
self.linear_fc2 = nn.Linear(ffn, hidden)
47+
48+
def forward(self, x: torch.Tensor) -> torch.Tensor:
49+
return self.linear_fc2(torch.relu(self.linear_fc1(x)))
50+
51+
4252
class _SmallModel(nn.Module):
43-
"""Tiny model with an embedding layer and two linear layers for LoRA targeting."""
53+
"""Tiny model with nested structure so wildcard patterns like ``*.linear_fc2`` work."""
4454

4555
def __init__(self, vocab_size: int = 64, hidden: int = 32, ffn: int = 64):
4656
super().__init__()
4757
self.embedding = nn.ModuleDict({"word_embeddings": nn.Embedding(vocab_size, hidden)})
48-
self.linear_fc1 = nn.Linear(hidden, ffn)
49-
self.linear_fc2 = nn.Linear(ffn, hidden)
58+
self.mlp = _MLP(hidden, ffn)
5059
self.output_proj = nn.Linear(hidden, vocab_size)
5160

5261
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
5362
h = self.embedding.word_embeddings(input_ids)
54-
h = self.linear_fc2(torch.relu(self.linear_fc1(h)))
55-
return self.output_proj(h)
63+
return self.output_proj(self.mlp(h))
5664

5765

5866
class TestEvo2LoRAAdapterWiring:
@@ -83,11 +91,13 @@ def test_adapter_params_always_trainable(self):
8391
"target_modules, skip_freeze",
8492
[
8593
(["linear_fc1", "linear_fc2"], ["linear_fc2"]),
86-
(["linear_fc1", "*fc2"], ["linear_fc2"]),
87-
(["linear_*"], ["linear_fc2"]),
8894
(["linear_fc1"], ["*"]),
95+
(["*.linear_fc2"], ["linear_fc2"]),
96+
(["linear_fc2"], ["*.linear_fc2"]),
97+
(["mlp.*"], ["linear_fc2"]),
98+
(["mlp.*"], ["*.linear_*"]),
8999
],
90-
ids=["exact", "wildcard_target", "wildcard_target_glob", "wildcard_skip"],
100+
ids=["exact", "star_skip", "dotstar_target", "dotstar_skip", "parent_glob_target", "both_wildcards"],
91101
)
92102
def test_errors_on_target_skip_freeze_overlap(self, target_modules, skip_freeze):
93103
"""Evo2LoRA must raise ValueError when target and skip-freeze patterns overlap."""
@@ -105,7 +115,7 @@ def test_errors_on_target_skip_freeze_overlap(self, target_modules, skip_freeze)
105115
"target_modules, skip_freeze",
106116
[
107117
(["linear_fc1", "linear_fc2"], ["word_embeddings"]),
108-
(["linear_*"], ["do_not_exist"]),
118+
(["*.linear_*"], ["do_not_exist"]),
109119
(["do_not_exist"], ["*"]),
110120
],
111121
ids=["disjoint", "glob_target_no_skip_match", "no_target_match_star_skip"],
@@ -126,8 +136,6 @@ def test_no_error_when_skip_freeze_disjoint_from_targets(self, target_modules, s
126136
# Integration tests: pretrain() with LoRA + skip_freeze → checkpoint → verify
127137
# ---------------------------------------------------------------------------
128138

129-
torch._dynamo.config.suppress_errors = True
130-
131139

132140
@dataclass
133141
class _TinyHyenaProvider(Hyena1bModelProvider):
@@ -254,8 +262,21 @@ def _load_dist_checkpoint_tensors(ckpt_dir: Path, keys: list[str]) -> dict[str,
254262
return state_dict
255263

256264

265+
@pytest.fixture(scope="module")
266+
def _suppress_dynamo_errors():
267+
"""Suppress torch.compile errors for integration tests (broken Triton env).
268+
269+
Restores the original value when the module's tests are done so other
270+
test modules in the same process are unaffected.
271+
"""
272+
old = torch._dynamo.config.suppress_errors
273+
torch._dynamo.config.suppress_errors = True
274+
yield
275+
torch._dynamo.config.suppress_errors = old
276+
277+
257278
@pytest.fixture(scope="class")
258-
def base_ckpt(tmp_path_factory) -> Path:
279+
def base_ckpt(tmp_path_factory, _suppress_dynamo_errors) -> Path:
259280
"""Pretrain a base model once for the entire integration test class."""
260281
base_dir = tmp_path_factory.mktemp("base")
261282
return _pretrain_base_model(base_dir)
@@ -264,6 +285,7 @@ def base_ckpt(tmp_path_factory) -> Path:
264285
@pytest.mark.timeout(300)
265286
@pytest.mark.slow
266287
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
288+
@pytest.mark.usefixtures("_suppress_dynamo_errors")
267289
class TestEvo2LoRAPretrainIntegration:
268290
"""End-to-end: pretrain() with LoRA + skip_freeze → checkpoint → verify → resume.
269291

0 commit comments

Comments
 (0)