Skip to content

Commit bad33a2

Browse files
committed
Add codonfm 5b benchmark and update recipe
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent 0d26cdf commit bad33a2

30 files changed

Lines changed: 4985 additions & 594 deletions

bionemo-recipes/recipes/codonfm_ptl_te/Dockerfile

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,18 @@ RUN chown -R ${USERNAME:-vscode}:${USERNAME:-vscode} /workspace/codonfm
6464

6565
# Switch to the non-root user
6666
USER $USERNAME
67+
68+
# ----------------- For benchmarking only -----------------
69+
# Warning: I was only able to build this image in an instance with 2TB of memory.
70+
# Otherwise, a segmentation fault occurs during the build process:
71+
# /bin/bash: line 1: 13517 Segmentation fault (core dumped) ptxas -arch=sm_90 -m64 -v --generate-line-info "/tmp/tmpxft_00002f58_00000000-6_flash_fwd_hdim64_256_fp16_paged_split_sm90.ptx" -o "/tmp/tmpxft_00002f58_00000000-8_flash_fwd_hdim64_256_fp16_paged_split_sm90.cubin" > /tmp/tmpxft_00002f58_00000000-10_2eb7d280_stdout 2> /tmp/tmpxft_00002f58_00000000-10_2eb7d280_stderr
72+
#
73+
# Could have also been caused by CUDA-compatibility issues.
74+
75+
FROM base AS benchmarking
76+
77+
WORKDIR /workspace/codonfm
78+
79+
RUN pip install -v --no-build-isolation -U git+https://github.com/facebookresearch/xformers.git@v0.0.32.post2#egg=xformers --no-deps
80+
81+
COPY . .

