Skip to content

Commit 1942554

Browse files
authored
Adds codon-fm-native-te recipe (#1531)
Summary - Add bionemo-recipes/models/codonfm/ — a HuggingFace-compatible CodonFM model using TransformerEngine, following the ESM2 pattern - Add bionemo-recipes/recipes/codonfm_native_te/ — a self-contained FSDP2 training recipe for CodonFM with FP8/FP4 quantization support - Add golden value regression tests cross-validated against the codonfm_ptl_te non-exact (standard TETransformerLayer) implementation models/codonfm/ (HF-compatible model) The model code in models/codonfm/modeling_codonfm_te.py is the source of truth, auto-synced to the recipe via check_copied_files.py. - CodonFMConfig(PretrainedConfig) — HF-compatible config with 4 presets (200k, 80M, 600M, 1B) - CodonFMPreTrainedModel(PreTrainedModel) — base class with MAGNETO initialization (xavier_normal with scaled gain), meta device support - CodonFMForMaskedLM — returns MaskedLMOutput, supports output_hidden_states, per-layer FP8/FP4 precision via layer_precision config - CodonEmbedding — token + post-LayerNorm embedding - CodonFMEncoder — stack of transformer_engine.pytorch.TransformerLayer with RoPE, BSHD and THD attention formats - CodonFMLMHead — Dense + GELU + LayerNormLinear (quantization disabled for numerical stability) - CodonTokenizer — 3-mer codon tokenizer (69 tokens: 5 special + 64 codons) - dataset.py — BSHD/THD collators, synthetic and parquet dataset classes 73 tests (36 pass, 10 skip, 25 xfail, 2 xpass): - Forward/backward smoke tests (BSHD + THD) - FP8/FP4 quantization tests (DelayedScaling, Float8CurrentScaling, Float8BlockScaling, MXFP8BlockScaling, NVFP4BlockScaling) - Meta device and CUDA initialization tests - Golden value regression tests — weights generated from codonfm_ptl_te non-exact model, state dict mapped to native_te key format, cross-model logit equivalence verified - BSHD ↔ THD equivalence test — same weights, both formats, outputs compared recipes/codonfm_native_te/ (training recipe) Self-contained FSDP2 training recipe with: - Hydra config (defaults.yaml, L0_sanity.yaml) - train_fsdp2.py — FSDP2 training loop with gradient clipping, LR scheduling - checkpoint.py — FSDP2 checkpoint save/load with save_pretrained support - perf_logger.py — WandB + stdout logging (loss, perplexity, tokens/sec, GPU memory) - quantization.py — FP8/FP4 recipe utilities - Sample train.parquet for testing 63 tests covering model, tokenizer, quantization, and end-to-end training. CI integration - ci/scripts/check_copied_files.py — added entries to sync: - models/codonfm/modeling_codonfm_te.py → recipes/codonfm_native_te/modeling_codonfm_te.py - models/esm2/tests/common/ → models/codonfm/tests/common/ - .gitignore — added negation rule for golden value safetensors test fixtures What's left out / future work - No HF Hub checkpoint yet — the published TE checkpoints (nvidia/NV-CodonFM-Encodon-TE-80M-v1) use the "exact" EncodonTELayer with extra post-attention/post-MLP LayerNorms not present in standard TETransformerLayer. Golden values will be updated once a native_te checkpoint is trained and uploaded. - Conversion tests skipped — CodonFM is natively TE; there is no HF variant to convert to/from. - THD padding tests skipped — pad_to_multiple_of and cu_seq_lens_q_padded not yet implemented for CodonFM's tokenizer. - _do_not_quantize patterns — currently ("lm_head.dense", "lm_head.layer_norm_linear"). May need tuning as quantization recipes evolve. - Dockerfile — recipe does not yet include a Dockerfile for containerized training. Test plan - cd bionemo-recipes/models/codonfm && pytest -v tests/ — 36 pass, 10 skip, 25 xfail, 2 xpass - cd bionemo-recipes/recipes/codonfm_native_te && pytest -v tests/ — 63 pass - python ci/scripts/check_copied_files.py — no diffs - pre-commit run --all-files — clean ---------- END OF DESCRIPTION---- #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. #### Triggering Code Rabbit AI Review To trigger a code review from code rabbit, comment on a pull request with one of these commands: - @coderabbitai review - Triggers a standard review - @coderabbitai full review - Triggers a comprehensive review See https://docs.coderabbit.ai/reference/review-commands for a full list of commands. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** - Added a CodonFM training recipe: model, tokenizer, datasets/collators, dataloaders, distributed training entrypoint, checkpointing, scheduler, perf logging, and configurable FP8/FP4 quantization and debug stats. * **Tests** - Comprehensive test suite and utilities including golden-value generation, conversion/golden regression tests, FP8/THD coverage, and sanity training runs. * **Documentation** - Recipe and shared test-library READMEs added. * **Chores** - .gitignore adjusted to allow tracking of golden_state_dict.safetensors fixtures. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent c301311 commit 1942554

42 files changed

Lines changed: 7184 additions & 0 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ wandb/
205205

206206
# Any model checkpoints
207207
*.safetensors
208+
# Allow golden value test fixtures
209+
!bionemo-recipes/models/*/tests/golden_state_dict.safetensors
208210
checkpoint_export/
209211
checkpoints/
210212

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Dataset and dataloader utilities for CodonFM pretraining."""
17+
18+
import random
19+
20+
import pyarrow.parquet as pq
21+
import torch
22+
from distributed_config import DistributedConfig
23+
from tokenizer import CodonTokenizer
24+
from torch.utils.data import DataLoader, Dataset, DistributedSampler
25+
26+
27+
BASES = "ACGT"
28+
29+
30+
class SyntheticCodonDataset(Dataset):
31+
"""Generates random codon sequences on-the-fly for testing."""
32+
33+
def __init__(self, num_samples: int = 1000, min_codons: int = 30, max_codons: int = 200, seed: int = 42):
34+
"""Initialize.
35+
36+
Args:
37+
num_samples: Number of sequences to generate.
38+
min_codons: Minimum number of codons per sequence.
39+
max_codons: Maximum number of codons per sequence.
40+
seed: Random seed.
41+
"""
42+
self.num_samples = num_samples
43+
self.min_codons = min_codons
44+
self.max_codons = max_codons
45+
self.rng = random.Random(seed)
46+
self.sequences = [self._generate_sequence() for _ in range(num_samples)]
47+
48+
def _generate_sequence(self) -> str:
49+
num_codons = self.rng.randint(self.min_codons, self.max_codons)
50+
return "".join(self.rng.choice(BASES) for _ in range(num_codons * 3))
51+
52+
def __len__(self) -> int: # noqa: D105
53+
return self.num_samples
54+
55+
def __getitem__(self, idx: int) -> dict[str, str]: # noqa: D105
56+
return {"sequence": self.sequences[idx]}
57+
58+
59+
class ParquetCodonDataset(Dataset):
60+
"""Dataset that reads codon sequences from a parquet file using memory-mapped Arrow arrays.
61+
62+
Uses PyArrow memory mapping instead of loading into a pandas DataFrame,
63+
avoiding the pandas copy and letting the OS page data in/out as needed.
64+
"""
65+
66+
def __init__(self, path: str):
67+
"""Initialize.
68+
69+
Args:
70+
path: Path to the parquet file with a 'sequence' column.
71+
"""
72+
self._table = pq.read_table(path, columns=["sequence"], memory_map=True)
73+
self._sequences = self._table.column("sequence")
74+
75+
def __len__(self) -> int: # noqa: D105
76+
return len(self._sequences)
77+
78+
def __getitem__(self, idx: int) -> dict[str, str]: # noqa: D105
79+
return {"sequence": self._sequences[idx].as_py()}
80+
81+
82+
class CodonMLMCollator:
83+
"""Collator that tokenizes sequences and applies MLM masking for BSHD format."""
84+
85+
def __init__(
86+
self,
87+
tokenizer: CodonTokenizer,
88+
max_seq_length: int = 512,
89+
mlm_probability: float = 0.15,
90+
seed: int = 42,
91+
):
92+
"""Initialize.
93+
94+
Args:
95+
tokenizer: CodonTokenizer instance.
96+
max_seq_length: Maximum sequence length (including special tokens).
97+
mlm_probability: Probability of masking a token.
98+
seed: Random seed for reproducible masking.
99+
"""
100+
self.tokenizer = tokenizer
101+
self.max_seq_length = max_seq_length
102+
self.mlm_probability = mlm_probability
103+
self.rng = random.Random(seed)
104+
105+
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
106+
"""Collate a batch of sequences into MLM training inputs.
107+
108+
Args:
109+
batch: List of dicts with 'sequence' key.
110+
111+
Returns:
112+
Dict with input_ids, attention_mask, and labels tensors.
113+
"""
114+
all_input_ids = []
115+
all_attention_masks = []
116+
all_labels = []
117+
118+
for sample in batch:
119+
ids = self.tokenizer.encode(sample["sequence"], add_special_tokens=True)
120+
# Truncate to max_seq_length, preserving trailing SEP token
121+
if len(ids) > self.max_seq_length:
122+
ids = [*ids[: self.max_seq_length - 1], self.tokenizer.sep_token_id]
123+
seq_len = len(ids)
124+
125+
# Create attention mask and pad
126+
attn_mask = [1] * seq_len + [0] * (self.max_seq_length - seq_len)
127+
ids = ids + [self.tokenizer.pad_token_id] * (self.max_seq_length - seq_len)
128+
129+
# Apply MLM masking
130+
labels = [-100] * self.max_seq_length
131+
for i in range(seq_len):
132+
# Skip special tokens (CLS at 0, SEP at end)
133+
if ids[i] in (self.tokenizer.cls_token_id, self.tokenizer.sep_token_id, self.tokenizer.pad_token_id):
134+
continue
135+
if self.rng.random() < self.mlm_probability:
136+
labels[i] = ids[i]
137+
r = self.rng.random()
138+
if r < 0.8:
139+
ids[i] = self.tokenizer.mask_token_id
140+
elif r < 0.9:
141+
# Random codon token (IDs 5 through 68)
142+
ids[i] = self.rng.randint(5, self.tokenizer.vocab_size - 1)
143+
# else: keep original (10% of the time)
144+
145+
all_input_ids.append(ids)
146+
all_attention_masks.append(attn_mask)
147+
all_labels.append(labels)
148+
149+
return {
150+
"input_ids": torch.tensor(all_input_ids, dtype=torch.long),
151+
"attention_mask": torch.tensor(all_attention_masks, dtype=torch.long),
152+
"labels": torch.tensor(all_labels, dtype=torch.long),
153+
}
154+
155+
156+
class CodonTHDCollator:
157+
"""Collator for THD (packed sequence) format."""
158+
159+
def __init__(
160+
self,
161+
tokenizer: CodonTokenizer,
162+
max_seq_length: int = 512,
163+
mlm_probability: float = 0.15,
164+
seed: int = 42,
165+
):
166+
"""Initialize.
167+
168+
Args:
169+
tokenizer: CodonTokenizer instance.
170+
max_seq_length: Maximum sequence length per sample.
171+
mlm_probability: Probability of masking a token.
172+
seed: Random seed for reproducible masking.
173+
"""
174+
self.tokenizer = tokenizer
175+
self.max_seq_length = max_seq_length
176+
self.mlm_probability = mlm_probability
177+
self.rng = random.Random(seed)
178+
179+
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
180+
"""Collate a batch into THD packed format.
181+
182+
Args:
183+
batch: List of dicts with 'sequence' key.
184+
185+
Returns:
186+
Dict with input_ids, labels (flattened), cu_seq_lens_q/k, max_length_q/k.
187+
"""
188+
all_ids = []
189+
all_labels = []
190+
seq_lengths = []
191+
192+
for sample in batch:
193+
ids = self.tokenizer.encode(sample["sequence"], add_special_tokens=True)
194+
# Truncate to max_seq_length, preserving trailing SEP token
195+
if len(ids) > self.max_seq_length:
196+
ids = [*ids[: self.max_seq_length - 1], self.tokenizer.sep_token_id]
197+
seq_len = len(ids)
198+
199+
# Apply MLM masking
200+
labels = [-100] * seq_len
201+
for i in range(seq_len):
202+
if ids[i] in (self.tokenizer.cls_token_id, self.tokenizer.sep_token_id, self.tokenizer.pad_token_id):
203+
continue
204+
if self.rng.random() < self.mlm_probability:
205+
labels[i] = ids[i]
206+
r = self.rng.random()
207+
if r < 0.8:
208+
ids[i] = self.tokenizer.mask_token_id
209+
elif r < 0.9:
210+
ids[i] = self.rng.randint(5, self.tokenizer.vocab_size - 1)
211+
212+
all_ids.extend(ids)
213+
all_labels.extend(labels)
214+
seq_lengths.append(seq_len)
215+
216+
cu_seq_lens = torch.zeros(len(seq_lengths) + 1, dtype=torch.int32)
217+
cu_seq_lens[1:] = torch.cumsum(torch.tensor(seq_lengths, dtype=torch.int32), dim=0)
218+
219+
return {
220+
"input_ids": torch.tensor(all_ids, dtype=torch.long).unsqueeze(0),
221+
"labels": torch.tensor(all_labels, dtype=torch.long).unsqueeze(0),
222+
"cu_seq_lens_q": cu_seq_lens,
223+
"cu_seq_lens_k": cu_seq_lens,
224+
"max_length_q": max(seq_lengths),
225+
"max_length_k": max(seq_lengths),
226+
}
227+
228+
229+
def create_bshd_dataloader(
230+
dist_config: DistributedConfig,
231+
data_path: str,
232+
micro_batch_size: int = 2,
233+
max_seq_length: int = 512,
234+
mlm_probability: float = 0.15,
235+
num_workers: int = 1,
236+
seed: int = 42,
237+
) -> tuple[DataLoader, DistributedSampler]:
238+
"""Create a BSHD-format dataloader.
239+
240+
Args:
241+
dist_config: Distributed configuration.
242+
data_path: Path to parquet file or 'synthetic'.
243+
micro_batch_size: Batch size per GPU.
244+
max_seq_length: Maximum sequence length.
245+
mlm_probability: MLM masking probability.
246+
num_workers: Number of dataloader workers.
247+
seed: Random seed.
248+
249+
Returns:
250+
Tuple of (DataLoader, DistributedSampler).
251+
"""
252+
tokenizer = CodonTokenizer()
253+
254+
if data_path == "synthetic":
255+
dataset = SyntheticCodonDataset(num_samples=500, seed=seed)
256+
else:
257+
dataset = ParquetCodonDataset(data_path)
258+
259+
sampler = DistributedSampler(
260+
dataset,
261+
rank=dist_config.rank,
262+
num_replicas=dist_config.world_size,
263+
seed=seed,
264+
)
265+
266+
collator = CodonMLMCollator(
267+
tokenizer=tokenizer,
268+
max_seq_length=max_seq_length,
269+
mlm_probability=mlm_probability,
270+
)
271+
272+
dataloader = DataLoader(
273+
dataset,
274+
sampler=sampler,
275+
batch_size=micro_batch_size,
276+
collate_fn=collator,
277+
num_workers=num_workers,
278+
pin_memory=True,
279+
)
280+
281+
return dataloader, sampler
282+
283+
284+
def create_thd_dataloader(
285+
dist_config: DistributedConfig,
286+
data_path: str,
287+
micro_batch_size: int = 2,
288+
max_seq_length: int = 512,
289+
mlm_probability: float = 0.15,
290+
num_workers: int = 1,
291+
seed: int = 42,
292+
) -> tuple[DataLoader, DistributedSampler]:
293+
"""Create a THD-format (packed sequence) dataloader.
294+
295+
Args:
296+
dist_config: Distributed configuration.
297+
data_path: Path to parquet file or 'synthetic'.
298+
micro_batch_size: Number of sequences to pack per batch.
299+
max_seq_length: Maximum sequence length per sample.
300+
mlm_probability: MLM masking probability.
301+
num_workers: Number of dataloader workers.
302+
seed: Random seed.
303+
304+
Returns:
305+
Tuple of (DataLoader, DistributedSampler).
306+
"""
307+
tokenizer = CodonTokenizer()
308+
309+
if data_path == "synthetic":
310+
dataset = SyntheticCodonDataset(num_samples=500, seed=seed)
311+
else:
312+
dataset = ParquetCodonDataset(data_path)
313+
314+
sampler = DistributedSampler(
315+
dataset,
316+
rank=dist_config.rank,
317+
num_replicas=dist_config.world_size,
318+
seed=seed,
319+
)
320+
321+
collator = CodonTHDCollator(
322+
tokenizer=tokenizer,
323+
max_seq_length=max_seq_length,
324+
mlm_probability=mlm_probability,
325+
)
326+
327+
dataloader = DataLoader(
328+
dataset,
329+
sampler=sampler,
330+
batch_size=micro_batch_size,
331+
collate_fn=collator,
332+
num_workers=num_workers,
333+
pin_memory=True,
334+
)
335+
336+
return dataloader, sampler

0 commit comments

Comments
 (0)