Skip to content

Commit 1263f64

Browse files
authored
[BIO-303] support LoRA in evo2 mbridge (#1550)
### Description <!-- Provide a detailed description of the changes in this PR --> #### Usage <!--- How does a user interact with the changed code --> ```python TODO: Add code snippet ``` ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks). This label can be used to enforce running all framework tests. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. #### Triggering Code Rabbit AI Review To trigger a code review from code rabbit, comment on a pull request with one of these commands: - @coderabbitai review - Triggers a standard review - @coderabbitai full review - Triggers a comprehensive review See https://docs.coderabbit.ai/reference/review-commands for a full list of commands. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully --------- Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent e66c666 commit 1263f64

12 files changed

Lines changed: 1701 additions & 146 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/README.md

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,169 @@ Options:
266266
- `--mixed-precision-recipe` — precision recipe (default: `bf16_mixed`). NOTE for checkpoints sensitive to FP8 and Hopper you need to run with `--mixed-precision-recipe bf16-mixed` and also supply the `--vortex-style-fp8` option for prediction/inference, you should not use the fp8 recipe for those models, as they are sensitive to the exact FP8 configuration they were trained with in savanna, see the [table under the section on available nvidia checkpoints for download from NGC](#available-models-in-ngc-currently-nemo-format-so-first-convert-to-mbridge).
267267
- `--verbose` / `-v` — enable debug logging.
268268

269+
## LoRA Fine-tuning
270+
271+
`Evo2LoRA` is a LoRA variant built on top of the Megatron Bridge PEFT stack. It
272+
freezes the entire base model and attaches low-rank adapter matrices to the
273+
modules you specify, with an optional escape hatch to keep selected modules
274+
fully trainable.
275+
276+
### Basic usage
277+
278+
Add `--lora-finetune` to any `train_evo2` command alongside a checkpoint:
279+
280+
```bash
281+
torchrun --nproc-per-node 2 --no-python \
282+
train_evo2 \
283+
--hf-tokenizer-model-path tokenizers/nucleotide_fast_tokenizer_512 \
284+
--model-size evo2_1b_base --max-steps 500 --eval-interval 100 \
285+
--eval-iters 3 --mock-data \
286+
--micro-batch-size 4 --global-batch-size 8 --seq-length 1024 \
287+
--mixed-precision-recipe bf16_mixed \
288+
--result-dir lora_run \
289+
--finetune-ckpt-dir $CKPT_OUT_DIR \
290+
--lora-finetune \
291+
--lora-dim 16 \
292+
--lora-alpha 32 \
293+
--lora-dropout 0.1 \
294+
--lora-target-modules "dense_projection,linear_qkv,linear_proj,linear_fc1,linear_fc2"
295+
```
296+
297+
### LoRA configuration flags
298+
299+
| Flag | Default | Description |
300+
| ---------------------------- | ---------- | -------------------------------------------------------------------------------------------- |
301+
| `--lora-finetune` | *(absent)* | Presence flag. Pass to enable LoRA fine-tuning; omit for standard fine-tuning. |
302+
| `--lora-dim` | `16` | Rank `r` of the low-rank decomposition |
303+
| `--lora-alpha` | `32` | Scaling factor α; effective scale = α/r |
304+
| `--lora-dropout` | `0.1` | Dropout applied to the LoRA path |
305+
| `--lora-target-modules` | see below | Comma-separated list of module short-names to attach LoRA adapters to |
306+
| `--lora-skip-freeze-modules` | `""` | Comma-separated list of module short-names to leave **fully trainable** (no LoRA, no freeze) |
307+
308+
**Default `--lora-target-modules`:** `dense_projection,dense,linear_qkv,linear_proj,linear_fc1,linear_fc2`
309+
310+
These cover the dense projection inside each Hyena mixer (`dense_projection`,
311+
`dense`) and the four standard transformer MLP/attention projections
312+
(`linear_qkv`, `linear_proj`, `linear_fc1`, `linear_fc2`).
313+
314+
### Module name matching
315+
316+
Both `--lora-target-modules` and `--lora-skip-freeze-modules` use the same
317+
two-level matching syntax:
318+
319+
- **Short name** — matches any module whose immediate attribute name equals the
320+
pattern, regardless of depth (e.g. `"mixer"` matches
321+
`model.layers.3.mixer`).
322+
- **Wildcard path** — if the pattern contains `*`, it is matched against the
323+
full dotted path using `*` as a substring wildcard (e.g.
324+
`"*.layers.0.*.mixer"` matches only layer 0).
325+
326+
A module that matches `--lora-target-modules` will have its base weights frozen
327+
and LoRA adapter matrices attached. A module that matches
328+
`--lora-skip-freeze-modules` is left entirely unfrozen — its full weight is
329+
trainable — and no LoRA adapter is applied. If a module matches **both** lists,
330+
`Evo2LoRA` raises a `ValueError` at startup.
331+
332+
### Weight tying and shared embeddings
333+
334+
Evo2 models default to `share_embeddings_and_output_weights=True`. Under this
335+
setting, the vocabulary embedding table and the output projection **share the
336+
same weight tensor**: `embedding.word_embeddings.weight` owns the data and
337+
`output_layer` allocates no weight of its own (`output_layer.weight is None`).
338+
The output layer receives the embedding weight as a runtime argument during the
339+
forward pass.
340+
341+
This has direct consequences when you try to apply LoRA or control freezing on
342+
these layers.
343+
344+
**Constraint on `--lora-target-modules`:** `word_embeddings` is a
345+
`VocabParallelEmbedding` and does not support LoRA adapters in Megatron Bridge.
346+
Including it in `--lora-target-modules` always raises a `ValueError`, regardless
347+
of `share_embeddings_and_output_weights`. `output_layer` is a
348+
`ColumnParallelLinear` and *does* support LoRA, but only when
349+
`share_embeddings_and_output_weights=False`; when weight tying is enabled
350+
`output_layer.weight` is `None` and there is no independent weight tensor to
351+
attach an adapter to.
352+
353+
**Design principle for `--lora-skip-freeze-modules`:** `Evo2LoRA` treats weight
354+
tying as a contract that must be honoured in full. Any configuration that would
355+
change the trainability of only one side of a tied pair is rejected with an error
356+
rather than silently producing asymmetric behaviour.
357+
358+
#### `--lora-target-modules` and weight tying
359+
360+
| `share_embeddings_and_output_weights` | `--lora-target-modules` includes | Behavior |
361+
| :-----------------------------------: | --------------------------------------------------------- | ------------------------------------------------------------------------ |
362+
| Either | `word_embeddings` (alone or combined with `output_layer`) | **Error.** `VocabParallelEmbedding` does not support LoRA adapters. |
363+
| `True` | `output_layer` only | **Error.** `output_layer.weight` is `None` when weight tying is enabled. |
364+
| `False` | `output_layer` only | Valid — LoRA adapter on the independent output projection. |
365+
366+
#### `--lora-skip-freeze-modules` and weight tying
367+
368+
| `share_embeddings_and_output_weights` | `--lora-skip-freeze-modules` includes | Behavior |
369+
| :-----------------------------------: | ------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
370+
| `False` | `word_embeddings` only | Embedding weight is fully trainable. Output projection is frozen unless also listed. |
371+
| `False` | `output_layer` only | Output projection weight is fully trainable. Embedding is frozen unless also listed. |
372+
| `False` | both | Both weights are fully trainable. |
373+
| `True` | `word_embeddings` only | **Error.** Listing only one side of a tied pair breaks the weight-tying invariant. Both must be listed together. |
374+
| `True` | `output_layer` only | **Error.** Listing only one side of a tied pair breaks the weight-tying invariant. Both must be listed together. |
375+
| `True` | both | Accepted. The shared weight (owned by `word_embeddings`) is unfrozen, so both the embedding lookup and the output projection train via the same tensor. **Note:** because `output_layer` allocates no weight of its own, gradient flow through the output projection path back to the shared tensor is a TODO item and may not be fully wired in all pipeline-parallel configurations. |
376+
377+
#### Recommendations
378+
379+
- **Default (vocabulary weights frozen, LoRA on inner layers):** omit both
380+
embedding/output modules from both flags. The default `--lora-target-modules`
381+
does not touch either layer.
382+
- **Apply LoRA to the output projection (untied models only):** list
383+
`output_layer` in `--lora-target-modules` and set
384+
`share_embeddings_and_output_weights=False` in the model config.
385+
- **Fully fine-tune the vocabulary weight alongside LoRA on inner layers:**
386+
list **both** `word_embeddings` and `output_layer` in
387+
`--lora-skip-freeze-modules`.
388+
```
389+
--lora-skip-freeze-modules "word_embeddings,output_layer"
390+
```
391+
- **Never put `word_embeddings` in `--lora-target-modules`**`VocabParallelEmbedding`
392+
does not support LoRA adapters and will raise a `ValueError`.
393+
- **Never list only one of the two tied layers in `--lora-skip-freeze-modules`
394+
when `share_embeddings_and_output_weights=True`** — the invariant is that tied
395+
weights are always treated as a unit, and any asymmetric configuration will
396+
raise an error.
397+
398+
### Running inference on a LoRA checkpoint
399+
400+
A LoRA training checkpoint contains only adapter tensors — the base model weights
401+
are not duplicated. Point `--ckpt-dir` at the LoRA `iter_*` directory as usual:
402+
403+
```bash
404+
torchrun --nproc_per_node 1 --no-python \
405+
infer_evo2 \
406+
--ckpt-dir </path/to/lora_run/checkpoints/> \
407+
--prompt "ATCGATCGATCGATCG" \
408+
--max-new-tokens 200
409+
```
410+
411+
```bash
412+
torchrun --nproc_per_node 1 --no-python \
413+
predict_evo2 \
414+
--fasta <path/to/fasta/sequences> \
415+
--ckpt-dir </path/to/lora_run/checkpoints/> \
416+
--output-dir ./predictions
417+
```
418+
419+
When `infer_evo2` / `predict_evo2` detect a `peft` section in the checkpoint's
420+
`run_config.yaml`, they:
421+
422+
1. load dense base weights from `checkpoint.pretrained_checkpoint` (the same
423+
value that was supplied during LoRA training),
424+
2. apply the stored PEFT config (`run_config["peft"]`) to graft `LoRALinear`
425+
wrappers onto the base modules,
426+
3. load only the adapter tensors from `--ckpt-dir`.
427+
428+
No merge step is required. The base checkpoint referenced by
429+
`pretrained_checkpoint` must still exist on disk at the path recorded in
430+
`run_config.yaml`.
431+
269432
## Exporting to Vortex format
270433

271434
Vortex is ARC Institute's inference format for Evo2 Hyena models, used by the

0 commit comments

Comments
 (0)