bionemo-recipes/recipes/codonfm_ptl_te/README.md

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ The table below summarizes the set of open source pre-trained weights currently
2222
| EnCodon 1B | MLM (random p=0.15) | 2048 | 18 | 16 | 8192 | `mlm/encodon_1b.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-1B-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-1B-v1) |
2323
| EnCodon 1B (CDSWT) | MLM (codon frequency-weighted) | 2048 | 18 | 16 | 8192 | `cdswt/encodon_1b.sh` | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-Cdwt-1B-v1) | [Link](https://huggingface.co/nvidia/NV-CodonFM-Encodon-TE-Cdwt-1B-v1) |
2424

25+
> Note (May 2026): The EnCodon 5B model checkpoint will be released in the near future.
26+
2527
## Repository Structure
2628

2729
High-level overview (NerdTree-style):
@@ -75,12 +77,29 @@ We also present the ability to utilize a simpler model architecture that directl
7577

7678
<br>
7779

78-
The training step speedups for the 80M Encodon model when both Transformer Engine (TE) and Sequence Packing (THD) are applied compared to the Xformers based model are shown below. We benchmarked on NVIDIA H100 80GB HBM3 GPUs using a micro batch-size is 32. The training step speedups for the 1B Encodon model are on a micro batch-size of 4.
80+
The figure below shows training throughput speedups, derived from `tokens/s/gpu`, for the `80M` and `1B` Encodon models when Transformer Engine (TE) and sequence packing (THD) are applied relative to the Xformers-based baseline.
7981

8082
![xf](assets/images/training_acceleration_plot.png)
8183

82-
For inferencing, we can also demonstrate acceleration when using each models TE counterpart. Thus, a 1.4X speedup in this chart shows how much faster the TE version of the model is over the original baseline PyTorch SDPA model.
83-
![i](assets/images/inference_plot.png)
84+
All training experiments reported here were run on `8 x NVIDIA H100 80GB HBM3` GPUs in `bfloat16` precision. The absolute throughputs used to compute the speedups above are reported below in `tokens/s/gpu`.
85+
86+
| Model | Xformers (`tokens/s/gpu`) | SDPA (`tokens/s/gpu`) | TE-BSHD (`tokens/s/gpu`) | TE-THD (`tokens/s/gpu`) | Speedup over baseline |
87+
| ----- | ------------------------: | --------------------: | -----------------------: | ----------------------: | ----------------------------- |
88+
| 80M | 117119 | 145357 | 419087 | 1028891 | 1.00x / 1.24x / 3.58x / 9.79x |
89+
| 1B | 8698 | 9899 | 26476 | 69300 | 1.00x / 1.14x / 3.04x / 7.97x |
90+
| 5B | 2320 | 2865 | 5112 | 13973 | 1.00x / 1.23x / 2.20x / 6.02x |
91+
92+
For inference, we report both relative speedup and absolute throughput. The figure below compares inference configurations by relative speedup within each model size.
93+
94+
![Inference speedup across model sizes](assets/images/inference_plot.png)
95+
96+
All inference experiments reported here were run on `8 x NVIDIA H100 80GB HBM3` GPUs in `bfloat16` precision. The absolute throughputs used to compute the speedups above are reported below in `tokens/s/gpu`.
97+
98+
| Model | Xformers (`tokens/s/gpu`) | SDPA (`tokens/s/gpu`) | TE-BSHD (`tokens/s/gpu`) | TE-THD (`tokens/s/gpu`) | Speedup over baseline |
99+
| ----- | ------------------------: | --------------------: | -----------------------: | ----------------------: | ------------------------------ |
100+
| 80M | 156819 | 190380 | 542147 | 1875140 | 1.00x / 1.21x / 3.46x / 11.96x |
101+
| 1B | 18655 | 21715 | 46551 | 221110 | 1.00x / 1.16x / 2.50x / 11.85x |
102+
| 5B | 5316 | 5991 | 9996 | 40373 | 1.00x / 1.13x / 1.88x / 7.59x |
84103

85104
## Quickstart
86105

@@ -185,9 +204,11 @@ Optional path overrides:
185204
```bash
186205
--out_dir <dir>
187206
--checkpoints_dir <dir>
188-
--pretrained_ckpt_path <path>
189207
```
190208

209+
- `--out_dir`: Base output directory for logs, metrics, and other artifacts. Defaults to `results/`.
210+
- `--checkpoints_dir`: Directory where training checkpoints are saved. Defaults to `<out_dir>/checkpoints/`. This directory also enables **automatic resumption**: if the runner finds a `last.ckpt` file inside this directory, it will reload the model weights and full trainer state (optimizer, learning-rate schedule, global step, etc.) so training picks up exactly where it left off. This is essential for long pretraining runs on clusters where jobs may be preempted or interrupted. On a fresh run the directory will be empty, so training starts from scratch as expected.
211+
191212
For multi-node execution consider using `torchrun`.
192213

193214
```bash
@@ -255,6 +276,10 @@ python -m src.runner finetune \
255276

