Skip to content

Commit 76ddb69

Browse files
mnabiancoreyjadams
authored andcommitted
PhysicsNeMo PEFT - LoRA (#1691)
* peft initial commit * docstring enhancement * modify extension * formatting * minor bug fixes * address greptile comments * address review comments * address initialization issue * address docstring and other minor comments * address LinearLike comment * is_compatible check * update tests * update example, fix pickle issue in a test * make te optional * fix lint * add codeowner
1 parent a1777e7 commit 76ddb69

21 files changed

Lines changed: 2913 additions & 0 deletions

File tree

.github/CODEOWNERS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ physicsnemo/models/vfgn/ @mnabian
253253
physicsnemo/experimental/
254254
physicsnemo/experimental/datapipes/healda/ @pzharrington
255255
physicsnemo/experimental/models/globe/ @peterdsharpe
256+
physicsnemo/experimental/peft/ @mnabian
256257

257258
# ==============================================================================
258259
# EXAMPLES - Active Learning
@@ -282,6 +283,7 @@ examples/cfd/external_aerodynamics/figconvnet/ @coreyjadams
282283
examples/cfd/external_aerodynamics/globe/ @peterdsharpe
283284
examples/cfd/external_aerodynamics/moe/ @mnabian
284285
examples/cfd/external_aerodynamics/transformer_models/ @coreyjadams @RishikeshRanade
286+
examples/cfd/external_aerodynamics/transformer_models/src/finetune/ @mnabian
285287
examples/cfd/external_aerodynamics/unified_external_aero_recipe/ @coreyjadams @peterdsharpe
286288
examples/cfd/external_aerodynamics/xaeronet/ @mnabian
287289
examples/cfd/flow_reconstruction_diffusion/
@@ -419,6 +421,7 @@ test/optim/ @peterdsharpe
419421
test/diffusion/ @CharlelieLrt
420422
test/utils/
421423
test/experimental/
424+
test/experimental/peft/ @mnabian
422425

423426
# ==============================================================================
424427
# TESTS - CI

examples/cfd/external_aerodynamics/transformer_models/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ During training, the configuration uses a flat learning rate that decays every 1
5151

5252
The Optimizer for this training is the `Muon` optimizer - available only in `pytorch>=2.9.0`. While not strictly required, we have found the `muon` optimizer performs substantially better on these architectures than standard `AdamW` and a oneCycle schedule.
5353

54+
### Parameter-Efficient Fine-Tuning (LoRA)
55+
56+
To adapt a *pretrained* model to a new dataset cheaply — without retraining all weights — use the LoRA fine-tuning recipe in the [`src/finetune/`](src/finetune/) folder (`src/finetune/finetune.py` and `src/finetune/deploy.py`, with `src/conf/finetune_lora.yaml`). It freezes the base model and trains only small low-rank adapters, producing a compact adapter checkpoint that can be swapped at serve time or merged into the base. See [src/finetune/README.md](src/finetune/README.md) for the full workflow.
57+
5458
### Training Precision
5559

5660
These transformer architectures have support for NVIDIA's [TransformerEngine](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html) built in. You can enable/disable the transformer engine path in the model with `model.use_te=[True | False]`. Available precisions for training with `transformer_engine` are `training.precision=["float32" | "float16" | "bfloat16" | "float8" ]`. In `float8` precision, the TransformerEngine Hybrid recipe is used for casting weights and inputs in the forward and backwards passes. For more details on `float8` precision, see the fp8 guide from [TransformerEngine](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html). When using fp8, the training script will automatically pad and unpad the input and output, respectively, to use the fp8 hardware correctly.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# ---------------------------------------------------------------------------
18+
# LoRA fine-tuning config (used by src/finetune/finetune.py and src/finetune/deploy.py).
19+
# Reuses the same model/data/training groups as geotransolver_surface.yaml and
20+
# adds the `peft:` block + `init_from`. Keep the `model:` block matching the
21+
# architecture of the base checkpoint you load (the adapter records a base
22+
# fingerprint and load/deploy enforce it).
23+
# ---------------------------------------------------------------------------
24+
25+
defaults:
26+
- training: base
27+
- model: geotransolver
28+
- data: surface
29+
- _self_ # this file's overrides (model:/data:/training:/peft:) apply last
30+
31+
# Pretrained base checkpoint to fine-tune (a GeoTransolver `.mdlus`, e.g. the
32+
# NIM / multi-dataset checkpoint). REQUIRED.
33+
init_from: ???
34+
35+
output_dir: "runs"
36+
run_id: "geotransolver_lora_finetune"
37+
precision: float32
38+
compile: false
39+
40+
# Fine-tuning is short and uses a smaller LR than from-scratch training.
41+
training:
42+
num_epochs: 50
43+
save_interval: 10
44+
optimizer:
45+
lr: 5.0e-4
46+
47+
# Match geotransolver_surface.yaml's model/data so this composes with a
48+
# surface-trained base out of the box.
49+
model:
50+
functional_dim: 6
51+
include_local_features: true
52+
radii: [0.01, 0.05, 0.25, 1.0, 2.5, 5.0]
53+
neighbors_in_radius: [4, 8, 16, 64, 128, 256]
54+
n_hidden_local: 32
55+
56+
data:
57+
include_sdf: false
58+
include_geometry: true
59+
geometry_sampling: 300_000
60+
broadcast_global_features: false
61+
62+
# LoRA configuration. Default targets the GALE attention projections. Set
63+
# `wrap_mlp: true` to also adapt the feed-forward MLP; under Transformer Engine
64+
# that uses the fused te.LayerNormMLP residual adapter.
65+
peft:
66+
_target_: physicsnemo.experimental.peft.LoRAConfig
67+
rank: 16
68+
alpha: 16
69+
target_pattern: 'blocks\.\d+\.Attn\.(in_project_x|in_project_fx|qkv_project|out_linear|cross_[qkv])'
70+
wrap_mlp: false
71+
72+
# Deploy-only (src/finetune/deploy.py): fold the adapter into the base and save a plain .mdlus.
73+
merge: false
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# LoRA fine-tuning (GeoTransolver)
2+
3+
Parameter-efficient fine-tuning of a **pretrained GeoTransolver** (trained with
4+
`src/train.py`, or the NIM / multi-dataset checkpoint) on a small custom
5+
external-aerodynamics dataset, using `physicsnemo.experimental.peft`.
6+
7+
This recipe lives in its own `src/finetune/` folder, separate from the main
8+
training/inference scripts in `src/`. It is a companion to `src/train.py`: same
9+
model, same data pipeline, same `src/conf/` groups — only the entry points
10+
(`src/finetune/finetune.py`, `src/finetune/deploy.py`) and config
11+
(`src/conf/finetune_lora.yaml`) are new. `train.py` is unchanged.
12+
13+
## Why LoRA
14+
15+
- Small adapters (~hundreds of KB) vs full checkpoints (~tens of MB).
16+
- Lower memory (frozen base layers drop saved activations).
17+
- Less overfitting / forgetting in the small-data regime (α=0 = the base exactly).
18+
- One base + N swappable adapters at serve time.
19+
20+
## Workflow
21+
22+
```text
23+
src/finetune/finetune.py src/finetune/deploy.py
24+
base.mdlus ───────────────────▶ adapter.lora ────────────────────▶ serve (swap)
25+
(pretrained) apply_lora + train (~hundreds KB) load_adapter or merge_lora
26+
only the adapters → merged .mdlus
27+
```
28+
29+
1. **Fine-tune** (run from the example root, same as `train.py`):
30+
31+
```bash
32+
python src/finetune/finetune.py init_from=/path/to/base_geotransolver.mdlus
33+
# multi-GPU (single node):
34+
torchrun --nproc_per_node=8 src/finetune/finetune.py init_from=/path/to/base.mdlus
35+
```
36+
37+
2. **Deploy** — adapter-swap, or merge for zero overhead:
38+
39+
```bash
40+
python src/finetune/deploy.py init_from=/path/to/base.mdlus # adapter-swap
41+
python src/finetune/deploy.py init_from=/path/to/base.mdlus merge=true # fold in → *_merged.mdlus
42+
```
43+
44+
## Config (`src/conf/finetune_lora.yaml`)
45+
46+
- `init_from` (**required**): the pretrained base `.mdlus`. The `model:` block
47+
**must match its architecture**`load_adapter`/`deploy.py` enforce a base
48+
fingerprint and refuse a mismatched base.
49+
- `peft.target_pattern`: which layers get adapters (default = GALE attention
50+
projections). `peft.wrap_mlp: true` also adapts the feed-forward MLP.
51+
- `peft.rank` / `peft.alpha`: adapter capacity / scaling. `peft.init` optionally
52+
customizes the `lora_A` initialization (a name or a callable).
53+
- Point the `data` group at your small dataset (see `src/conf/data/{core,surface}.yaml`).
54+
55+
## How it differs from `train.py`
56+
57+
- Only LoRA (+`extras_trainable`) params train; the base is frozen.
58+
- Those params go to **AdamW**, never Muon (Newton-Schulz is degenerate on
59+
rank-`r` factors) — via `split_params_for_optimizer`.
60+
- DDP uses `find_unused_parameters=True` (frozen base params get no grad).
61+
- Multi-GPU shards the dataset per rank via `DistributedSampler` + `set_indices`
62+
(same as `train.py`); launch with `torchrun --nproc_per_node=<N>`.
63+
- **float32 only**: the minimal recipe does not wire the mixed/fp8 path (autocast,
64+
fp8 padding, GradScaler); it errors if `precision != float32`. Use `train.py`
65+
for fp8.
66+
- `finetune.py` keeps a minimal MSE loop for readability; reuse
67+
`train.forward_pass` if you want the full metrics/normalization path.
68+
- Deploy `merge=true` only writes a merged `.mdlus` if all adapters are
69+
mergeable; a fused `te.LayerNormMLP` residual (from `wrap_mlp` under TE) is
70+
left in place and you deploy via `load_adapter` instead.
71+
72+
The PEFT API used here is covered by `test/experimental/peft/`. A full
73+
data-driven run needs a base checkpoint, a dataset, and the PhysicsNeMo container.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Deploy a trained LoRA adapter (companion to ``finetune.py``).
18+
19+
Two modes:
20+
* Adapter-swap: keep the frozen base + small adapter, ``load_adapter`` at
21+
serve time (one base + N adapters, swappable per request).
22+
* Merge: fold the adapter into the base for zero inference overhead and save
23+
a plain ``.mdlus``. (Fused ``te.LayerNormMLP`` residual adapters are not
24+
mergeable and are left in place; deploy those via adapter-swap instead.)
25+
26+
Run from the example root::
27+
28+
python src/finetune/deploy.py init_from=<base.mdlus> # adapter-swap
29+
python src/finetune/deploy.py init_from=<base.mdlus> merge=true # fold in
30+
"""
31+
32+
import logging
33+
34+
import hydra
35+
from omegaconf import DictConfig
36+
37+
from physicsnemo.experimental.peft import is_lora_layer, load_adapter, merge_lora
38+
39+
logger = logging.getLogger("finetune_lora.deploy")
40+
41+
42+
@hydra.main(version_base=None, config_path="../conf", config_name="finetune_lora")
43+
def main(cfg: DictConfig) -> None:
44+
"""Load a trained adapter onto the base for serving (adapter-swap), optionally
45+
merging it into the base weights (``merge=true``) for zero-overhead inference."""
46+
logging.basicConfig(level=logging.INFO)
47+
48+
# Reconstruct the SAME base architecture, then load its pretrained weights.
49+
model = hydra.utils.instantiate(cfg.model, _convert_="partial")
50+
if cfg.get("init_from"):
51+
model.load(str(cfg.init_from))
52+
53+
adapter_path = f"{cfg.output_dir}/{cfg.run_id}.lora"
54+
# load_adapter verifies kind + base fingerprint, re-applies LoRA, loads weights.
55+
load_adapter(model, adapter_path)
56+
logger.info("loaded adapter %s onto base", adapter_path)
57+
58+
if cfg.get("merge", False):
59+
merge_lora(model) # fold mergeable adapters into base weights
60+
remaining = [n for n, m in model.named_modules() if is_lora_layer(m)]
61+
if remaining:
62+
# e.g. te.LayerNormMLP residuals are non-mergeable; saving now would
63+
# write wrapper-prefixed keys that won't reload as the base model.
64+
logger.warning(
65+
"merge requested but %d non-mergeable adapter(s) remain "
66+
"(e.g. te.LayerNormMLP residuals); NOT writing a merged "
67+
"checkpoint. Serve with the adapter via load_adapter instead.",
68+
len(remaining),
69+
)
70+
else:
71+
merged_path = f"{cfg.output_dir}/{cfg.run_id}_merged.mdlus"
72+
model.save(merged_path) # plain full-model .mdlus, no adapter overhead
73+
logger.info("merged and saved %s", merged_path)
74+
75+
model.eval()
76+
logger.info("model ready for inference")
77+
78+
79+
if __name__ == "__main__":
80+
main()

0 commit comments

Comments
 (0)