256277
```
257278

279+
- `--pretrained_ckpt_path`: Path to a pretrained checkpoint whose **model weights only** are loaded as the starting point for finetuning. The optimizer state, learning-rate schedule, and global step are not restored — training starts fresh from step 0 with the pretrained weights. Accepts a local `.ckpt` file, a local directory containing a `.safetensors` file and `config.json`, or a HuggingFace Hub repo ID (e.g. `nvidia/codon-fm-base`).
280+
- `--checkpoints_dir`: Directory where finetuning checkpoints are saved. Defaults to `<out_dir>/checkpoints/`. If the runner finds a `last.ckpt` here, it resumes the finetuning run (model weights, optimizer, step count) from that checkpoint instead of starting from the pretrained weights. This enables automatic resumption of interrupted finetuning jobs.
281+
- `--resume_trainer_state`: When set, restores the full trainer state (optimizer, scheduler, step count) from the pretrained checkpoint rather than only loading model weights. Useful when continuing a pretraining run as a finetuning job.
282+
258283
#### Evaluation
259284

260285
The publicly available checkpoints can be used to launch scientific evaluation and benchmarking.
-65.6 KB
Loading
-77.3 KB
Loading

bionemo-recipes/recipes/codonfm_ptl_te/codonfm_ckpt_te_conversion.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,39 @@
2626

2727
import argparse
2828
import logging
29+
import os
2930

3031
import torch
32+
from safetensors.torch import save_file as safetensors_save_file
3133

3234
from src.utils.load_checkpoint import load_checkpoint
3335

3436

3537
logger = logging.getLogger(__name__)
3638

39+
ALLOWED_HYPERPARAMETER_KEYS = (
40+
"vocab_size",
41+
"hidden_size",
42+
"num_hidden_layers",
43+
"num_attention_heads",
44+
"intermediate_size",
45+
"hidden_act",
46+
"hidden_dropout_prob",
47+
"attention_probs_dropout_prob",
48+
"initializer_range",
49+
"layer_norm_eps",
50+
"pad_token_id",
51+
"position_embedding_type",
52+
"classifier_dropout",
53+
"rotary_theta",
54+
"ignore_index",
55+
"loss_type",
56+
"lora",
57+
"lora_alpha",
58+
"lora_r",
59+
"lora_dropout",
60+
)
61+
3762
# PYTorch -> TE keymap
3863
PYTORCH_TO_TE_KEYMAP = {
3964
"model.layers.*.pre_attn_layer_norm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
@@ -300,6 +325,11 @@ def convert_state_dict(src: dict, keymap: dict):
300325
return dst_state_dict
301326

302327

328+
def filter_hyper_parameters(hyper_parameters: dict) -> dict:
329+
"""Keep only conversion-compatible hyperparameter keys."""
330+
return {key: value for key, value in hyper_parameters.items() if key in ALLOWED_HYPERPARAMETER_KEYS}
331+
332+
303333
def main():
304334
"""Main function."""
305335
logging.basicConfig(level=logging.INFO)
@@ -325,6 +355,7 @@ def main():
325355
# Load source checkpoint (automatically detects format)
326356
logger.info(f"Loading checkpoint from {args.src}")
327357
src_checkpoint = load_checkpoint(args.src, map_location="cpu")
358+
src_checkpoint["hyper_parameters"] = filter_hyper_parameters(src_checkpoint["hyper_parameters"])
328359

329360
# Perform conversion based on direction
330361
if args.direction == "pytorch2te":
@@ -341,11 +372,19 @@ def main():
341372
dst_state_dict = split_qkv(converted_state_dict, src_checkpoint["hyper_parameters"])
342373

343374
# Prepare final checkpoint
344-
dst_checkpoint = {"state_dict": dst_state_dict, "hyper_parameters": src_checkpoint["hyper_parameters"]}
375+
dst_checkpoint = {
376+
"state_dict": dst_state_dict,
377+
"hyper_parameters": src_checkpoint["hyper_parameters"],
378+
}
345379

346380
# Save the converted checkpoint in pickled format
347381
torch.save(dst_checkpoint, args.dst)
348-
logger.info(f"Successfully converted checkpoint from {args.src} to {args.dst}")
382+
logger.info(f"Successfully converted checkpoint saved to {args.dst}")
383+
384+
# Save the state_dict in safetensors format alongside the .ckpt file
385+
safetensors_path = os.path.splitext(args.dst)[0] + ".safetensors"
386+
safetensors_save_file(dst_state_dict, safetensors_path)
387+
logger.info(f"Successfully saved safetensors checkpoint to {safetensors_path}")
349388

350389

351390
if __name__ == "__main__":

bionemo-recipes/recipes/codonfm_ptl_te/data_scripts/check_codon_frequency.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616

1717
# %%
18+
import argparse
1819
import json
1920
import sys
2021
from pathlib import Path
@@ -23,41 +24,52 @@
2324
from tqdm import tqdm
2425

2526

26-
sys.path.append("/workspace/codon_fm")
27+
sys.path.append("/workspace/codonfm")
2728
from src.tokenizer import Tokenizer
2829

2930

30-
data_path = Path("/data/ncbi/processed_unfiltered")
31-
tax_ids_to_remove = json.load(open("/data/ncbi/taxids_to_remove.json"))
32-
metadata = json.load(open(data_path / "metadata.json"))
33-
tokenizer = Tokenizer()
34-
35-
36-
groups = set([x["file_name"][:-4] for x in metadata["file_metadata"]]) # noqa: C403
37-
counts = {g: np.zeros(tokenizer.vocab_size) for g in groups}
38-
for fm, cm in tqdm(zip(metadata["file_metadata"], metadata["chunks"]), total=len(metadata["file_metadata"])):
39-
group = fm["file_name"][:-4]
40-
if group in tax_ids_to_remove:
41-
curr_taxids_to_remove = set(tax_ids_to_remove[group])
42-
else:
43-
curr_taxids_to_remove = set()
44-
mmap = np.memmap(
45-
data_path / cm["sequences"]["path"],
46-
dtype=cm["sequences"]["dtype"],
47-
mode="r",
48-
shape=tuple(cm["sequences"]["shape"]),
49-
)
50-
idx_mmap = np.memmap(
51-
data_path / cm["index"]["path"], dtype=cm["index"]["dtype"], mode="r", shape=tuple(cm["index"]["shape"])
52-
)
53-
for start, end, taxid in idx_mmap:
54-
if taxid in curr_taxids_to_remove:
55-
continue
56-
seq = mmap[start:end]
57-
idx, count = np.unique(seq, return_counts=True)
58-
counts[group][idx] += count
31+
def main(pretraining_processed_data_dir: Path, data_dir: Path):
32+
"""Check codon frequency."""
33+
tax_ids_to_remove = json.load(open(data_dir / Path("taxids_to_remove.json")))
34+
metadata = json.load(open(pretraining_processed_data_dir / "metadata.json"))
35+
tokenizer = Tokenizer()
5936

60-
# %%
61-
for g in counts:
62-
counts[g] = counts[g].tolist()
63-
json.dump(counts, open("/data/ncbi/codon_counts_nopathogen.json", "w"))
37+
groups = set([x["file_name"][:-4] for x in metadata["file_metadata"]]) # noqa: C403
38+
counts = {g: np.zeros(tokenizer.vocab_size) for g in groups}
39+
for fm, cm in tqdm(zip(metadata["file_metadata"], metadata["chunks"]), total=len(metadata["file_metadata"])):
40+
group = fm["file_name"][:-4]
41+
if group in tax_ids_to_remove:
42+
curr_taxids_to_remove = set(tax_ids_to_remove[group])
43+
else:
44+
curr_taxids_to_remove = set()
45+
mmap = np.memmap(
46+
pretraining_processed_data_dir / cm["sequences"]["path"],
47+
dtype=cm["sequences"]["dtype"],
48+
mode="r",
49+
shape=tuple(cm["sequences"]["shape"]),
50+
)
51+
idx_mmap = np.memmap(
52+
pretraining_processed_data_dir / cm["index"]["path"],
53+
dtype=cm["index"]["dtype"],
54+
mode="r",
55+
shape=tuple(cm["index"]["shape"]),
56+
)
57+
for start, end, taxid in idx_mmap:
58+
if taxid in curr_taxids_to_remove:
59+
continue
60+
seq = mmap[start:end]
61+
idx, count = np.unique(seq, return_counts=True)
62+
counts[group][idx] += count
63+
64+
# %%
65+
for g in counts:
66+
counts[g] = counts[g].tolist()
67+
json.dump(counts, open(data_dir / "codon_counts_nopathogen.json", "w"))
68+
69+
70+
if __name__ == "__main__":
71+
parser = argparse.ArgumentParser(description="Check codon frequency")
72+
parser.add_argument("--pretraining_processed_data_dir", type=str, required=True)
73+
parser.add_argument("--data_dir", type=str, required=True)
74+
args = parser.parse_args()
75+
main(Path(args.pretraining_processed_data_dir), Path(args.data_dir))

bionemo-recipes/recipes/codonfm_ptl_te/data_scripts/ncbi_memmap_dataset_creator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
import argparse
1818
import json
1919
import os
20+
import sys
2021
from multiprocessing import Pool, cpu_count
2122

2223
import numpy as np
2324
import polars as pl
2425
import pyarrow.parquet as pq
2526
from tqdm import tqdm
2627

28+
29+
sys.path.append("/workspace/codonfm")
2730
from src.tokenizer import Tokenizer
2831

2932

0 commit comments

Comments
 (0)