From 1fc218ea3e0f29c8528026ea8e0bb65c20db8831 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Fri, 8 May 2026 01:22:39 +0000 Subject: [PATCH 01/17] Add train ddp to recipe --- .../recipes/codonfm_native_te/Dockerfile | 7 +- .../recipes/codonfm_native_te/checkpoint.py | 88 ++++++ .../recipes/codonfm_native_te/run_1b.sh | 94 +++++++ .../recipes/codonfm_native_te/train_ddp.py | 265 ++++++++++++++++++ 4 files changed, 453 insertions(+), 1 deletion(-) create mode 100755 bionemo-recipes/recipes/codonfm_native_te/run_1b.sh create mode 100644 bionemo-recipes/recipes/codonfm_native_te/train_ddp.py diff --git a/bionemo-recipes/recipes/codonfm_native_te/Dockerfile b/bionemo-recipes/recipes/codonfm_native_te/Dockerfile index b72c36b890..e59e7fe2fd 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/Dockerfile +++ b/bionemo-recipes/recipes/codonfm_native_te/Dockerfile @@ -1,9 +1,14 @@ # syntax=docker/dockerfile:1.4 -FROM nvcr.io/nvidia/pytorch:26.04-py3 +FROM nvcr.io/nvidia/pytorch:26.02-py3 + +RUN apt-get update && apt-get install -y tmux npm RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ PIP_CONSTRAINT= pip install -r /requirements.txt +RUN curl -fsSL https://claude.ai/install.sh | bash # Install Claude CLI tool +RUN npm install -g @openai/codex + WORKDIR /workspace/bionemo COPY . . diff --git a/bionemo-recipes/recipes/codonfm_native_te/checkpoint.py b/bionemo-recipes/recipes/codonfm_native_te/checkpoint.py index 66617d0502..cef9bc8fb0 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/checkpoint.py +++ b/bionemo-recipes/recipes/codonfm_native_te/checkpoint.py @@ -221,3 +221,91 @@ def save_final_model_fsdp2( save_file(model_state_dict, os.path.join(save_directory, "model.safetensors")) config.to_json_file(os.path.join(save_directory, "config.json")) logger.info(f"Saved final FSDP2 model to {save_directory}") + + +# ============================================================================ +# DDP Checkpointing +# ============================================================================ + + +def load_checkpoint_ddp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, +) -> CheckpointOutput: + """Load DDP checkpoint.""" + checkpoint_path, _ = get_latest_checkpoint(ckpt_path) + if not checkpoint_path: + logger.info("No DDP checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, 0, 0) + + checkpoint = torch.load( + checkpoint_path / "checkpoint.pt", + map_location=f"cuda:{dist_config.local_rank}", + weights_only=True, + ) + + model.load_state_dict(checkpoint["model"], strict=False) + optimizer.load_state_dict(checkpoint["optimizer"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + + if dist_config.is_main_process(): + logger.info(f"Loaded DDP checkpoint from step {checkpoint['step']}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, checkpoint["step"] + 1, checkpoint["epoch"]) + + +def save_checkpoint_ddp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + epoch: int, + dist_config: DistributedConfig, + max_checkpoints: int | None = None, +) -> None: + """Save DDP checkpoint (rank-0 only since the model is replicated).""" + if not dist_config.is_main_process(): + return + + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + torch.save( + { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "step": step, + "epoch": epoch, + }, + checkpoint_path / "checkpoint.pt", + ) + logger.info(f"Saved DDP checkpoint to {checkpoint_path}") + + if max_checkpoints is not None: + prune_checkpoints(ckpt_path, max_checkpoints) + + +def save_final_model_ddp( + model: torch.nn.Module, + config, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for DDP - only on main process.""" + if not dist_config.is_main_process(): + return + + # Unwrap DDP if wrapped. + underlying_model = model.module if hasattr(model, "module") else model + + os.makedirs(save_directory, exist_ok=True) + save_file(underlying_model.state_dict(), os.path.join(save_directory, "model.safetensors")) + config.to_json_file(os.path.join(save_directory, "config.json")) + logger.info(f"Saved final DDP model to {save_directory}") diff --git a/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh b/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh new file mode 100755 index 0000000000..5ec827e7bc --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash +set -euo pipefail + +export CPATH=/usr/local/cuda/include +export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas + +# Run config +export CONFIG_NAME=encodon_1b +export NPROC_PER_NODE=8 +export DIST_STRATEGY=ddp # fsdp or ddp + +# Training +export NUM_TRAIN_STEPS=100 +export MICRO_BATCH_SIZE=31 +export NUM_WORKERS=1 +export USE_SEQUENCE_PACKING=True +export USE_FP32_MASTER_WEIGHTS=True +export NUM_WARMUP_STEPS=500 + +# Logging / W&B +export LOGGER_FREQUENCY=10 +export WANDB_API_KEY="" +export WANDB_PROJECT=codon-fm-low-precision + +# Checkpointing +export SAVE_FINAL_MODEL=False +export SAVE_EVERY_N_STEPS=100000 +export CKPT_DIR=/tmp +export RESUME_FROM_CHECKPOINT=False + +# Hydra +export HYDRA_RUN_DIR=1b_test + +# Quantization / FP8 +export QUANT_STATS_ENABLED=False +export FP8_ENABLED=True +export FP8_RECIPE=transformer_engine.common.recipe.MXFP8BlockScaling +export FP8_FORMAT=E4M3 + +# Data +export DATASET_DATA_PATH=/data/balvisio/codonfm/reference-dataset/codonfm/processed_unfiltered/ + +# Derived: build wandb run name from model size, batch size, and precision recipe +MODEL_SIZE="${CONFIG_NAME##*_}" +if [ "${FP8_ENABLED}" = "True" ]; then + RECIPE_SHORT="${FP8_RECIPE##*.}" + RECIPE_SHORT="${RECIPE_SHORT%BlockScaling}" + RECIPE_SHORT="${RECIPE_SHORT%Scaling}" + PRECISION_TAG="${RECIPE_SHORT,,}_${FP8_FORMAT,,}" +else + PRECISION_TAG="bf16" +fi +export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_bs${MICRO_BATCH_SIZE}_${PRECISION_TAG}" + +# Pick training script based on distributed strategy. +# DDP can't emulate FSDP's fp32-master / bf16-param split, so force fp32 master weights off. +case "${DIST_STRATEGY}" in + fsdp) + TRAIN_SCRIPT=train_fsdp2.py + ;; + ddp) + TRAIN_SCRIPT=train_ddp.py + if [ "${USE_FP32_MASTER_WEIGHTS}" = "True" ]; then + echo "DIST_STRATEGY=ddp: overriding USE_FP32_MASTER_WEIGHTS=True -> False" >&2 + export USE_FP32_MASTER_WEIGHTS=False + fi + ;; + *) + echo "DIST_STRATEGY must be 'fsdp' or 'ddp', got '${DIST_STRATEGY}'" >&2 + exit 1 + ;; +esac + +torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \ + --config-name ${CONFIG_NAME} \ + quant_stats_config.enabled=${QUANT_STATS_ENABLED} \ + logger.frequency=${LOGGER_FREQUENCY} \ + num_train_steps=${NUM_TRAIN_STEPS} \ + dataset.micro_batch_size=${MICRO_BATCH_SIZE} \ + dataset.num_workers=${NUM_WORKERS} \ + dataset.data_path=${DATASET_DATA_PATH} \ + use_sequence_packing=${USE_SEQUENCE_PACKING} \ + use_fp32_master_weights=${USE_FP32_MASTER_WEIGHTS} \ + lr_scheduler_kwargs.num_warmup_steps=${NUM_WARMUP_STEPS} \ + wandb_init_args.name=${WANDB_RUN_NAME} \ + +wandb_init_args.project=${WANDB_PROJECT} \ + checkpoint.save_final_model=${SAVE_FINAL_MODEL} \ + checkpoint.save_every_n_steps=${SAVE_EVERY_N_STEPS} \ + checkpoint.ckpt_dir=${CKPT_DIR} \ + checkpoint.resume_from_checkpoint=${RESUME_FROM_CHECKPOINT} \ + hydra.run.dir=${HYDRA_RUN_DIR} \ + fp8_config.enabled=${FP8_ENABLED} \ + fp8_config.fp8_recipe=${FP8_RECIPE} \ + fp8_config.fp8_format=${FP8_FORMAT} diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py b/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py new file mode 100644 index 0000000000..9acc153bbb --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DDP training script for CodonFM with TransformerEngine layers.""" + +import logging +from contextlib import nullcontext +from pathlib import Path + +import hydra +import nvdlfw_inspect.api as debug_api +import torch +from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint +from dataset import create_bshd_dataloader, create_thd_dataloader +from distributed_config import DistributedConfig +from modeling_codonfm_te import MODEL_PRESETS, CodonFMConfig, CodonFMForMaskedLM +from omegaconf import DictConfig, OmegaConf +from perf_logger import PerfLogger +from quantization import WandBQuantLogger, initialize_quant_stats_logging, resolve_layer_precision +from scheduler import get_linear_schedule_with_warmup +from torch.distributed.device_mesh import init_device_mesh +from torch.optim import AdamW +from transformer_engine.common.recipe import Format + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") +def main(args: DictConfig) -> float | None: + """Train CodonFM with TE layers using DDP. + + Returns: + float: The minimum loss value seen during training. + """ + logging.getLogger("httpx").setLevel(logging.WARNING) + + # Initialize distributed configuration + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + # DDP keeps a single param dtype per replica, so it can't emulate FSDP2's + # MixedPrecisionPolicy(param_dtype=bf16, reduce_dtype=fp32) split. Reject up-front. + if args.use_fp32_master_weights: + raise ValueError("FP32 master weights are not supported with DDP. Use train_fsdp2.py instead.") + + perf_logger = None + try: + # Mirrors the FSDP2 device mesh — not strictly required for DDP, but keeps configs symmetric. + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(dist_config.world_size,), + mesh_dim_names=("ddp",), + ) + + # Build model config from preset + preset_overrides = MODEL_PRESETS[args.model_preset] + + # Resolve layer-wise quantization assignments + num_layers = preset_overrides.get("num_hidden_layers", 12) + layer_precision = resolve_layer_precision( + num_layers=num_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + + # Initialize quant stats logging if enabled + if args.quant_stats_config.enabled: + wandb_logger = None + if args.quant_stats_config.log_to_wandb and dist_config.is_main_process(): + wandb_logger = WandBQuantLogger() + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + layer_precision=layer_precision, + statistics_logger=wandb_logger, + ) + + # Create quantization recipes + fp8_recipe = None + fp4_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + config = CodonFMConfig( + attn_input_format="thd" if args.use_sequence_packing else "bshd", + max_position_embeddings=args.dataset.max_seq_length, + layer_precision=layer_precision, + **preset_overrides, + ) + + with torch.device("meta") if args.use_meta_device else nullcontext(): + model = CodonFMForMaskedLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + + logger.info("Initialized Model:\n%s", model) + + # Materialize weights. With meta-device init, init_empty_weights() runs the MAGNETO init + # and moves params to CUDA; otherwise the model was constructed eagerly on CPU. + if args.use_meta_device: + model.init_empty_weights() + else: + model = model.to(device) + + # DDP replicates the full model on each GPU. Cast params to bf16 since the optimizer + # update happens in the same dtype as the params (no FP32 master weights here). + model = model.to(dtype=torch.bfloat16) + + # Assign layer names for debug API + if args.quant_stats_config.enabled: + debug_api.infer_and_assign_layer_names(model) + + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + device_mesh=device_mesh["ddp"], + ) + + # Create optimizer and scheduler + optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) + scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + # Create dataloader + dataloader_kwargs = OmegaConf.to_container(args.dataset, resolve=True) + train_dataloader, sampler = ( + create_thd_dataloader(dist_config, **dataloader_kwargs) + if args.use_sequence_packing + else create_bshd_dataloader(dist_config, **dataloader_kwargs) + ) + + # Resume from checkpoint if available + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None + if args.checkpoint.resume_from_checkpoint and ckpt_path: + model, optimizer, scheduler, start_step, epoch = load_checkpoint_ddp( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + dist_config=dist_config, + ) + else: + start_step = 0 + epoch = 0 + + perf_logger = PerfLogger(dist_config, args) + + # Training loop + step = start_step + micro_step = 0 # Gradient accumulation step counter + while step < args.num_train_steps: + batches_in_epoch = 0 + for batch in train_dataloader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 + + micro_step += 1 + + # Skip DDP grad sync on intermediate accumulation micro-steps; the final + # micro-step (when we will call optimizer.step) syncs as usual. + is_accumulation_boundary = micro_step % args.grad_acc_steps == 0 + sync_context = nullcontext() if is_accumulation_boundary else model.no_sync() + + with sync_context: + # Forward pass + outputs = model(**batch) + + # Backward pass - scale loss by grad_acc_steps for proper gradient averaging + loss = outputs.loss / args.grad_acc_steps + loss.backward() + + # Log micro-batch data for accumulation metrics + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) + + # Optimizer step only after accumulating grad_acc_steps micro-batches + if is_accumulation_boundary: + micro_step = 0 + + # Grad clip + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() + + # Optimizer step + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + perf_logger.log_step( + step=step, + grad_norm=total_norm, + lr=optimizer.param_groups[0]["lr"], + ) + + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): + save_checkpoint_ddp( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + step=step, + epoch=epoch, + dist_config=dist_config, + max_checkpoints=args.checkpoint.max_checkpoints, + ) + + step += 1 + if step >= args.num_train_steps: + break + + batches_in_epoch += 1 + + if batches_in_epoch == 0: + raise RuntimeError( + f"Dataloader produced zero batches at epoch {epoch}, step {step}/{args.num_train_steps}. " + "This would cause an infinite loop." + ) + + epoch += 1 + sampler.set_epoch(epoch) + + # Save final model + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_ddp( + model=model, + config=config, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + return float(perf_logger.min_loss.item()) + finally: + if perf_logger is not None: + perf_logger.finish() + else: + try: + debug_api.end_debug() + except RuntimeError: + pass + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() From a53e7d6e42d5d6b2fd0449d02b03cbef24b14314 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Tue, 12 May 2026 03:36:13 +0000 Subject: [PATCH 02/17] Added support for train/val/test split CodonMemmapDataset --- .../codonfm_native_te/codon_memmap_dataset.py | 513 ++++++++++++++++++ .../recipes/codonfm_native_te/dataset.py | 113 +++- .../hydra_config/L0_sanity.yaml | 5 + .../hydra_config/defaults.yaml | 21 +- .../hydra_config/encodon_1b.yaml | 2 +- .../hydra_config/encodon_5b.yaml | 2 +- .../recipes/codonfm_native_te/perf_logger.py | 12 + .../recipes/codonfm_native_te/run_1b.sh | 5 +- .../recipes/codonfm_native_te/train_ddp.py | 42 +- .../recipes/codonfm_native_te/train_fsdp2.py | 39 +- 10 files changed, 730 insertions(+), 24 deletions(-) create mode 100644 bionemo-recipes/recipes/codonfm_native_te/codon_memmap_dataset.py diff --git a/bionemo-recipes/recipes/codonfm_native_te/codon_memmap_dataset.py b/bionemo-recipes/recipes/codonfm_native_te/codon_memmap_dataset.py new file mode 100644 index 0000000000..87a239e2b4 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/codon_memmap_dataset.py @@ -0,0 +1,513 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Split-aware memmap dataset for CodonFM pretraining. + +Ports the index construction, sequence-cluster filtering, and proportional cluster split +from the PTL recipe (codonfm_ptl_te) so the native_te recipe sees the same train/val/test +sample sets when pointed at the same data directory and seed. On-disk cache files +(metadata.cache*.npy, train_idx*.npy, val_idx*.npy, test_idx*.npy) use PTL's filename +conventions, so caches generated by either recipe are interchangeable. + +A cache_params*.json sidecar is written on regeneration to detect silent parameter drift. +""" + +import hashlib +import json +import logging +from pathlib import Path + +import numpy as np +import torch.distributed as dist +from torch.utils.data import Dataset +from tqdm import tqdm + + +logger = logging.getLogger(__name__) + + +_VALID_SPLITS = ("train", "validation", "test") + + +def _is_distributed() -> bool: + """True when torch.distributed has been initialized for the current process.""" + return dist.is_available() and dist.is_initialized() + + +def _is_rank_zero() -> bool: + """True on global rank 0 under torch.distributed, or unconditionally in single-process runs.""" + return not _is_distributed() or dist.get_rank() == 0 + + +def _barrier() -> None: + """Synchronize all ranks if distributed is initialized; no-op otherwise.""" + if _is_distributed(): + dist.barrier() + + +def _log_info_rank0(msg: str, *args) -> None: + """Emit logger.info on rank 0 only; no-op on other ranks. Prevents N-way log duplication.""" + if _is_rank_zero(): + logger.info(msg, *args) + + +def _sample_clusters_by_size(count_mat, global_keys, used_clusters=None, target_size=30000, seed=42): + """Verbatim port of codonfm_ptl_te.src.data.dataset_utils.sample_clusters_by_size. + + Sample whole sequence-clusters proportionally per organism group, starting from rarest, + until target_size sequences are reached. Used to carve val/test out of global_keys. + """ + if used_clusters is None: + used_clusters = [] + + np.random.seed(seed) # noqa: NPY002 -- verbatim port; PTL uses NumPy global RNG here + + total_per_organism = count_mat.sum(axis=0) + total_sum = total_per_organism.sum() + if total_sum == 0: + raise ValueError("No sequences found in count matrix") + target_proportions = total_per_organism / total_sum + org_indices = np.argsort(total_per_organism) + + selected_clusters = [] + current_counts = np.zeros_like(total_per_organism) + current_total = 0 + + for org_idx in tqdm(org_indices, desc="Sampling clusters"): + target_count = max(1, int(target_size * target_proportions[org_idx])) + valid_clusters = np.where(count_mat[:, org_idx] > 0)[0] + valid_clusters = list(set(valid_clusters) - set(selected_clusters) - set(used_clusters)) + np.random.shuffle(valid_clusters) # noqa: NPY002 + + for cluster_idx in valid_clusters: + if current_counts[org_idx] >= target_count: + break + selected_clusters.append(cluster_idx) + current_counts += count_mat[cluster_idx] + current_total += count_mat[cluster_idx].sum() + + selected_indices = np.where(np.isin(global_keys, selected_clusters))[0] + return np.array(selected_indices), selected_clusters + + +def _load_train_val_test_indices_proportional( + metadata_path, + global_keys, + global_indices, + post_fix="", + train_val_test_ratio=(0.9998, 0.0002, 0.0), + seed=42, + force_recompute=False, +): + """Verbatim port of codonfm_ptl_te.src.data.dataset_utils.load_train_val_test_indices_proportional. + + When force_recompute=True the existing train_idx/val_idx/test_idx files (if any) are ignored + and overwritten with freshly computed splits. Callers must set this to True whenever the + upstream global_indices has just been regenerated, otherwise stale splits would silently + pair with a fresh global array. + """ + train_idx_path = metadata_path.parent / f"train_idx{post_fix}.npy" + val_idx_path = metadata_path.parent / f"val_idx{post_fix}.npy" + test_idx_path = metadata_path.parent / f"test_idx{post_fix}.npy" + + cache_hit = not force_recompute and train_idx_path.exists() and val_idx_path.exists() and test_idx_path.exists() + + # Synchronize the cache-presence decision before any rank writes. Without this barrier a + # straggler rank could check existence AFTER rank 0 wrote the files and disagree on the + # branch, leading to a downstream barrier deadlock (one rank in regen, others in cache-hit). + _barrier() + + if not cache_hit: + if _is_rank_zero(): + _log_info_rank0("Computing train/val/test split indices (post_fix=%r)", post_fix) + with open(metadata_path, "r") as f: + metadata = json.load(f)["file_metadata"] + groups = sorted({metadata[i]["file_name"].split(".")[0] for i in range(len(metadata))}) + count_mat = np.zeros((max(global_keys) + 1, len(groups))) + for idx, cluster_idx in zip(global_indices, global_keys): + group = metadata[idx[0]]["file_name"].split(".")[0] + count_mat[cluster_idx, groups.index(group)] += 1 + + val_size = int(len(global_keys) * train_val_test_ratio[1]) + test_size = int(len(global_keys) * train_val_test_ratio[2]) + val_indices, val_clusters = np.array([]), [] + test_indices, test_clusters = np.array([]), [] + if val_size > 0: + val_indices, val_clusters = _sample_clusters_by_size( + count_mat, global_keys, used_clusters=[], target_size=val_size, seed=seed + ) + if test_size > 0: + test_indices, test_clusters = _sample_clusters_by_size( + count_mat, global_keys, used_clusters=val_clusters, target_size=test_size, seed=seed + ) + train_indices = np.setdiff1d(np.arange(len(global_keys)), np.concatenate([val_indices, test_indices])) + np.save(train_idx_path, train_indices) + np.save(val_idx_path, val_indices) + np.save(test_idx_path, test_indices) + # Non-rank-0 ranks (and rank 0 after writing) wait here so everyone loads the same + # on-disk state below. + _barrier() + else: + _log_info_rank0("Loading existing train/val/test split indices from disk (post_fix=%r)", post_fix) + + train_indices = np.load(train_idx_path, mmap_mode="r") + val_indices = np.load(val_idx_path, mmap_mode="r") + test_indices = np.load(test_idx_path, mmap_mode="r") + return train_indices, val_indices, test_indices + + +def _get_taxids_to_exclude(taxid_exclusion_file): + """Load taxids-to-exclude JSON. Verbatim port of PTL's get_taxids_to_exclude.""" + with open(taxid_exclusion_file, "r") as f: + exclusion_content = json.load(f) + taxids = [] + for v in exclusion_content.values(): + taxids.extend(v) + return set(taxids) + + +def _taxid_exclusion_hash(taxids_to_exclude: set | None) -> str | None: + """Stable SHA-256 hash of the sorted excluded-taxid set, or None when no exclusion is in effect. + + Hashes the parsed set rather than the raw file bytes, so the stamp is insensitive to JSON + formatting and key ordering while still catching content changes that affect which sequences + appear in global_indices. JSON-serializes the sorted int-coerced taxids for a deterministic + representation that does not depend on numpy/Python version-specific repr(). + """ + if taxids_to_exclude is None: + return None + payload = json.dumps(sorted(int(t) for t in taxids_to_exclude), separators=(",", ":")) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def _parity_stamp(**args) -> dict: + """Stable JSON-serializable dict of parity-affecting args (used to detect silent cache drift).""" + return {k: args[k] for k in sorted(args)} + + +class CodonMemmapDataset(Dataset): + """Split-aware memmap dataset that reproduces the PTL recipe's sample selection. + + Each instance is bound to a single split ("train", "validation", or "test"). The class reads the + memmap format produced by codonfm_ptl_te's preprocessing (a directory with metadata.json plus + sequences_chunk*.mmap and index_chunk*.mmap files) and exposes the same per-window slices the + PTL recipe trains on, filtered by the same min/max length and taxid rules, and split via the + same proportional cluster sampler. + + Cache files follow PTL's filename conventions, but the strict contract here is that a + cache is only considered valid when the metadata.cache_params{cache_suffix}.json sidecar + is also present. An unstamped cache (e.g. one produced by codonfm_ptl_te, which does not + write stamps) is treated as missing and triggers a full rebuild. The stamp records every + parity-affecting argument and is namespaced with cache_suffix (= "_{split_name_prefix}"), + so it shares a namespace with cache_path/key_cache_path and catches context_overlap drift. + + Args: + data_path: Path to directory containing metadata.json and mmap chunk files. + split: Which split this instance is bound to ("train", "validation", "test"). + max_seq_length: Model context length including special tokens (PTL calls this context_length). + context_overlap: Overlap between consecutive windows of the same source sequence (token count). + pretraining_task: "mlm" (reserves 2 tokens for CLS/SEP) or "next_token_prediction" (reserves 1). + train_val_test_ratio: Three-element tuple summing to <=1; passes through to the cluster sampler. + min_seq_length: Drop source sequences with fewer than this many tokens. + max_filter_seq_length: Drop source sequences with more than this many tokens. + groups_to_use: Optional list of organism group names to restrict to (intersected per split). + taxid_exclusion_file: Optional JSON file listing taxids to drop. + split_name_prefix: Suffix appended to cache filenames for namespacing (required if + taxid_exclusion_file is set, since exclusion changes which windows exist). + seed: Random seed for the proportional cluster split. + force_recompute: If True, ignore any existing cache (the global window cache AND the + per-split index files) and rebuild from scratch. The flag cascades from the global + cache into the split loader so a freshly regenerated global_indices is never paired + with stale on-disk splits. + + Note: + __getitem__ returns {"sequence_tokens": np.ndarray} carrying the raw memmap codon + IDs — NOT a decoded codon string. Returning IDs directly preserves tokens + (ID 2) that the existing MemmapCodonDataset silently drops via a decode -> string -> + re-encode round trip (the codon tokenizer's tokenize() is strict 3-char chunking + that misaligns when "" literals appear in the decoded string). The + CodonMLMCollator / CodonTHDCollator branch on the key name and wrap with CLS/SEP + without re-tokenizing. + """ + + def __init__( + self, + data_path: str, + *, + split: str = "train", + max_seq_length: int = 2048, + context_overlap: int = 0, + pretraining_task: str = "mlm", + train_val_test_ratio: tuple[float, float, float] = (0.9998, 0.0002, 0.0), + min_seq_length: int = 100, + max_filter_seq_length: int = 150_000, + groups_to_use: list[str] | None = None, + taxid_exclusion_file: str | None = None, + split_name_prefix: str = "", + seed: int = 123, + force_recompute: bool = False, + ): + """Initialize.""" + if split not in _VALID_SPLITS: + raise ValueError(f"split must be one of {_VALID_SPLITS}, got {split!r}") + + self.data_path = Path(data_path) + self.metadata_path = self.data_path / "metadata.json" + self.split = split + self.max_seq_length = max_seq_length + self.context_overlap = context_overlap + self.pretraining_task = pretraining_task + self.train_val_test_ratio = tuple(train_val_test_ratio) + self.min_seq_length = min_seq_length + self.max_filter_seq_length = max_filter_seq_length + self.groups_to_use = groups_to_use + self.taxid_exclusion_file = taxid_exclusion_file + self.split_name_prefix = split_name_prefix + self.seed = seed + + # PTL convention: MLM reserves CLS+SEP (2), NTP reserves a single sentinel (1). + # The collator will reinsert these at batch time; we just leave room in each window. + if pretraining_task == "mlm": + self.tok_adjust = 2 + elif pretraining_task == "next_token_prediction": + self.tok_adjust = 1 + raise ValueError("next_token_prediction is not supported for codonfm_native_te") + else: + raise ValueError(f"Invalid pretraining_task '{pretraining_task}'") + + self.taxids_to_exclude = None + if taxid_exclusion_file: + if not split_name_prefix: + raise ValueError("split_name_prefix is required when taxid_exclusion_file is set") + self.taxids_to_exclude = _get_taxids_to_exclude(taxid_exclusion_file) + _log_info_rank0("Loaded %d taxids to exclude", len(self.taxids_to_exclude)) + + with open(self.metadata_path, "r") as f: + self.metadata = json.load(f) + + self.sequences_mmaps = [] + self.indices_mmaps = [] + for chunk in self.metadata["chunks"]: + seq_mmap = np.memmap( + self.data_path / chunk["sequences"]["path"], + dtype=chunk["sequences"]["dtype"], + mode="r", + shape=tuple(chunk["sequences"]["shape"]), + ) + idx_mmap = np.memmap( + self.data_path / chunk["index"]["path"], + dtype=chunk["index"]["dtype"], + mode="r", + shape=tuple(chunk["index"]["shape"]), + ) + self.sequences_mmaps.append(seq_mmap) + self.indices_mmaps.append(idx_mmap) + + # PTL filename conventions: cache_suffix gates the all-windows cache, post_fix gates the + # per-split cache. They diverge because PTL only namespaced splits (not the underlying + # windows) by overlap. We preserve that scheme so PTL-generated caches are readable. + cache_suffix = f"_{split_name_prefix}" if split_name_prefix else "" + post_fix = "_" + split_name_prefix if split_name_prefix else "" + if context_overlap != 0: + post_fix = post_fix + f"_overlap_{context_overlap}" + + cache_path = self.metadata_path.with_suffix(f".cache{cache_suffix}.npy") + key_cache_path = self.metadata_path.with_suffix(f".key.cache{cache_suffix}.npy") + stamp_path = self.metadata_path.with_suffix(f".cache_params{cache_suffix}.json") + + current_stamp = _parity_stamp( + max_seq_length=max_seq_length, + context_overlap=context_overlap, + pretraining_task=pretraining_task, + train_val_test_ratio=list(train_val_test_ratio), + min_seq_length=min_seq_length, + max_filter_seq_length=max_filter_seq_length, + groups_to_use=sorted(groups_to_use) if groups_to_use else None, + taxid_exclusion_hash=_taxid_exclusion_hash(self.taxids_to_exclude), + split_name_prefix=split_name_prefix, + seed=seed, + ) + + # Strict contract: a cache is only considered present when ALL three files + # (globals + keys + stamp) exist together. An unstamped cache (e.g. one produced by + # codonfm_ptl_te) is treated as invalid and triggers a full rebuild. + cache_present = cache_path.exists() and key_cache_path.exists() and stamp_path.exists() + regenerated = (not cache_present) or force_recompute + + # Stamp validation is a pure read; every rank reads the same on-disk file and either + # raises consistently or proceeds. Done before the decision-sync barrier so the + # mismatch error surfaces immediately on all ranks. + if cache_present and not force_recompute: + with open(stamp_path, "r") as f: + on_disk_stamp = json.load(f) + if on_disk_stamp != current_stamp: + raise RuntimeError( + f"Cache parity-stamp at {stamp_path} disagrees with current arguments. " + f"Set force_recompute=True to regenerate.\n" + f"On-disk: {on_disk_stamp}\nCurrent: {current_stamp}" + ) + + # Synchronize the cache-presence decision before any rank writes. Without this barrier a + # straggler rank could check existence AFTER rank 0 wrote the globals and disagree on the + # branch, leading to a downstream barrier deadlock (one rank in regen, others in cache-hit). + _barrier() + + if regenerated: + if _is_rank_zero(): + # Invalidate the old stamp before touching anything else. A crash anywhere + # between here and the stamp rewrite at the end leaves "no stamp on disk", + # which the next run treats as cache invalid -> full rebuild. + stamp_path.unlink(missing_ok=True) + _log_info_rank0("Computing global indices for subsequences (this may take a while)") + global_indices_arr, global_keys_arr = self._build_global_indices() + np.save(cache_path, global_indices_arr) + np.save(key_cache_path, global_keys_arr) + # Non-rank-0 ranks (and rank 0 after writing) wait here so everyone loads the same + # on-disk state below. + _barrier() + + # Every rank loads as mmap views so train+val instances share kernel page cache. + self.global_indices = np.load(cache_path, mmap_mode="r") + self.global_keys = np.load(key_cache_path, mmap_mode="r") + + # Cascade the recompute into the split loader: whenever globals were rebuilt (explicit + # force_recompute or cache miss), the on-disk split files may have been written against a + # different global_indices and must be overwritten — otherwise stale splits silently pair + # with fresh global arrays. The split loader handles its own rank-0 gating internally. + train_indices, val_indices, test_indices = _load_train_val_test_indices_proportional( + metadata_path=self.metadata_path, + global_keys=self.global_keys, + global_indices=self.global_indices, + post_fix=post_fix, + train_val_test_ratio=self.train_val_test_ratio, + seed=seed, + force_recompute=regenerated, + ) + + n_total = len(train_indices) + len(val_indices) + len(test_indices) + if n_total != len(self.global_indices): + raise RuntimeError( + f"Split indices do not partition global_indices (size mismatch): " + f"{len(train_indices)} + {len(val_indices)} + {len(test_indices)} != {len(self.global_indices)}" + ) + # Size match alone misses overlaps balanced by gaps; verify uniqueness and range for loaded caches. + all_indices = np.concatenate([train_indices, val_indices, test_indices]) + if len(np.unique(all_indices)) != n_total: + raise RuntimeError("Split indices contain duplicates across train/val/test") + if n_total > 0 and (all_indices.min() < 0 or all_indices.max() >= len(self.global_indices)): + raise RuntimeError( + f"Split indices out of range [0, {len(self.global_indices)}): " + f"min={all_indices.min()}, max={all_indices.max()}" + ) + + # Stamp written only after the partition check passes, so a construction-bug failure + # leaves no stamp on disk falsely blessing the bad cache. + if regenerated: + if _is_rank_zero(): + with open(stamp_path, "w") as f: + json.dump(current_stamp, f, indent=2, sort_keys=True) + _barrier() + + if groups_to_use: + groups = [ + self.metadata["file_metadata"][i]["file_name"].split(".")[0] + for i in range(len(self.metadata["file_metadata"])) + ] + missing = set(groups_to_use) - set(groups) + if missing: + raise ValueError(f"Groups not present in dataset: {sorted(missing)}. Available: {groups}") + global_groups = [groups[i] for i in self.global_indices[:, 0]] + indices_to_use = np.where(np.isin(global_groups, groups_to_use))[0] + train_indices = np.intersect1d(train_indices, indices_to_use) + val_indices = np.intersect1d(val_indices, indices_to_use) + test_indices = np.intersect1d(test_indices, indices_to_use) + + if len(train_indices) == 0 or len(val_indices) == 0: + raise ValueError("Empty train or validation split after filtering") + + split_map = {"train": train_indices, "validation": val_indices, "test": test_indices} + self.split_indices = split_map[split] + + _log_info_rank0( + "CodonMemmapDataset(split=%s) ready: %d windows in this split / %d total", + split, + len(self.split_indices), + len(self.global_indices), + ) + + def _build_global_indices(self) -> tuple[np.ndarray, np.ndarray]: + """Build the full per-window index from the chunk memmaps. + + Verbatim port of codonfm_ptl_te.src.data.codon_memmap_dataset.CodonMemmapDataset's + inner loop, including filtering by taxid / min_seq_length / max_filter_seq_length and + the floor-division windowing (which can drop a partial tail window vs. the existing + MemmapCodonDataset's ceiling-division windowing). + + Returns: + (global_indices, global_keys) — int64 arrays of shape (num_windows, 3) and (num_windows,). + """ + global_indices_list = [] + global_keys = [] + total_sequences = sum(len(idx_mmap) for idx_mmap in self.indices_mmaps) + + seq_cluster_path = self.data_path / "allSeqClusterIdx.npy" + if seq_cluster_path.exists(): + seq_clusters = np.load(seq_cluster_path, allow_pickle=True) + if seq_clusters.shape[0] != total_sequences: + raise RuntimeError( + f"Mismatch: allSeqClusterIdx.npy has {seq_clusters.shape[0]} entries, " + f"but total_sequences across chunks is {total_sequences}" + ) + else: + # Without a cluster file every sequence is its own cluster; the proportional + # sampler then degrades to per-sequence sampling. + seq_clusters = np.arange(total_sequences, dtype=int) + + step_size = self.max_seq_length - self.tok_adjust - self.context_overlap + global_seq_idx = 0 + with tqdm(total=total_sequences, desc="Processing sequences") as pbar: + for chunk_id, idx_mmap in enumerate(self.indices_mmaps): + for seq_idx in range(len(idx_mmap)): + seq_cluster_idx = seq_clusters[global_seq_idx] + global_seq_idx += 1 + seq_start, seq_end, taxid = idx_mmap[seq_idx] + if self.taxids_to_exclude and taxid in self.taxids_to_exclude: + pbar.update(1) + continue + seq_len_tokens = seq_end - seq_start + if seq_len_tokens < self.min_seq_length or seq_len_tokens > self.max_filter_seq_length: + pbar.update(1) + continue + num_subsequences = max(1, (seq_len_tokens - self.context_overlap) // step_size) + for sub_seq_idx in range(num_subsequences): + start_token_idx = seq_start + sub_seq_idx * step_size + end_token_idx = min(start_token_idx + (self.max_seq_length - self.tok_adjust), seq_end) + if end_token_idx > start_token_idx: + global_indices_list.append([chunk_id, start_token_idx, end_token_idx]) + global_keys.append(seq_cluster_idx) + pbar.update(1) + + return ( + np.array(global_indices_list, dtype=np.int64), + np.array(global_keys, dtype=np.int64), + ) + + def __len__(self) -> int: # noqa: D105 + return len(self.split_indices) + + def __getitem__(self, idx: int) -> dict[str, np.ndarray]: # noqa: D105 + global_idx = self.split_indices[idx] + chunk_id, start, end = self.global_indices[global_idx] + return {"sequence_tokens": np.array(self.sequences_mmaps[chunk_id][start:end])} diff --git a/bionemo-recipes/recipes/codonfm_native_te/dataset.py b/bionemo-recipes/recipes/codonfm_native_te/dataset.py index c9b3652745..3802d47f4d 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/dataset.py +++ b/bionemo-recipes/recipes/codonfm_native_te/dataset.py @@ -22,6 +22,7 @@ import numpy as np import pyarrow.parquet as pq import torch +from codon_memmap_dataset import CodonMemmapDataset from distributed_config import DistributedConfig from tokenizer import CodonTokenizer from torch.utils.data import DataLoader, Dataset, DistributedSampler @@ -162,6 +163,11 @@ def __len__(self) -> int: # noqa: D105 def __getitem__(self, idx: int) -> dict[str, str]: # noqa: D105 chunk_id, start, end = self.global_indices[idx] token_ids = self.sequences_mmaps[chunk_id][start:end] + # Note: decode(skip_special_tokens=True) silently drops tokens (ID 2). The codon + # tokenizer's tokenize() is strict 3-char chunking that cannot reparse the "" + # literal in a decoded string, so any window containing ambiguous-base codons loses + # those positions when round-tripped. Use CodonMemmapDataset (returns sequence_tokens + # directly) for PTL-parity behavior. sequence = self.tokenizer.decode(token_ids.tolist(), skip_special_tokens=True) return {"sequence": sequence} @@ -203,7 +209,10 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: all_labels = [] for sample in batch: - ids = self.tokenizer.encode(sample["sequence"], add_special_tokens=True) + if "sequence_tokens" in sample: + ids = [self.tokenizer.cls_token_id, *sample["sequence_tokens"].tolist(), self.tokenizer.sep_token_id] + else: + ids = self.tokenizer.encode(sample["sequence"], add_special_tokens=True) # Truncate to max_seq_length, preserving trailing SEP token if len(ids) > self.max_seq_length: ids = [*ids[: self.max_seq_length - 1], self.tokenizer.sep_token_id] @@ -281,7 +290,10 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: seq_lengths = [] for sample in batch: - ids = self.tokenizer.encode(sample["sequence"], add_special_tokens=True) + if "sequence_tokens" in sample: + ids = [self.tokenizer.cls_token_id, *sample["sequence_tokens"].tolist(), self.tokenizer.sep_token_id] + else: + ids = self.tokenizer.encode(sample["sequence"], add_special_tokens=True) # Truncate to max_seq_length, preserving trailing SEP token if len(ids) > self.max_seq_length: ids = [*ids[: self.max_seq_length - 1], self.tokenizer.sep_token_id] @@ -344,13 +356,26 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: } -def _create_dataset(data_path: str, max_seq_length: int, seed: int) -> Dataset: +def _create_dataset( + data_path: str, + max_seq_length: int, + seed: int, + split: str | None = None, + split_kwargs: dict | None = None, +) -> Dataset: """Create the appropriate dataset based on data_path format. Args: data_path: 'synthetic', path to a parquet file, or path to a memmap directory. max_seq_length: Maximum sequence length (used for memmap sliding windows). seed: Random seed. + split: If set ("train" / "validation" / "test"), construct the split-aware + CodonMemmapDataset (port of the PTL dataset) instead of MemmapCodonDataset. + Only meaningful when data_path is a memmap directory; ignored otherwise. + split_kwargs: Extra keyword arguments forwarded to CodonMemmapDataset + (train_val_test_ratio, context_overlap, pretraining_task, min_seq_length, + max_filter_seq_length, groups_to_use, taxid_exclusion_file, split_name_prefix, + force_recompute). Only used when split is set. Returns: A Dataset instance. @@ -359,6 +384,14 @@ def _create_dataset(data_path: str, max_seq_length: int, seed: int) -> Dataset: return SyntheticCodonDataset(num_samples=500, seed=seed) data_dir = Path(data_path) if data_dir.is_dir() and (data_dir / "metadata.json").exists(): + if split is not None: + return CodonMemmapDataset( + data_path, + split=split, + max_seq_length=max_seq_length, + seed=seed, + **(split_kwargs or {}), + ) return MemmapCodonDataset(data_path, max_seq_length=max_seq_length) return ParquetCodonDataset(data_path) @@ -372,6 +405,8 @@ def create_bshd_dataloader( num_workers: int = 1, seed: int = 42, pad_to_multiple_of: int | None = None, + split: str | None = None, + split_kwargs: dict | None = None, ) -> tuple[DataLoader, DistributedSampler]: """Create a BSHD-format dataloader. @@ -384,25 +419,30 @@ def create_bshd_dataloader( num_workers: Number of dataloader workers. seed: Random seed. pad_to_multiple_of: Unused in BSHD mode (only applies to THD). + split: If set, use the split-aware CodonMemmapDataset for memmap dirs. + split_kwargs: Extra arguments forwarded to CodonMemmapDataset when split is set. Returns: Tuple of (DataLoader, DistributedSampler). """ tokenizer = CodonTokenizer() - dataset = _create_dataset(data_path, max_seq_length, seed) + dataset = _create_dataset(data_path, max_seq_length, seed, split=split, split_kwargs=split_kwargs) + sampler_kwargs = {"shuffle": False} if split == "validation" else {} sampler = DistributedSampler( dataset, rank=dist_config.rank, num_replicas=dist_config.world_size, seed=seed, + **sampler_kwargs, ) collator = CodonMLMCollator( tokenizer=tokenizer, max_seq_length=max_seq_length, mlm_probability=mlm_probability, + seed=seed, ) dataloader = DataLoader( @@ -426,6 +466,8 @@ def create_thd_dataloader( num_workers: int = 1, seed: int = 42, pad_to_multiple_of: int | None = None, + split: str | None = None, + split_kwargs: dict | None = None, ) -> tuple[DataLoader, DistributedSampler]: """Create a THD-format (packed sequence) dataloader. @@ -440,6 +482,8 @@ def create_thd_dataloader( pad_to_multiple_of: If set, pad total tokens to a multiple of this value. If None, defaults to micro_batch_size * max_seq_length for consistent tensor shapes (matching ESM2's approach). Set to 0 to disable padding. + split: If set, use the split-aware CodonMemmapDataset for memmap dirs. + split_kwargs: Extra arguments forwarded to CodonMemmapDataset when split is set. Returns: Tuple of (DataLoader, DistributedSampler). @@ -454,13 +498,15 @@ def create_thd_dataloader( elif pad_to_multiple_of == 0: pad_to_multiple_of = None - dataset = _create_dataset(data_path, max_seq_length, seed) + dataset = _create_dataset(data_path, max_seq_length, seed, split=split, split_kwargs=split_kwargs) + sampler_kwargs = {"shuffle": False} if split == "validation" else {} sampler = DistributedSampler( dataset, rank=dist_config.rank, num_replicas=dist_config.world_size, seed=seed, + **sampler_kwargs, ) collator = CodonTHDCollator( @@ -468,6 +514,7 @@ def create_thd_dataloader( max_seq_length=max_seq_length, mlm_probability=mlm_probability, pad_to_multiple_of=pad_to_multiple_of, + seed=seed, ) dataloader = DataLoader( @@ -480,3 +527,59 @@ def create_thd_dataloader( ) return dataloader, sampler + + +def create_dataloaders( + dist_config: DistributedConfig, + *, + use_sequence_packing: bool, + build_validation: bool, + use_split_dataset: bool = True, + split_kwargs: dict | None = None, + **factory_kwargs, +) -> tuple[DataLoader, DataLoader | None, DistributedSampler]: + """Build train (and optionally validation) dataloaders from a single configuration. + + Wrapper modeled on esm2_peft_te.create_dataloader: one factory call produces both loaders, so + train and val datasets share the on-disk caches via mmap and the kernel page cache. When + use_split_dataset is True, the new CodonMemmapDataset is constructed for each split (train/val + samples are disjoint by the PTL proportional cluster split); when False, the legacy path is + used and the val loader simply re-reads the train data (placeholder behavior). + + If split_kwargs requests force_recompute, the flag is honored only by the train call; the val + call is invoked with force_recompute=False so the cache written by train is reused instead of + rebuilt a second time in the same process. + + Args: + dist_config: Distributed configuration. + use_sequence_packing: Pick THD factory if True, BSHD factory if False. + build_validation: If False, skip val-loader construction entirely (returns None). + use_split_dataset: When True (default), construct the split-aware CodonMemmapDataset + for memmap directories. Set to False to fall back to the legacy MemmapCodonDataset, + in which case the val loader (if requested) re-reads the train data as a + placeholder. Has no effect for synthetic/parquet data paths. + split_kwargs: Extra arguments forwarded to CodonMemmapDataset (only used when + use_split_dataset=True). See codon_memmap_dataset.CodonMemmapDataset for the full list. + **factory_kwargs: Remaining keyword arguments passed to the low-level factory + (data_path, micro_batch_size, max_seq_length, mlm_probability, num_workers, seed, + pad_to_multiple_of). + + Returns: + Tuple of (train_dataloader, val_dataloader or None, train DistributedSampler). + """ + factory = create_thd_dataloader if use_sequence_packing else create_bshd_dataloader + + train_split = "train" if use_split_dataset else None + val_split = "validation" if use_split_dataset else None + + train_dataloader, sampler = factory(dist_config, split=train_split, split_kwargs=split_kwargs, **factory_kwargs) + + val_dataloader = None + if build_validation: + # The train call above has already regenerated the cache if force_recompute was set, so + # the val call must use the warmed cache rather than redo the work. Copy split_kwargs to + # avoid mutating the caller's dict. + val_split_kwargs = {**split_kwargs, "force_recompute": False} if split_kwargs is not None else None + val_dataloader, _ = factory(dist_config, split=val_split, split_kwargs=val_split_kwargs, **factory_kwargs) + + return train_dataloader, val_dataloader, sampler diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity.yaml index 91fe2b96bb..75ab6ddcc6 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity.yaml @@ -30,3 +30,8 @@ checkpoint: logger: frequency: 1 + +validation: + enabled: true + eval_interval: 50 + num_batches: 4 diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml index 3a97660834..3f1542e4f0 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml @@ -11,8 +11,20 @@ dataset: data_path: ??? micro_batch_size: ??? num_workers: 1 - max_seq_length: 512 + max_seq_length: 2048 mlm_probability: 0.15 + seed: 123 # Used for DistributedSampler, MLM masking RNG, and (when split mode is on) the cluster split. + use_split_dataset: true + split_kwargs: + train_val_test_ratio: [0.9998, 0.0002, 0.0] + context_overlap: 0 + pretraining_task: mlm + min_seq_length: 100 + max_filter_seq_length: 150_000 + groups_to_use: null + taxid_exclusion_file: null + split_name_prefix: "" + force_recompute: false # WandB config wandb_init_args: @@ -35,7 +47,7 @@ fp4_config: adamw_kwargs: lr: 4e-4 fused: true - betas: [0.9, 0.98] + betas: [0.9, 0.999] eps: 1e-8 weight_decay: 0.01 @@ -55,6 +67,11 @@ checkpoint: logger: frequency: 100 +validation: + enabled: false + eval_interval: 500 + num_batches: 10 + quant_stats_config: enabled: false quant_stats_file: ./fp8_debugging_stats.yaml diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_1b.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_1b.yaml index 02884ae8ed..d00b85302f 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_1b.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_1b.yaml @@ -11,7 +11,7 @@ dataset: data_path: ??? micro_batch_size: 4 num_workers: 1 - max_seq_length: 512 + max_seq_length: 2048 # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_5b.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_5b.yaml index 2477353300..10053d0571 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_5b.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_5b.yaml @@ -11,7 +11,7 @@ dataset: data_path: ??? micro_batch_size: 4 num_workers: 1 - max_seq_length: 512 + max_seq_length: 2048 # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index a0b5a21b70..6f0574c92c 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -181,6 +181,18 @@ def log_step( self.num_unpadded_tokens.zero_() self.grad_acc_step_count = 0 + def log_validation(self, step: int, val_metrics: dict): + """Log validation metrics to wandb on the main process. + + Args: + step: The current optimizer step. + val_metrics: Dict of metric name -> scalar value (already reduced across ranks). + """ + if not self._dist_config.is_main_process(): + return + wandb.log({f"val/{k}": v for k, v in val_metrics.items()}, step=step) + logger.info("[VAL step=%d] %s", step, ", ".join(f"{k}: {v:.4g}" for k, v in val_metrics.items())) + def finish(self): """Finish the logger.""" if self.quant_stats_config: diff --git a/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh b/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh index 5ec827e7bc..83338f3f97 100755 --- a/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh +++ b/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh @@ -83,7 +83,7 @@ torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \ use_fp32_master_weights=${USE_FP32_MASTER_WEIGHTS} \ lr_scheduler_kwargs.num_warmup_steps=${NUM_WARMUP_STEPS} \ wandb_init_args.name=${WANDB_RUN_NAME} \ - +wandb_init_args.project=${WANDB_PROJECT} \ + wandb_init_args.project=${WANDB_PROJECT} \ checkpoint.save_final_model=${SAVE_FINAL_MODEL} \ checkpoint.save_every_n_steps=${SAVE_EVERY_N_STEPS} \ checkpoint.ckpt_dir=${CKPT_DIR} \ @@ -91,4 +91,5 @@ torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \ hydra.run.dir=${HYDRA_RUN_DIR} \ fp8_config.enabled=${FP8_ENABLED} \ fp8_config.fp8_recipe=${FP8_RECIPE} \ - fp8_config.fp8_format=${FP8_FORMAT} + fp8_config.fp8_format=${FP8_FORMAT} \ + dataset.pad_to_multiple_of=32 diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py b/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py index 9acc153bbb..b0afe4fe7c 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py @@ -17,13 +17,14 @@ import logging from contextlib import nullcontext +from datetime import timedelta from pathlib import Path import hydra import nvdlfw_inspect.api as debug_api import torch from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint -from dataset import create_bshd_dataloader, create_thd_dataloader +from dataset import create_dataloaders from distributed_config import DistributedConfig from modeling_codonfm_te import MODEL_PRESETS, CodonFMConfig, CodonFMForMaskedLM from omegaconf import DictConfig, OmegaConf @@ -52,7 +53,7 @@ def main(args: DictConfig) -> float | None: dist_config = DistributedConfig() logger.info("Initializing distributed training: %s", dist_config) device = torch.device(f"cuda:{dist_config.local_rank}") - torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.distributed.init_process_group(backend="nccl", device_id=device, timeout=timedelta(hours=1)) torch.cuda.set_device(dist_config.local_rank) # DDP keeps a single param dtype per replica, so it can't emulate FSDP2's @@ -145,12 +146,16 @@ def main(args: DictConfig) -> float | None: optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) - # Create dataloader dataloader_kwargs = OmegaConf.to_container(args.dataset, resolve=True) - train_dataloader, sampler = ( - create_thd_dataloader(dist_config, **dataloader_kwargs) - if args.use_sequence_packing - else create_bshd_dataloader(dist_config, **dataloader_kwargs) + use_split_dataset = dataloader_kwargs.pop("use_split_dataset", False) + split_kwargs = dataloader_kwargs.pop("split_kwargs", None) + train_dataloader, val_dataloader, sampler = create_dataloaders( + dist_config, + use_sequence_packing=args.use_sequence_packing, + build_validation=args.validation.enabled, + use_split_dataset=use_split_dataset, + split_kwargs=split_kwargs, + **dataloader_kwargs, ) # Resume from checkpoint if available @@ -225,6 +230,29 @@ def main(args: DictConfig) -> float | None: max_checkpoints=args.checkpoint.max_checkpoints, ) + if val_dataloader is not None and step > 0 and step % args.validation.eval_interval == 0: + model.eval() + val_loss_sum = torch.zeros((), device=device) + val_batches_seen = torch.zeros((), device=device) + val_iter = iter(val_dataloader) + with torch.no_grad(): + for _ in range(args.validation.num_batches): + try: + val_batch = next(val_iter) + except StopIteration: + break + val_batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in val_batch.items() + } + val_outputs = model(**val_batch) + val_loss_sum += val_outputs.loss.detach() + val_batches_seen += 1 + torch.distributed.all_reduce(val_loss_sum) + torch.distributed.all_reduce(val_batches_seen) + avg_val_loss = (val_loss_sum / val_batches_seen.clamp(min=1)).item() + perf_logger.log_validation(step, {"loss": avg_val_loss}) + model.train() + step += 1 if step >= args.num_train_steps: break diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py index 8b07f8954e..ab2d945fbc 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py @@ -23,7 +23,7 @@ import nvdlfw_inspect.api as debug_api import torch from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint -from dataset import create_bshd_dataloader, create_thd_dataloader +from dataset import create_dataloaders from distributed_config import DistributedConfig from modeling_codonfm_te import MODEL_PRESETS, CodonFMConfig, CodonFMForMaskedLM from omegaconf import DictConfig, OmegaConf @@ -141,12 +141,16 @@ def main(args: DictConfig) -> float | None: optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) - # Create dataloader dataloader_kwargs = OmegaConf.to_container(args.dataset, resolve=True) - train_dataloader, sampler = ( - create_thd_dataloader(dist_config, **dataloader_kwargs) - if args.use_sequence_packing - else create_bshd_dataloader(dist_config, **dataloader_kwargs) + use_split_dataset = dataloader_kwargs.pop("use_split_dataset", False) + split_kwargs = dataloader_kwargs.pop("split_kwargs", None) + train_dataloader, val_dataloader, sampler = create_dataloaders( + dist_config, + use_sequence_packing=args.use_sequence_packing, + build_validation=args.validation.enabled, + use_split_dataset=use_split_dataset, + split_kwargs=split_kwargs, + **dataloader_kwargs, ) # Resume from checkpoint if available @@ -215,6 +219,29 @@ def main(args: DictConfig) -> float | None: max_checkpoints=args.checkpoint.max_checkpoints, ) + if val_dataloader is not None and step > 0 and step % args.validation.eval_interval == 0: + model.eval() + val_loss_sum = torch.zeros((), device=device) + val_batches_seen = torch.zeros((), device=device) + val_iter = iter(val_dataloader) + with torch.no_grad(): + for _ in range(args.validation.num_batches): + try: + val_batch = next(val_iter) + except StopIteration: + break + val_batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in val_batch.items() + } + val_outputs = model(**val_batch) + val_loss_sum += val_outputs.loss.detach() + val_batches_seen += 1 + torch.distributed.all_reduce(val_loss_sum) + torch.distributed.all_reduce(val_batches_seen) + avg_val_loss = (val_loss_sum / val_batches_seen.clamp(min=1)).item() + perf_logger.log_validation(step, {"loss": avg_val_loss}) + model.train() + step += 1 if step >= args.num_train_steps: break From 45910f4cd7367ae1887a88c556d040952986dcc0 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Fri, 15 May 2026 19:22:37 +0000 Subject: [PATCH 03/17] Added support for different precision modes --- .../hydra_config/L0_sanity.yaml | 2 + .../hydra_config/defaults.yaml | 13 +- .../hydra_config/encodon_1b.yaml | 2 + .../hydra_config/encodon_5b.yaml | 2 + .../recipes/codonfm_native_te/run_1b.sh | 17 +- .../recipes/codonfm_native_te/slurm/1b.sh | 195 ++++++++++++++++++ .../codonfm_native_te/tests/test_train.py | 22 +- .../recipes/codonfm_native_te/train_ddp.py | 41 ++-- .../recipes/codonfm_native_te/train_fsdp2.py | 61 +++++- 9 files changed, 325 insertions(+), 30 deletions(-) create mode 100755 bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity.yaml index 75ab6ddcc6..52a0046df3 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity.yaml @@ -6,6 +6,8 @@ defaults: model_preset: encodon_200k num_train_steps: 250 +precision: fp32 + use_sequence_packing: false dataset: data_path: train.parquet diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml index 3f1542e4f0..16333758a2 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml @@ -81,4 +81,15 @@ quant_stats_config: # Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. fp8_layers: null fp4_layers: null -use_fp32_master_weights: null + +# Precision mode. One of: +# fp32 - params, compute, grads, and optimizer state all in fp32. +# bf16 - params, compute, grads, and optimizer state all in bf16 (pure bf16). +# bf16-mixed - fp32 master weights + bf16 compute (via autocast in DDP, via FSDP2 +# MixedPrecisionPolicy.param_dtype=bf16 in FSDP2). +precision: ??? + +# Gradient reduce dtype for FSDP2 when precision=bf16-mixed. One of: fp32, bf16. +# fp32 (default) is more conservative than PTL FSDP bf16-mixed (which reduces in bf16). +# Ignored for other precision modes and for DDP. +grad_reduce_type: fp32 diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_1b.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_1b.yaml index d00b85302f..230e773658 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_1b.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_1b.yaml @@ -6,6 +6,8 @@ defaults: model_preset: encodon_1b num_train_steps: 500_000 +precision: bf16-mixed + use_sequence_packing: true dataset: data_path: ??? diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_5b.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_5b.yaml index 10053d0571..951dd6df2c 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_5b.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_5b.yaml @@ -6,6 +6,8 @@ defaults: model_preset: encodon_5b num_train_steps: 500_000 +precision: bf16-mixed + use_sequence_packing: true dataset: data_path: ??? diff --git a/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh b/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh index 83338f3f97..11172df035 100755 --- a/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh +++ b/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh @@ -14,7 +14,10 @@ export NUM_TRAIN_STEPS=100 export MICRO_BATCH_SIZE=31 export NUM_WORKERS=1 export USE_SEQUENCE_PACKING=True -export USE_FP32_MASTER_WEIGHTS=True +# Precision mode: one of fp32, bf16, bf16-mixed. bf16-mixed matches the reference codonfm `--bf16`. +export PRECISION=bf16-mixed +# Only used for FSDP2 + bf16-mixed. One of fp32, bf16. +export GRAD_REDUCE_TYPE=fp32 export NUM_WARMUP_STEPS=500 # Logging / W&B @@ -46,24 +49,19 @@ if [ "${FP8_ENABLED}" = "True" ]; then RECIPE_SHORT="${FP8_RECIPE##*.}" RECIPE_SHORT="${RECIPE_SHORT%BlockScaling}" RECIPE_SHORT="${RECIPE_SHORT%Scaling}" - PRECISION_TAG="${RECIPE_SHORT,,}_${FP8_FORMAT,,}" + PRECISION_TAG="${PRECISION}_${RECIPE_SHORT,,}_${FP8_FORMAT,,}" else - PRECISION_TAG="bf16" + PRECISION_TAG="${PRECISION}" fi export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_bs${MICRO_BATCH_SIZE}_${PRECISION_TAG}" # Pick training script based on distributed strategy. -# DDP can't emulate FSDP's fp32-master / bf16-param split, so force fp32 master weights off. case "${DIST_STRATEGY}" in fsdp) TRAIN_SCRIPT=train_fsdp2.py ;; ddp) TRAIN_SCRIPT=train_ddp.py - if [ "${USE_FP32_MASTER_WEIGHTS}" = "True" ]; then - echo "DIST_STRATEGY=ddp: overriding USE_FP32_MASTER_WEIGHTS=True -> False" >&2 - export USE_FP32_MASTER_WEIGHTS=False - fi ;; *) echo "DIST_STRATEGY must be 'fsdp' or 'ddp', got '${DIST_STRATEGY}'" >&2 @@ -80,7 +78,8 @@ torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \ dataset.num_workers=${NUM_WORKERS} \ dataset.data_path=${DATASET_DATA_PATH} \ use_sequence_packing=${USE_SEQUENCE_PACKING} \ - use_fp32_master_weights=${USE_FP32_MASTER_WEIGHTS} \ + precision=${PRECISION} \ + grad_reduce_type=${GRAD_REDUCE_TYPE} \ lr_scheduler_kwargs.num_warmup_steps=${NUM_WARMUP_STEPS} \ wandb_init_args.name=${WANDB_RUN_NAME} \ wandb_init_args.project=${WANDB_PROJECT} \ diff --git a/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh b/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh new file mode 100755 index 0000000000..69189ec452 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh @@ -0,0 +1,195 @@ +#!/bin/bash +#SBATCH --account= +#SBATCH --nodes=1 +#SBATCH --partition= +#SBATCH --ntasks-per-node=1 +#SBATCH --time=03:55:00 +#SBATCH --mem=0 +#SBATCH --job-name= +#SBATCH --mail-type=FAIL +#SBATCH --overcommit +#SBATCH --exclusive +set -euxo pipefail + +# ============================================================================ +# Codon 1B +# ============================================================================ + +BASE_DIR="" +CONTAINER="" +DATA_DIR="${BASE_DIR}/data" +CODE_MOUNT="/workspace/bionemo" + + +: "${WANDB_API_KEY:?Set WANDB_API_KEY in ~/.bash_profile}" +: "${HUGGING_FACE_HUB_TOKEN:?Set HUGGING_FACE_HUB_TOKEN in ~/.bash_profile}" +: "${CLUSTER_NAME:?Set CLUSTER_NAME in ~/.bash_profile}" + +# Experiment parameters +export CONFIG_NAME=encodon_1b +export NPROC_PER_NODE=8 +export DIST_STRATEGY=ddp # fsdp or ddp + +# Training +export NUM_TRAIN_STEPS=1000 +export MICRO_BATCH_SIZE=31 +export LEARNING_RATE=7.5e-5 +export NUM_WORKERS=1 +export USE_SEQUENCE_PACKING=False +# Precision mode: one of fp32, bf16, bf16-mixed. bf16-mixed matches the reference codonfm `--bf16`. +export PRECISION=bf16-mixed +# Only used for FSDP2 + bf16-mixed. One of fp32, bf16. +export GRAD_REDUCE_TYPE=fp32 +export NUM_WARMUP_STEPS=50 + +# Logging / W&B +export LOGGER_FREQUENCY=10 +export WANDB_PROJECT= + +# Checkpointing +export SAVE_FINAL_MODEL=True +export SAVE_EVERY_N_STEPS=100000 +export RESUME_FROM_CHECKPOINT=True + +# Hydra +export HYDRA_RUN_DIR=1b_test + +# Quantization / FP8 +export QUANT_STATS_ENABLED=False +export FP8_ENABLED=False +export FP8_RECIPE=transformer_engine.common.recipe.MXFP8BlockScaling +export FP8_FORMAT=E4M3 + +# Derived: build wandb run name from model size, batch size, and precision recipe +MODEL_SIZE="${CONFIG_NAME##*_}" +if [ "${FP8_ENABLED}" = "True" ]; then + RECIPE_SHORT="${FP8_RECIPE##*.}" + RECIPE_SHORT="${RECIPE_SHORT%BlockScaling}" + RECIPE_SHORT="${RECIPE_SHORT%Scaling}" + PRECISION_TAG="${PRECISION}_${RECIPE_SHORT,,}_${FP8_FORMAT,,}" +else + PRECISION_TAG="${PRECISION}" +fi + +if [ "${USE_SEQUENCE_PACKING}" = "True" ]; then + BATCH_TYPE_TAG="thd" +else + BATCH_TYPE_TAG="bshd" +fi + +export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_bs${MICRO_BATCH_SIZE}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}" + +# Mounts +RESULTS_DIR="${BASE_DIR}/results/${WANDB_RUN_NAME}" +CKPT_DIR="${BASE_DIR}/checkpoints/${WANDB_RUN_NAME}" + +mkdir -p "${RESULTS_DIR}" "${CKPT_DIR}" + +MOUNTS="${DATA_DIR}:${CODE_MOUNT}/data,${RESULTS_DIR}:${CODE_MOUNT}/results,${CKPT_DIR}:${CODE_MOUNT}/checkpoints" + + +read -r -d '' COMMAND <<'OUTER_EOF' || true +set -euxo pipefail + +echo "=========================================" +echo "CodonFM ${CONFIG_NAME} - STRATEGY: ${DIST_STRATEGY} - PRECISION: ${PRECISION_TAG} - CLUSTER: ${CLUSTER_NAME}" +echo "Job ID: ${SLURM_JOB_ID}" +echo "Nodes: ${SLURM_JOB_NUM_NODES}" +echo "=========================================" + +# Pick training script based on distributed strategy. +case "${DIST_STRATEGY}" in + fsdp) + TRAIN_SCRIPT=train_fsdp2.py + ;; + ddp) + TRAIN_SCRIPT=train_ddp.py + ;; + *) + echo "DIST_STRATEGY must be 'fsdp' or 'ddp', got '${DIST_STRATEGY}'" >&2 + exit 1 + ;; +esac + +torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \ + --config-name ${CONFIG_NAME} \ + quant_stats_config.enabled=${QUANT_STATS_ENABLED} \ + logger.frequency=${LOGGER_FREQUENCY} \ + num_train_steps=${NUM_TRAIN_STEPS} \ + dataset.micro_batch_size=${MICRO_BATCH_SIZE} \ + adamw_kwargs.lr=${LEARNING_RATE} \ + dataset.num_workers=${NUM_WORKERS} \ + dataset.data_path=/workspace/bionemo/data/processed_unfiltered/ \ + use_sequence_packing=${USE_SEQUENCE_PACKING} \ + precision=${PRECISION} \ + grad_reduce_type=${GRAD_REDUCE_TYPE} \ + lr_scheduler_kwargs.num_warmup_steps=${NUM_WARMUP_STEPS} \ + wandb_init_args.name=${WANDB_RUN_NAME} \ + +wandb_init_args.id=${WANDB_RUN_NAME} \ + +wandb_init_args.project=${WANDB_PROJECT} \ + checkpoint.save_final_model=${SAVE_FINAL_MODEL} \ + checkpoint.save_every_n_steps=${SAVE_EVERY_N_STEPS} \ + checkpoint.ckpt_dir=/workspace/bionemo/checkpoints \ + checkpoint.resume_from_checkpoint=${RESUME_FROM_CHECKPOINT} \ + hydra.run.dir=${HYDRA_RUN_DIR} \ + fp8_config.enabled=${FP8_ENABLED} \ + fp8_config.fp8_recipe=${FP8_RECIPE} \ + fp8_config.fp8_format=${FP8_FORMAT} \ + +dataset.pad_to_multiple_of=32 + +echo "=========================================" +echo "Training complete!" +echo "=========================================" +OUTER_EOF + +# Inject environment variables into the command. +COMMAND="export DIST_STRATEGY=\"${DIST_STRATEGY}\"; ${COMMAND}" +COMMAND="export PRECISION_TAG=\"${PRECISION_TAG}\"; ${COMMAND}" +COMMAND="export CLUSTER_NAME=\"${CLUSTER_NAME}\"; ${COMMAND}" +COMMAND="export NPROC_PER_NODE=\"${NPROC_PER_NODE}\"; ${COMMAND}" +COMMAND="export CONFIG_NAME=\"${CONFIG_NAME}\"; ${COMMAND}" +COMMAND="export QUANT_STATS_ENABLED=\"${QUANT_STATS_ENABLED}\"; ${COMMAND}" +COMMAND="export LOGGER_FREQUENCY=\"${LOGGER_FREQUENCY}\"; ${COMMAND}" +COMMAND="export NUM_TRAIN_STEPS=\"${NUM_TRAIN_STEPS}\"; ${COMMAND}" +COMMAND="export MICRO_BATCH_SIZE=\"${MICRO_BATCH_SIZE}\"; ${COMMAND}" +COMMAND="export LEARNING_RATE=\"${LEARNING_RATE}\"; ${COMMAND}" +COMMAND="export NUM_WORKERS=\"${NUM_WORKERS}\"; ${COMMAND}" +COMMAND="export USE_SEQUENCE_PACKING=\"${USE_SEQUENCE_PACKING}\"; ${COMMAND}" +COMMAND="export PRECISION=\"${PRECISION}\"; ${COMMAND}" +COMMAND="export GRAD_REDUCE_TYPE=\"${GRAD_REDUCE_TYPE}\"; ${COMMAND}" +COMMAND="export NUM_WARMUP_STEPS=\"${NUM_WARMUP_STEPS}\"; ${COMMAND}" +COMMAND="export WANDB_RUN_NAME=\"${WANDB_RUN_NAME}\"; ${COMMAND}" +COMMAND="export WANDB_PROJECT=\"${WANDB_PROJECT}\"; ${COMMAND}" +COMMAND="export SAVE_FINAL_MODEL=\"${SAVE_FINAL_MODEL}\"; ${COMMAND}" +COMMAND="export SAVE_EVERY_N_STEPS=\"${SAVE_EVERY_N_STEPS}\"; ${COMMAND}" +COMMAND="export RESUME_FROM_CHECKPOINT=\"${RESUME_FROM_CHECKPOINT}\"; ${COMMAND}" +COMMAND="export HYDRA_RUN_DIR=\"${HYDRA_RUN_DIR}\"; ${COMMAND}" +COMMAND="export FP8_ENABLED=\"${FP8_ENABLED}\"; ${COMMAND}" +COMMAND="export FP8_RECIPE=\"${FP8_RECIPE}\"; ${COMMAND}" +COMMAND="export FP8_FORMAT=\"${FP8_FORMAT}\"; ${COMMAND}" + +COMMAND="export WANDB_API_KEY=\"${WANDB_API_KEY}\"; ${COMMAND}" +COMMAND="export HUGGING_FACE_HUB_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}" +COMMAND="export HF_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}" + +echo "Launching: ${WANDB_RUN_NAME}" + +# AUTO-CHAIN: resubmit on timeout. +trap ' + rc=$? + if [ "$rc" -eq 143 ] || [ "$rc" -eq 137 ]; then + echo "Killed by signal (rc=$rc) — assuming SLURM timeout, resubmitting..." + sbatch --dependency=singleton "${BASH_SOURCE[0]}" + elif [ "$rc" -eq 0 ]; then + echo "Clean exit — training finished, NOT resubmitting." + else + echo "Error exit (rc=$rc) — NOT resubmitting; investigate ${RESULTS_DIR}" + fi + ' EXIT + +srun \ + --output "${RESULTS_DIR}/slurm-%j-%n.out" \ + --error "${RESULTS_DIR}/error-%j-%n.out" \ + --container-image "${CONTAINER}" \ + --container-mounts "${MOUNTS}" \ + bash -c "${COMMAND}" diff --git a/bionemo-recipes/recipes/codonfm_native_te/tests/test_train.py b/bionemo-recipes/recipes/codonfm_native_te/tests/test_train.py index 45cf6501a2..45118458d1 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/codonfm_native_te/tests/test_train.py @@ -130,15 +130,31 @@ def test_sanity_convergence_fsdp2_fp8(tmp_path, recipe_path): assert final_loss < 5.0, f"Final loss {final_loss} is too high" -def test_sanity_convergence_fsdp2_fp32_master_weights(tmp_path, recipe_path): - """Test CodonFM with FP32 master weights.""" +def test_sanity_convergence_fsdp2_bf16_mixed(tmp_path, recipe_path): + """Test CodonFM with bf16-mixed precision (fp32 master weights + bf16 compute).""" with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): sanity_config = compose( config_name="L0_sanity", overrides=[ f"+wandb_init_args.dir={tmp_path}", f"checkpoint.ckpt_dir={tmp_path}", - "use_fp32_master_weights=true", + "precision=bf16-mixed", + ], + ) + + final_loss = main_fsdp2(sanity_config) + assert final_loss < 5.0, f"Final loss {final_loss} is too high" + + +def test_sanity_convergence_fsdp2_bf16(tmp_path, recipe_path): + """Test CodonFM with pure bf16 precision.""" + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "precision=bf16", ], ) diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py b/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py index b0afe4fe7c..0f970984f2 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py @@ -40,6 +40,20 @@ logger.setLevel(logging.INFO) +_VALID_PRECISIONS = ("fp32", "bf16", "bf16-mixed") + + +def precision_context(precision: str): + """Return a fresh autocast context for the given precision mode. + + For `bf16-mixed`, wraps forward in `torch.autocast(cuda, bf16)`. For `fp32` and `bf16`, + returns a nullcontext — params are already in the target dtype, no autocast needed. + """ + if precision == "bf16-mixed": + return torch.autocast(device_type="cuda", dtype=torch.bfloat16) + return nullcontext() + + @hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") def main(args: DictConfig) -> float | None: """Train CodonFM with TE layers using DDP. @@ -49,6 +63,9 @@ def main(args: DictConfig) -> float | None: """ logging.getLogger("httpx").setLevel(logging.WARNING) + if args.precision not in _VALID_PRECISIONS: + raise ValueError(f"precision must be one of {_VALID_PRECISIONS}, got {args.precision!r}") + # Initialize distributed configuration dist_config = DistributedConfig() logger.info("Initializing distributed training: %s", dist_config) @@ -56,11 +73,6 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device, timeout=timedelta(hours=1)) torch.cuda.set_device(dist_config.local_rank) - # DDP keeps a single param dtype per replica, so it can't emulate FSDP2's - # MixedPrecisionPolicy(param_dtype=bf16, reduce_dtype=fp32) split. Reject up-front. - if args.use_fp32_master_weights: - raise ValueError("FP32 master weights are not supported with DDP. Use train_fsdp2.py instead.") - perf_logger = None try: # Mirrors the FSDP2 device mesh — not strictly required for DDP, but keeps configs symmetric. @@ -127,9 +139,8 @@ def main(args: DictConfig) -> float | None: else: model = model.to(device) - # DDP replicates the full model on each GPU. Cast params to bf16 since the optimizer - # update happens in the same dtype as the params (no FP32 master weights here). - model = model.to(dtype=torch.bfloat16) + if args.precision == "bf16": + model = model.to(dtype=torch.bfloat16) # Assign layer names for debug API if args.quant_stats_config.enabled: @@ -190,11 +201,12 @@ def main(args: DictConfig) -> float | None: sync_context = nullcontext() if is_accumulation_boundary else model.no_sync() with sync_context: - # Forward pass - outputs = model(**batch) - - # Backward pass - scale loss by grad_acc_steps for proper gradient averaging - loss = outputs.loss / args.grad_acc_steps + # Forward pass under the precision-specific autocast context. + # backward inherits the cached autocast state — no need to wrap it. + with precision_context(args.precision): + outputs = model(**batch) + # Scale loss by grad_acc_steps for proper gradient averaging + loss = outputs.loss / args.grad_acc_steps loss.backward() # Log micro-batch data for accumulation metrics @@ -244,7 +256,8 @@ def main(args: DictConfig) -> float | None: val_batch = { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in val_batch.items() } - val_outputs = model(**val_batch) + with precision_context(args.precision): + val_outputs = model(**val_batch) val_loss_sum += val_outputs.loss.detach() val_batches_seen += 1 torch.distributed.all_reduce(val_loss_sum) diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py index ab2d945fbc..340b453a51 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py @@ -40,6 +40,39 @@ logger.setLevel(logging.INFO) +_VALID_PRECISIONS = ("fp32", "bf16", "bf16-mixed") +_VALID_REDUCE_TYPES = ("fp32", "bf16") +_PRECISION_TO_STORAGE_DTYPE = { + "fp32": torch.float32, + "bf16": torch.bfloat16, + "bf16-mixed": torch.float32, # master shards stored fp32; MP policy casts to bf16 at compute time +} + + +def _assert_fsdp_param_dtypes(model: torch.nn.Module, precision: str) -> None: + """Verify FSDP2 produced the expected param storage dtypes.""" + expected = _PRECISION_TO_STORAGE_DTYPE[precision] + for name, param in model.named_parameters(): + if param.dtype != expected: + raise RuntimeError( + f"FSDP2 precision={precision}: expected param storage {expected}, got {param.dtype} for {name}" + ) + logger.info("FSDP2 param dtype check OK (precision=%s, storage=%s)", precision, expected) + + +def _assert_fsdp_optimizer_state_dtypes(optimizer: torch.optim.Optimizer, precision: str) -> None: + """Verify optimizer moments are in the expected dtype after the first step.""" + expected = _PRECISION_TO_STORAGE_DTYPE[precision] + for param_state in optimizer.state.values(): + for key in ("exp_avg", "exp_avg_sq"): + t = param_state.get(key) + if t is not None and t.dtype != expected: + raise RuntimeError( + f"FSDP2 precision={precision}: optimizer state {key} dtype={t.dtype}, expected {expected}" + ) + logger.info("FSDP2 optimizer state dtype check OK (precision=%s, dtype=%s)", precision, expected) + + @hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") def main(args: DictConfig) -> float | None: """Train CodonFM with TE layers using FSDP2. @@ -49,6 +82,11 @@ def main(args: DictConfig) -> float | None: """ logging.getLogger("httpx").setLevel(logging.WARNING) + if args.precision not in _VALID_PRECISIONS: + raise ValueError(f"precision must be one of {_VALID_PRECISIONS}, got {args.precision!r}") + if args.grad_reduce_type not in _VALID_REDUCE_TYPES: + raise ValueError(f"grad_reduce_type must be one of {_VALID_REDUCE_TYPES}, got {args.grad_reduce_type!r}") + # Initialize distributed configuration dist_config = DistributedConfig() logger.info("Initializing distributed training: %s", dist_config) @@ -115,11 +153,17 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) - # Apply FSDP2 sharding with optional mixed precision policy - if args.use_fp32_master_weights: + # Apply FSDP2 sharding with precision-specific mixed precision policy. + # fp32 - no MP overrides; sharded params stay fp32. + # bf16 - no MP overrides; params get cast to bf16 after init_empty_weights below. + # bf16-mixed - fp32 master shards, MP policy casts to bf16 at compute time. grad_reduce_type + # controls reduce-scatter dtype (default fp32 is more conservative than PTL FSDP + # bf16-mixed, which reduces in bf16). + if args.precision == "bf16-mixed": + reduce_dtype = torch.float32 if args.grad_reduce_type == "fp32" else torch.bfloat16 mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, - reduce_dtype=torch.float32, + reduce_dtype=reduce_dtype, output_dtype=torch.bfloat16, cast_forward_inputs=False, ) @@ -133,6 +177,13 @@ def main(args: DictConfig) -> float | None: if args.use_meta_device: model.init_empty_weights() + # Pure bf16: cast sharded params (and downstream optimizer state) to bf16. Must happen + # after init_empty_weights so the cast catches the freshly-initialized values, not metadata. + if args.precision == "bf16": + model = model.to(dtype=torch.bfloat16) + + _assert_fsdp_param_dtypes(model, args.precision) + # Assign layer names for debug API if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) @@ -172,6 +223,7 @@ def main(args: DictConfig) -> float | None: # Training loop step = start_step micro_step = 0 # Gradient accumulation step counter + optimizer_state_asserted = False while step < args.num_train_steps: batches_in_epoch = 0 for batch in train_dataloader: @@ -199,6 +251,9 @@ def main(args: DictConfig) -> float | None: # Optimizer step optimizer.step() scheduler.step() + if not optimizer_state_asserted: + _assert_fsdp_optimizer_state_dtypes(optimizer, args.precision) + optimizer_state_asserted = True optimizer.zero_grad() perf_logger.log_step( From 8af15a1715f1c315475cb10d3c1097088b0c16c0 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Fri, 15 May 2026 20:12:41 +0000 Subject: [PATCH 04/17] Modify scripts to use GBS --- .../recipes/codonfm_native_te/run_1b.sh | 17 +++++++++++++++-- .../recipes/codonfm_native_te/slurm/1b.sh | 19 +++++++++++++++++-- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh b/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh index 11172df035..a19231ced6 100755 --- a/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh +++ b/bionemo-recipes/recipes/codonfm_native_te/run_1b.sh @@ -4,6 +4,9 @@ set -euo pipefail export CPATH=/usr/local/cuda/include export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas +export GLOBAL_BATCH_SIZE=1536 +export MICRO_BATCH_SIZE=4 + # Run config export CONFIG_NAME=encodon_1b export NPROC_PER_NODE=8 @@ -11,7 +14,6 @@ export DIST_STRATEGY=ddp # fsdp or ddp # Training export NUM_TRAIN_STEPS=100 -export MICRO_BATCH_SIZE=31 export NUM_WORKERS=1 export USE_SEQUENCE_PACKING=True # Precision mode: one of fp32, bf16, bf16-mixed. bf16-mixed matches the reference codonfm `--bf16`. @@ -53,7 +55,17 @@ if [ "${FP8_ENABLED}" = "True" ]; then else PRECISION_TAG="${PRECISION}" fi -export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_bs${MICRO_BATCH_SIZE}_${PRECISION_TAG}" +# Derive grad accumulation from GBS / (MBS * GPUs). Single-node run. +NUM_NODES=1 +TOTAL_PER_STEP=$(( MICRO_BATCH_SIZE * NPROC_PER_NODE * NUM_NODES )) +if [ "${TOTAL_PER_STEP}" -eq 0 ] || [ "$(( GLOBAL_BATCH_SIZE % TOTAL_PER_STEP ))" -ne 0 ]; then + echo "ERROR: GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE} must be a positive multiple of MICRO_BATCH_SIZE*NPROC_PER_NODE*NODES=${TOTAL_PER_STEP}" >&2 + exit 1 +fi +export GRAD_ACC_STEPS=$(( GLOBAL_BATCH_SIZE / TOTAL_PER_STEP )) +echo "Batch sizing: GBS=${GLOBAL_BATCH_SIZE}, MBS=${MICRO_BATCH_SIZE}, NPROC=${NPROC_PER_NODE}, NODES=${NUM_NODES}, GRAD_ACC=${GRAD_ACC_STEPS}" + +export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_gbs${GLOBAL_BATCH_SIZE}_mbs${MICRO_BATCH_SIZE}_ga${GRAD_ACC_STEPS}_${PRECISION_TAG}" # Pick training script based on distributed strategy. case "${DIST_STRATEGY}" in @@ -75,6 +87,7 @@ torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \ logger.frequency=${LOGGER_FREQUENCY} \ num_train_steps=${NUM_TRAIN_STEPS} \ dataset.micro_batch_size=${MICRO_BATCH_SIZE} \ + grad_acc_steps=${GRAD_ACC_STEPS} \ dataset.num_workers=${NUM_WORKERS} \ dataset.data_path=${DATASET_DATA_PATH} \ use_sequence_packing=${USE_SEQUENCE_PACKING} \ diff --git a/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh b/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh index 69189ec452..c05d346992 100755 --- a/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh +++ b/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh @@ -25,6 +25,9 @@ CODE_MOUNT="/workspace/bionemo" : "${HUGGING_FACE_HUB_TOKEN:?Set HUGGING_FACE_HUB_TOKEN in ~/.bash_profile}" : "${CLUSTER_NAME:?Set CLUSTER_NAME in ~/.bash_profile}" +export GLOBAL_BATCH_SIZE=1536 +export MICRO_BATCH_SIZE=4 + # Experiment parameters export CONFIG_NAME=encodon_1b export NPROC_PER_NODE=8 @@ -32,7 +35,6 @@ export DIST_STRATEGY=ddp # fsdp or ddp # Training export NUM_TRAIN_STEPS=1000 -export MICRO_BATCH_SIZE=31 export LEARNING_RATE=7.5e-5 export NUM_WORKERS=1 export USE_SEQUENCE_PACKING=False @@ -77,7 +79,17 @@ else BATCH_TYPE_TAG="bshd" fi -export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_bs${MICRO_BATCH_SIZE}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}" +# Derive grad accumulation from GBS / (MBS * GPUs). +TOTAL_GPUS=$(( NPROC_PER_NODE * SLURM_JOB_NUM_NODES )) +TOTAL_PER_STEP=$(( MICRO_BATCH_SIZE * TOTAL_GPUS )) +if [ "${TOTAL_PER_STEP}" -eq 0 ] || [ "$(( GLOBAL_BATCH_SIZE % TOTAL_PER_STEP ))" -ne 0 ]; then + echo "ERROR: GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE} must be a positive multiple of MICRO_BATCH_SIZE*NPROC_PER_NODE*NODES=${TOTAL_PER_STEP}" >&2 + exit 1 +fi +export GRAD_ACC_STEPS=$(( GLOBAL_BATCH_SIZE / TOTAL_PER_STEP )) +echo "Batch sizing: GBS=${GLOBAL_BATCH_SIZE}, MBS=${MICRO_BATCH_SIZE}, NPROC=${NPROC_PER_NODE}, NODES=${SLURM_JOB_NUM_NODES}, GRAD_ACC=${GRAD_ACC_STEPS}" + +export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_gbs${GLOBAL_BATCH_SIZE}_mbs${MICRO_BATCH_SIZE}_ga${GRAD_ACC_STEPS}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}" # Mounts RESULTS_DIR="${BASE_DIR}/results/${WANDB_RUN_NAME}" @@ -117,6 +129,7 @@ torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \ logger.frequency=${LOGGER_FREQUENCY} \ num_train_steps=${NUM_TRAIN_STEPS} \ dataset.micro_batch_size=${MICRO_BATCH_SIZE} \ + grad_acc_steps=${GRAD_ACC_STEPS} \ adamw_kwargs.lr=${LEARNING_RATE} \ dataset.num_workers=${NUM_WORKERS} \ dataset.data_path=/workspace/bionemo/data/processed_unfiltered/ \ @@ -151,7 +164,9 @@ COMMAND="export CONFIG_NAME=\"${CONFIG_NAME}\"; ${COMMAND}" COMMAND="export QUANT_STATS_ENABLED=\"${QUANT_STATS_ENABLED}\"; ${COMMAND}" COMMAND="export LOGGER_FREQUENCY=\"${LOGGER_FREQUENCY}\"; ${COMMAND}" COMMAND="export NUM_TRAIN_STEPS=\"${NUM_TRAIN_STEPS}\"; ${COMMAND}" +COMMAND="export GLOBAL_BATCH_SIZE=\"${GLOBAL_BATCH_SIZE}\"; ${COMMAND}" COMMAND="export MICRO_BATCH_SIZE=\"${MICRO_BATCH_SIZE}\"; ${COMMAND}" +COMMAND="export GRAD_ACC_STEPS=\"${GRAD_ACC_STEPS}\"; ${COMMAND}" COMMAND="export LEARNING_RATE=\"${LEARNING_RATE}\"; ${COMMAND}" COMMAND="export NUM_WORKERS=\"${NUM_WORKERS}\"; ${COMMAND}" COMMAND="export USE_SEQUENCE_PACKING=\"${USE_SEQUENCE_PACKING}\"; ${COMMAND}" From 30e06a89fa7b8ab1503401680e9615096ea23b49 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Fri, 15 May 2026 21:00:32 +0000 Subject: [PATCH 05/17] make script run in multi-node --- .../recipes/codonfm_native_te/slurm/1b.sh | 39 +++++++++++++++---- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh b/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh index c05d346992..2b0e23248d 100755 --- a/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh +++ b/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh @@ -11,6 +11,14 @@ #SBATCH --exclusive set -euxo pipefail +# Establish or inherit chain ID: manual launch picks SLURM_JOB_ID; trap-resubmit inherits via --export. +if [ -z "${CHAIN_ID:-}" ]; then + export CHAIN_ID="${SLURM_JOB_ID}" + echo "Starting NEW chain: CHAIN_ID=${CHAIN_ID}" +else + echo "Continuing chain ${CHAIN_ID} (current job ${SLURM_JOB_ID})" +fi + # ============================================================================ # Codon 1B # ============================================================================ @@ -26,7 +34,7 @@ CODE_MOUNT="/workspace/bionemo" : "${CLUSTER_NAME:?Set CLUSTER_NAME in ~/.bash_profile}" export GLOBAL_BATCH_SIZE=1536 -export MICRO_BATCH_SIZE=4 +export MICRO_BATCH_SIZE=96 # Experiment parameters export CONFIG_NAME=encodon_1b @@ -34,7 +42,7 @@ export NPROC_PER_NODE=8 export DIST_STRATEGY=ddp # fsdp or ddp # Training -export NUM_TRAIN_STEPS=1000 +export NUM_TRAIN_STEPS=100 export LEARNING_RATE=7.5e-5 export NUM_WORKERS=1 export USE_SEQUENCE_PACKING=False @@ -89,7 +97,7 @@ fi export GRAD_ACC_STEPS=$(( GLOBAL_BATCH_SIZE / TOTAL_PER_STEP )) echo "Batch sizing: GBS=${GLOBAL_BATCH_SIZE}, MBS=${MICRO_BATCH_SIZE}, NPROC=${NPROC_PER_NODE}, NODES=${SLURM_JOB_NUM_NODES}, GRAD_ACC=${GRAD_ACC_STEPS}" -export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_gbs${GLOBAL_BATCH_SIZE}_mbs${MICRO_BATCH_SIZE}_ga${GRAD_ACC_STEPS}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}" +export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_gbs${GLOBAL_BATCH_SIZE}_mbs${MICRO_BATCH_SIZE}_ga${GRAD_ACC_STEPS}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}_chain_${CHAIN_ID}" # Mounts RESULTS_DIR="${BASE_DIR}/results/${WANDB_RUN_NAME}" @@ -99,6 +107,10 @@ mkdir -p "${RESULTS_DIR}" "${CKPT_DIR}" MOUNTS="${DATA_DIR}:${CODE_MOUNT}/data,${RESULTS_DIR}:${CODE_MOUNT}/results,${CKPT_DIR}:${CODE_MOUNT}/checkpoints" +# Resolve head node on the host (scontrol is not available inside the container). +MASTER_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) +MASTER_PORT=29500 + read -r -d '' COMMAND <<'OUTER_EOF' || true set -euxo pipefail @@ -123,7 +135,14 @@ case "${DIST_STRATEGY}" in ;; esac -torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \ +torchrun \ + --nproc_per_node=${NPROC_PER_NODE} \ + --rdzv_id=${SLURM_JOB_ID} \ + --rdzv_backend=c10d \ + --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \ + --nnodes=${SLURM_JOB_NUM_NODES} \ + --node-rank=${SLURM_NODEID} \ + ${TRAIN_SCRIPT} \ --config-name ${CONFIG_NAME} \ quant_stats_config.enabled=${QUANT_STATS_ENABLED} \ logger.frequency=${LOGGER_FREQUENCY} \ @@ -182,6 +201,8 @@ COMMAND="export HYDRA_RUN_DIR=\"${HYDRA_RUN_DIR}\"; ${COMMAND}" COMMAND="export FP8_ENABLED=\"${FP8_ENABLED}\"; ${COMMAND}" COMMAND="export FP8_RECIPE=\"${FP8_RECIPE}\"; ${COMMAND}" COMMAND="export FP8_FORMAT=\"${FP8_FORMAT}\"; ${COMMAND}" +COMMAND="export MASTER_ADDR=\"${MASTER_ADDR}\"; ${COMMAND}" +COMMAND="export MASTER_PORT=\"${MASTER_PORT}\"; ${COMMAND}" COMMAND="export WANDB_API_KEY=\"${WANDB_API_KEY}\"; ${COMMAND}" COMMAND="export HUGGING_FACE_HUB_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}" @@ -193,12 +214,14 @@ echo "Launching: ${WANDB_RUN_NAME}" trap ' rc=$? if [ "$rc" -eq 143 ] || [ "$rc" -eq 137 ]; then - echo "Killed by signal (rc=$rc) — assuming SLURM timeout, resubmitting..." - sbatch --dependency=singleton "${BASH_SOURCE[0]}" + echo "Timed out (rc=$rc) — resubmitting chain ${CHAIN_ID}." + sbatch --dependency=singleton \ + --export=ALL,CHAIN_ID="${CHAIN_ID}" \ + "${BASH_SOURCE[0]}" elif [ "$rc" -eq 0 ]; then - echo "Clean exit — training finished, NOT resubmitting." + echo "Training finished cleanly — chain ${CHAIN_ID} ends." else - echo "Error exit (rc=$rc) — NOT resubmitting; investigate ${RESULTS_DIR}" + echo "Real error (rc=$rc) — chain ${CHAIN_ID} ends so you can investigate." fi ' EXIT From adfc4d75bc28dc087cb20be62363dbdda35bb3e9 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Sun, 17 May 2026 20:11:00 +0000 Subject: [PATCH 06/17] Revert back base image --- bionemo-recipes/recipes/codonfm_native_te/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/Dockerfile b/bionemo-recipes/recipes/codonfm_native_te/Dockerfile index e59e7fe2fd..bfd688de36 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/Dockerfile +++ b/bionemo-recipes/recipes/codonfm_native_te/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1.4 -FROM nvcr.io/nvidia/pytorch:26.02-py3 +FROM nvcr.io/nvidia/pytorch:26.04-py3 RUN apt-get update && apt-get install -y tmux npm From f3be34a46f07e70e02d9ae92cc91bfee7f3283d4 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Sun, 17 May 2026 21:15:12 +0000 Subject: [PATCH 07/17] Disable Fused attn when thd is true --- bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh b/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh index 2b0e23248d..c9aa6eb82f 100755 --- a/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh +++ b/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh @@ -121,6 +121,11 @@ echo "Job ID: ${SLURM_JOB_ID}" echo "Nodes: ${SLURM_JOB_NUM_NODES}" echo "=========================================" +# cuDNN fused-attn sub-backend 1 OOMs on Blackwell (sm_103) with THD+padding (TE 2.12 / cuDNN 9.19); force flash-attn varlen. +if [ "${USE_SEQUENCE_PACKING}" = "True" ]; then + export NVTE_FUSED_ATTN=0 +fi + # Pick training script based on distributed strategy. case "${DIST_STRATEGY}" in fsdp) From 52a66382cc046792dd467e4f892da0397a965b5e Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Sun, 17 May 2026 23:34:44 +0000 Subject: [PATCH 08/17] Add encodon 10b configuration --- .../models/codonfm/modeling_codonfm_te.py | 6 ++++ .../hydra_config/encodon_10b.yaml | 33 +++++++++++++++++++ .../codonfm_native_te/modeling_codonfm_te.py | 6 ++++ 3 files changed, 45 insertions(+) create mode 100644 bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_10b.yaml diff --git a/bionemo-recipes/models/codonfm/modeling_codonfm_te.py b/bionemo-recipes/models/codonfm/modeling_codonfm_te.py index 64618dcd62..baa13219ae 100644 --- a/bionemo-recipes/models/codonfm/modeling_codonfm_te.py +++ b/bionemo-recipes/models/codonfm/modeling_codonfm_te.py @@ -156,6 +156,12 @@ def __init__( "num_attention_heads": 32, "num_hidden_layers": 24, }, + "encodon_10b": { + "hidden_size": 5120, + "intermediate_size": 20480, + "num_attention_heads": 40, + "num_hidden_layers": 34, + }, } diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_10b.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_10b.yaml new file mode 100644 index 0000000000..d444aee252 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_10b.yaml @@ -0,0 +1,33 @@ +defaults: + - defaults + - _self_ + +# Training config +model_preset: encodon_10b +num_train_steps: 500_000 + +precision: bf16-mixed + +use_sequence_packing: true +dataset: + data_path: ??? + micro_batch_size: 4 + num_workers: 1 + max_seq_length: 2048 + +# WandB config +wandb_init_args: + name: "codonfm_native_te_10b" + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 2_000 + num_training_steps: 500_000 + +checkpoint: + ckpt_dir: ??? + resume_from_checkpoint: true + save_every_n_steps: 1_000 + +logger: + frequency: 100 diff --git a/bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py b/bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py index 5b29b1254b..c2b80dab8f 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py +++ b/bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py @@ -162,6 +162,12 @@ def __init__( "num_attention_heads": 32, "num_hidden_layers": 24, }, + "encodon_10b": { + "hidden_size": 5120, + "intermediate_size": 20480, + "num_attention_heads": 40, + "num_hidden_layers": 34, + }, } From 6d810c4da5a24dbb8d672ccae3cfc3925cdb6017 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Mon, 18 May 2026 18:50:06 +0000 Subject: [PATCH 09/17] Add support for saving intermediate models --- .../recipes/codonfm_native_te/hydra_config/defaults.yaml | 1 + bionemo-recipes/recipes/codonfm_native_te/train_ddp.py | 7 +++++++ bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py | 7 +++++++ 3 files changed, 15 insertions(+) diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml index 16333758a2..d7686d768a 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml @@ -60,6 +60,7 @@ lr_scheduler_kwargs: checkpoint: ckpt_dir: ??? save_final_model: true + save_final_model_with_checkpoint: false resume_from_checkpoint: true save_every_n_steps: 1_000 max_checkpoints: 5 diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py b/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py index 0f970984f2..3ad59cbd2d 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/codonfm_native_te/train_ddp.py @@ -241,6 +241,13 @@ def main(args: DictConfig) -> float | None: dist_config=dist_config, max_checkpoints=args.checkpoint.max_checkpoints, ) + if args.checkpoint.save_final_model_with_checkpoint: + save_final_model_ddp( + model=model, + config=config, + save_directory=ckpt_path / f"step_{step}" / "final_model", + dist_config=dist_config, + ) if val_dataloader is not None and step > 0 and step % args.validation.eval_interval == 0: model.eval() diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py index 340b453a51..6a99a6bb8e 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py @@ -273,6 +273,13 @@ def main(args: DictConfig) -> float | None: dist_config=dist_config, max_checkpoints=args.checkpoint.max_checkpoints, ) + if args.checkpoint.save_final_model_with_checkpoint: + save_final_model_fsdp2( + model=model, + config=config, + save_directory=ckpt_path / f"step_{step}" / "final_model", + dist_config=dist_config, + ) if val_dataloader is not None and step > 0 and step % args.validation.eval_interval == 0: model.eval() From 51a0125a5f70cb0e273ce6005d2717f825bb3f20 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Mon, 18 May 2026 20:24:17 +0000 Subject: [PATCH 10/17] Add support for extract embeddings --- .../codonfm_native_te/extract_embeddings.py | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 bionemo-recipes/recipes/codonfm_native_te/extract_embeddings.py diff --git a/bionemo-recipes/recipes/codonfm_native_te/extract_embeddings.py b/bionemo-recipes/recipes/codonfm_native_te/extract_embeddings.py new file mode 100644 index 0000000000..d5f283e49d --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/extract_embeddings.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extract CLS-token embeddings from a pretrained CodonFM model. + +Usage: + python extract_embeddings.py \ + --model-name-or-path nvidia/NV-CodonFM-Encodon-TE-Cdwt-1B-v1 \ + --input seqs.fasta \ + --output embeddings.npz +""" + +import argparse +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import torch +from modeling_codonfm_te import CodonFMForMaskedLM +from tokenizer import CodonTokenizer + + +@dataclass +class EmbeddingOutput: + """Container for extracted embeddings and the corresponding record ids.""" + + embeddings: np.ndarray + ids: np.ndarray | None = None + + +def read_fasta(path: Path) -> list[tuple[str, str]]: + """Parse a FASTA file into a list of (id, sequence) tuples.""" + records: list[tuple[str, str]] = [] + seq_id: str | None = None + seq_parts: list[str] = [] + with open(path) as f: + for raw_line in f: + line = raw_line.strip() + if not line: + continue + if line.startswith(">"): + if seq_id is not None: + records.append((seq_id, "".join(seq_parts))) + seq_id = line[1:].split()[0] if len(line) > 1 else "" + seq_parts = [] + else: + seq_parts.append(line) + if seq_id is not None: + records.append((seq_id, "".join(seq_parts))) + return records + + +def _tokenize_batch( + tokenizer: CodonTokenizer, + sequences: list[str], + max_seq_length: int, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """Encode and right-pad a batch of DNA sequences to the longest seq in the batch.""" + encoded: list[list[int]] = [] + for s in sequences: + ids = tokenizer.encode(s, add_special_tokens=True) + if len(ids) > max_seq_length: + ids = [*ids[: max_seq_length - 1], tokenizer.sep_token_id] + encoded.append(ids) + + pad_to = max(len(ids) for ids in encoded) + input_ids = np.full((len(encoded), pad_to), tokenizer.pad_token_id, dtype=np.int64) + attention_mask = np.zeros((len(encoded), pad_to), dtype=np.int64) + for i, ids in enumerate(encoded): + input_ids[i, : len(ids)] = ids + attention_mask[i, : len(ids)] = 1 + + return ( + torch.from_numpy(input_ids).to(device), + torch.from_numpy(attention_mask).to(device), + ) + + +def extract_embeddings( + model: CodonFMForMaskedLM, + tokenizer: CodonTokenizer, + records: list[tuple[str, str]], + batch_size: int, + max_seq_length: int, + device: torch.device | str = "cuda", +) -> EmbeddingOutput: + """Return CLS-token embeddings from the final hidden layer for each record.""" + device = torch.device(device) + ids = [r[0] for r in records] + seqs = [r[1] for r in records] + + all_embeds: list[np.ndarray] = [] + for i in range(0, len(seqs), batch_size): + batch_seqs = seqs[i : i + batch_size] + input_ids, attention_mask = _tokenize_batch(tokenizer, batch_seqs, max_seq_length, device) + + with torch.no_grad(): + output = model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + cls = output.hidden_states[-1][:, 0, :] + if cls.dtype != torch.float32: + cls = cls.float() + all_embeds.append(cls.cpu().numpy()) + + embeddings = np.concatenate(all_embeds, axis=0) if all_embeds else np.zeros((0, 0), dtype=np.float32) + return EmbeddingOutput(embeddings=embeddings, ids=np.array(ids)) + + +def main() -> None: + """CLI entrypoint.""" + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--model-name-or-path", + required=True, + help="Hugging Face Hub tag or local directory with a CodonFM checkpoint.", + ) + parser.add_argument("--input", required=True, type=Path, help="Path to a FASTA file with DNA sequences.") + parser.add_argument("--output", type=Path, default=None, help="Optional .npz path to save embeddings and ids.") + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--max-seq-length", type=int, default=2048) + parser.add_argument("--device", default="cuda") + args = parser.parse_args() + + model = CodonFMForMaskedLM.from_pretrained(args.model_name_or_path).to(args.device).eval() + tokenizer = CodonTokenizer() + + records = read_fasta(args.input) + if not records: + raise ValueError(f"No FASTA records found in {args.input}") + + out = extract_embeddings( + model, + tokenizer, + records, + batch_size=args.batch_size, + max_seq_length=args.max_seq_length, + device=args.device, + ) + + if args.output is not None: + np.savez(args.output, embeddings=out.embeddings, ids=out.ids) + print(f"Saved {out.embeddings.shape[0]} embeddings of dim {out.embeddings.shape[1]} to {args.output}") + else: + print(f"Extracted {out.embeddings.shape[0]} embeddings of dim {out.embeddings.shape[1]}") + + +if __name__ == "__main__": + main() From 8ada8680b1fd2ea4ebd8a6b8664111e694f661bd Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Tue, 19 May 2026 16:25:05 +0000 Subject: [PATCH 11/17] Added evaluation data generation ribonn --- .../evaluation/ribonn/preprocess.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/preprocess.py diff --git a/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/preprocess.py b/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/preprocess.py new file mode 100644 index 0000000000..a0fbc89c16 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/preprocess.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download and preprocess the RiboNN translation efficiency dataset. + +Extracted verbatim from notebooks/4-EnCodon-Downstream-Task-riboNN.ipynb (section 3). +""" + +import os +import urllib.request +from pathlib import Path + +import polars as pl + + +# Configurable dataset path +data_path = "/data/validation/processed/data_with_human_TE_cellline_all_NA_plain.csv" + +# Source URL for the TE dataset +te_dataset_url = "https://raw.githubusercontent.com/CenikLab/TE_classic_ML/refs/heads/main/data/data_with_human_TE_cellline_all_NA_plain.csv" + +# Ensure parent directory exists +Path(os.path.dirname(data_path)).mkdir(parents=True, exist_ok=True) + +# Download if missing +if not os.path.exists(data_path): + print(f"Downloading TE dataset to {data_path} ...") + urllib.request.urlretrieve(te_dataset_url, data_path) + print("Download complete.") +else: + print(f"Found existing dataset at {data_path}.") + + +# Slice the transcript sequence into CDS / 5'UTR / 3'UTR using utr5_size and cds_size, +# and add a row index column 'id'. +data = pl.read_csv(data_path, separator="\t") +data = data.with_columns( + [ + pl.struct(["utr5_size", "cds_size", "tx_sequence"]) + .map_elements( + lambda row: row["tx_sequence"][row["utr5_size"] : row["utr5_size"] + row["cds_size"]], return_dtype=pl.Utf8 + ) + .alias("cds_sequence"), + pl.struct(["utr5_size", "tx_sequence"]) + .map_elements(lambda row: row["tx_sequence"][: row["utr5_size"]], return_dtype=pl.Utf8) + .alias("utr5_sequence"), + pl.struct(["utr5_size", "cds_size", "tx_sequence"]) + .map_elements(lambda row: row["tx_sequence"][row["utr5_size"] + row["cds_size"] :], return_dtype=pl.Utf8) + .alias("utr3_sequence"), + ] +).with_row_index("id") +output_path = data_path[:-4] + ".processed.csv" +data.write_csv(output_path) + + +# Load processed RiboNN dataset and report basic statistics on the mean_te target. +data_loaded = False +if os.path.exists(output_path): + try: + data = pl.read_csv(output_path) + print(f"✅ Loaded {len(data)} sequences from: {output_path}") + print(f"Shape: {data.shape}") + print(f"Key columns: {[col for col in ['id', 'cds_sequence', 'mean_te', 'fold'] if col in data.columns]}") + + data_loaded = True + except Exception as e: + print(f"Failed to load {output_path}: {e}") + + # Show basic statistics + te_stats = data.select( + [ + pl.col("mean_te").mean().alias("mean"), + pl.col("mean_te").std().alias("std"), + pl.col("mean_te").min().alias("min"), + pl.col("mean_te").max().alias("max"), + ] + ) + print("\nTranslation Efficiency stats:") + print(f" Mean: {te_stats['mean'][0]:.4f}") + print(f" Range: [{te_stats['min'][0]:.4f}, {te_stats['max'][0]:.4f}]") + data_loaded = True From e30de367d2dc7c1007e478305eb93d27de02e538 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Tue, 19 May 2026 16:30:32 +0000 Subject: [PATCH 12/17] add fixes to preprocess.py --- .../evaluation/ribonn/README.md | 71 +++++ .../evaluation/ribonn/evaluate_rf.py | 258 ++++++++++++++++++ .../evaluation/ribonn/preprocess.py | 101 ++++--- 3 files changed, 378 insertions(+), 52 deletions(-) create mode 100644 bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/README.md create mode 100644 bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/evaluate_rf.py diff --git a/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/README.md b/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/README.md new file mode 100644 index 0000000000..ef65e3f54d --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/README.md @@ -0,0 +1,71 @@ +# RiboNN Translation Efficiency Evaluation + +This directory contains scripts that reproduce the RiboNN translation-efficiency (TE) +downstream evaluation from the EnCodon paper, using a Hugging Face CodonFM checkpoint +from this recipe. + +The task: given a coding sequence (CDS) from a human mRNA, predict its translation +efficiency. We do this by extracting CLS-token embeddings from a frozen pretrained +CodonFM model and training a `RandomForestRegressor` on top, evaluated with +leave-one-fold-out cross-validation against the public RiboNN labels. + +Data source: [`CenikLab/TE_classic_ML`](https://github.com/CenikLab/TE_classic_ML) +(`data_with_human_TE_cellline_all_NA_plain.csv`), ~11k transcripts with `mean_te` +labels and a precomputed 10-fold split. + +Reference: + +> Zheng, Dinghai, et al. *Predicting the translation efficiency of messenger RNA in +> mammalian cells.* Nature Biotechnology (2025): 1-14. + +## Scripts + +- `preprocess.py` — Downloads the RiboNN TSV, slices the transcript into + CDS / 5'UTR / 3'UTR using `utr5_size` and `cds_size`, adds a row-index `id`, and + writes `ribonn_cds.parquet` containing only the columns the downstream evaluation + needs (`id`, `cds_sequence`, `mean_te`, `fold`). +- `evaluate_rf.py` — Loads `ribonn_cds.parquet`, extracts CLS embeddings from a HF + CodonFM checkpoint (reusing `extract_embeddings.py` from the recipe root), runs + leave-one-fold-out CV with a `RandomForestRegressor`, and writes per-fold metrics + to `metrics.csv` and aggregate stats to `metrics_summary.csv`. Embeddings are + cached to a `.npz` and validated on load (ids, targets, folds, sequence hash, and + `max_seq_length` must all match), so re-running only re-trains the RF unless an + input actually changed. + +## Usage + +Run from inside this directory. + +### 1. Preprocess the dataset + +```bash +python preprocess.py +``` + +Downloads `data_with_human_TE_cellline_all_NA_plain.csv` next to the script if not +already present and writes `ribonn_cds.parquet`. + +### 2. Run the evaluation + +```bash +python evaluate_rf.py --model-name-or-path nvidia/NV-CodonFM-Encodon-TE-Cdwt-1B-v1 +``` + +Useful flags: + +- `--demo-size N` — stratified-sample `N` rows by `fold` for a quick smoke run + (the original notebook uses `--demo-size 500`). Omit for the full ~11k dataset. +- `--batch-size`, `--device` — passed through to the embedding extractor + (defaults: `16`, `cuda`). +- `--output-dir` — where to write the embeddings cache and metrics CSVs + (default: this directory). +- `--force-extract` — re-extract embeddings even if a valid cache exists. +- `--seed` — RNG seed for subsampling and the Random Forest (default: `42`). + +### Outputs + +- `embeddings__n.npz` — cached CLS embeddings + the inputs they + were extracted from (used to validate the cache on subsequent runs). +- `metrics.csv` — one row per fold with `fold, r2, pearson_r, mse, rmse`. +- `metrics_summary.csv` — single-row aggregate: `mean_r2, std_r2, mean_pearson_r, std_pearson_r, mean_rmse` (mean RMSE follows the notebook convention, + `sqrt(mean(MSE))`, not `mean(sqrt(MSE))`). diff --git a/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/evaluate_rf.py b/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/evaluate_rf.py new file mode 100644 index 0000000000..abcd2f3bb7 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/evaluate_rf.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reproduce the RiboNN Random Forest TE-prediction eval against a HF CodonFM model. + +Loads CDS sequences + labels from `ribonn_cds.parquet` (produced by `preprocess.py`), +extracts CLS-token embeddings from a Hugging Face CodonFM checkpoint, runs leave-one-fold-out +cross-validation with a RandomForestRegressor, and writes per-fold metrics to a CSV. + +Embeddings are cached to disk and validated on load against the current ids, targets, folds, +sequences, and max_seq_length, so re-running the script only re-tunes the RF without +re-extracting embeddings — unless any of those inputs change, in which case the cache is +treated as stale and re-extracted. + +Usage: + python evaluate_rf.py \ + --model-name-or-path nvidia/NV-CodonFM-Encodon-TE-Cdwt-1B-v1 \ + --demo-size 500 +""" + +import argparse +import hashlib +import sys +from pathlib import Path + +import numpy as np +import polars as pl +from scipy.stats import pearsonr +from sklearn.ensemble import RandomForestRegressor +from sklearn.metrics import mean_squared_error, r2_score +from sklearn.model_selection import train_test_split + + +# `extract_embeddings`, `CodonFMForMaskedLM`, and `CodonTokenizer` live at the recipe root. +RECIPE_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(RECIPE_ROOT)) + +from extract_embeddings import extract_embeddings # noqa: E402 +from modeling_codonfm_te import CodonFMConfig, CodonFMForMaskedLM # noqa: E402 +from tokenizer import CodonTokenizer # noqa: E402 + + +SCRIPT_DIR = Path(__file__).parent + + +def _slugify(s: str) -> str: + """Make a model name/path safe to embed in a filename.""" + return s.replace("/", "__").replace(":", "_") + + +def load_or_extract_embeddings( + df: pl.DataFrame, + model_name_or_path: str, + cache_path: Path, + batch_size: int, + device: str, + force_extract: bool, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Return (embeddings, ids, targets, folds), loading from cache when it is still valid.""" + df_pd = df.to_pandas() + ids = df_pd["id"].to_numpy() + targets = df_pd["mean_te"].to_numpy() + folds = df_pd["fold"].to_numpy() + raw_sequences = df_pd["cds_sequence"].tolist() + seqs_hash = hashlib.sha256("\n".join(raw_sequences).encode()).hexdigest() + + max_seq_length = CodonFMConfig.from_pretrained(model_name_or_path).max_position_embeddings + + if cache_path.exists() and not force_extract: + z = np.load(cache_path, allow_pickle=False) + cache_valid = ( + "max_seq_length" in z.files + and int(z["max_seq_length"]) == max_seq_length + and "seqs_hash" in z.files + and str(z["seqs_hash"].item()) == seqs_hash + and np.array_equal(z["ids"], ids) + and np.array_equal(z["targets"], targets) + and np.array_equal(z["folds"], folds) + ) + if cache_valid: + print(f"Loading cached embeddings from {cache_path}") + return z["embeddings"], z["ids"], z["targets"], z["folds"] + print(f"⚠️ Cache at {cache_path} is stale; re-extracting.") + + # CodonTokenizer (DNA mode) does not normalise 'U' — unhandled 'U' codons would tokenize + # to . Match the notebook by uppercasing and replacing U->T before encoding. + sequences = [s.upper().replace("U", "T") for s in raw_sequences] + records = list(zip([str(i) for i in ids], sequences)) + + print(f"Loading model from {model_name_or_path}") + model = CodonFMForMaskedLM.from_pretrained(model_name_or_path).to(device).eval() + tokenizer = CodonTokenizer() + + print(f"Extracting embeddings for {len(records)} sequences...") + out = extract_embeddings( + model, + tokenizer, + records, + batch_size=batch_size, + max_seq_length=max_seq_length, + device=device, + ) + + cache_path.parent.mkdir(parents=True, exist_ok=True) + np.savez( + cache_path, + embeddings=out.embeddings, + ids=ids, + targets=targets, + folds=folds, + max_seq_length=np.array(max_seq_length), + seqs_hash=np.array(seqs_hash), + ) + print(f"✅ Cached embeddings to {cache_path}") + return out.embeddings, ids, targets, folds + + +def cross_validate( + embeddings: np.ndarray, + targets: np.ndarray, + folds: np.ndarray, + seed: int, +) -> list[dict]: + """Run leave-one-fold-out CV with RandomForestRegressor; return per-fold metrics.""" + rows: list[dict] = [] + for fold in np.unique(folds): + train_mask = folds != fold + test_mask = ~train_mask + x_train, x_test = embeddings[train_mask], embeddings[test_mask] + y_train, y_test = targets[train_mask], targets[test_mask] + + rf = RandomForestRegressor( + n_estimators=500, + max_depth=15, + min_samples_split=2, + random_state=seed, + n_jobs=-1, + ) + rf.fit(x_train, y_train) + y_pred = rf.predict(x_test) + + r2 = r2_score(y_test, y_pred) + pearson_r, _ = pearsonr(y_test, y_pred) + mse = mean_squared_error(y_test, y_pred) + rmse = float(np.sqrt(mse)) + + print(f"Fold {fold}: R² = {r2:.4f}, r = {pearson_r:.4f}, RMSE = {rmse:.4f}") + rows.append( + {"fold": int(fold), "r2": float(r2), "pearson_r": float(pearson_r), "mse": float(mse), "rmse": rmse} + ) + + return rows + + +def main() -> None: + """CLI entrypoint for the RiboNN RF evaluation.""" + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--model-name-or-path", + required=True, + help="Hugging Face Hub tag or local directory with a CodonFM checkpoint.", + ) + parser.add_argument( + "--data-path", + type=Path, + default=SCRIPT_DIR / "ribonn_cds.parquet", + help="Parquet file produced by preprocess.py (default: ribonn_cds.parquet next to this script).", + ) + parser.add_argument( + "--demo-size", + type=int, + default=None, + help="If set, stratified-sample this many rows by 'fold'. Notebook uses 500.", + ) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--device", default="cuda") + parser.add_argument("--output-dir", type=Path, default=SCRIPT_DIR) + parser.add_argument( + "--force-extract", + action="store_true", + help="Re-extract embeddings even if a cached file exists.", + ) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + + df = pl.read_parquet(args.data_path) + print(f"Loaded {len(df)} rows from {args.data_path}") + + # Subsample stratified by fold (mirrors notebook section 4). + if args.demo_size is not None and args.demo_size < len(df): + print(f"=== SUBSAMPLING DATA to {args.demo_size} rows ===") + sample_fraction = args.demo_size / len(df) + _, sampled_pd = train_test_split( + df.to_pandas(), + test_size=sample_fraction, + stratify=df["fold"].to_numpy(), + random_state=args.seed, + ) + df = pl.from_pandas(sampled_pd) + + n_tag = f"n{len(df)}" + cache_path = args.output_dir / f"embeddings_{_slugify(args.model_name_or_path)}_{n_tag}.npz" + + embeddings, _ids, targets, folds = load_or_extract_embeddings( + df=df, + model_name_or_path=args.model_name_or_path, + cache_path=cache_path, + batch_size=args.batch_size, + device=args.device, + force_extract=args.force_extract, + ) + print(f"Embeddings shape: {embeddings.shape}") + + print("\n=== TRAINING RANDOM FOREST ===") + rows = cross_validate(embeddings, targets, folds, seed=args.seed) + + metrics_path = args.output_dir / "metrics.csv" + pl.DataFrame(rows).write_csv(metrics_path) + print(f"\n✅ Wrote per-fold metrics to {metrics_path}") + + # Summary stats — mirrors notebook section 5: Mean RMSE uses sqrt(mean(MSE)), + # not mean(sqrt(MSE)). + r2 = np.array([r["r2"] for r in rows]) + pr = np.array([r["pearson_r"] for r in rows]) + mse = np.array([r["mse"] for r in rows]) + summary = { + "mean_r2": float(r2.mean()), + "std_r2": float(r2.std()), + "mean_pearson_r": float(pr.mean()), + "std_pearson_r": float(pr.std()), + "mean_rmse": float(np.sqrt(mse.mean())), + } + summary_path = args.output_dir / "metrics_summary.csv" + pl.DataFrame([summary]).write_csv(summary_path) + + print("\n=== CROSS-VALIDATION RESULTS ===") + print(f"Mean R²: {summary['mean_r2']:.4f} ± {summary['std_r2']:.4f}") + print(f"Mean Pearson r: {summary['mean_pearson_r']:.4f} ± {summary['std_pearson_r']:.4f}") + print(f"Mean RMSE: {summary['mean_rmse']:.4f}") + print(f"✅ Wrote summary metrics to {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/preprocess.py b/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/preprocess.py index a0fbc89c16..f8db7e8de9 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/preprocess.py +++ b/bionemo-recipes/recipes/codonfm_native_te/evaluation/ribonn/preprocess.py @@ -18,65 +18,54 @@ Extracted verbatim from notebooks/4-EnCodon-Downstream-Task-riboNN.ipynb (section 3). """ -import os import urllib.request from pathlib import Path import polars as pl -# Configurable dataset path -data_path = "/data/validation/processed/data_with_human_TE_cellline_all_NA_plain.csv" +SCRIPT_DIR = Path(__file__).parent # Source URL for the TE dataset -te_dataset_url = "https://raw.githubusercontent.com/CenikLab/TE_classic_ML/refs/heads/main/data/data_with_human_TE_cellline_all_NA_plain.csv" - -# Ensure parent directory exists -Path(os.path.dirname(data_path)).mkdir(parents=True, exist_ok=True) - -# Download if missing -if not os.path.exists(data_path): - print(f"Downloading TE dataset to {data_path} ...") - urllib.request.urlretrieve(te_dataset_url, data_path) - print("Download complete.") -else: - print(f"Found existing dataset at {data_path}.") - - -# Slice the transcript sequence into CDS / 5'UTR / 3'UTR using utr5_size and cds_size, -# and add a row index column 'id'. -data = pl.read_csv(data_path, separator="\t") -data = data.with_columns( - [ - pl.struct(["utr5_size", "cds_size", "tx_sequence"]) - .map_elements( - lambda row: row["tx_sequence"][row["utr5_size"] : row["utr5_size"] + row["cds_size"]], return_dtype=pl.Utf8 - ) - .alias("cds_sequence"), - pl.struct(["utr5_size", "tx_sequence"]) - .map_elements(lambda row: row["tx_sequence"][: row["utr5_size"]], return_dtype=pl.Utf8) - .alias("utr5_sequence"), - pl.struct(["utr5_size", "cds_size", "tx_sequence"]) - .map_elements(lambda row: row["tx_sequence"][row["utr5_size"] + row["cds_size"] :], return_dtype=pl.Utf8) - .alias("utr3_sequence"), - ] -).with_row_index("id") -output_path = data_path[:-4] + ".processed.csv" -data.write_csv(output_path) - - -# Load processed RiboNN dataset and report basic statistics on the mean_te target. -data_loaded = False -if os.path.exists(output_path): - try: - data = pl.read_csv(output_path) - print(f"✅ Loaded {len(data)} sequences from: {output_path}") - print(f"Shape: {data.shape}") - print(f"Key columns: {[col for col in ['id', 'cds_sequence', 'mean_te', 'fold'] if col in data.columns]}") - - data_loaded = True - except Exception as e: - print(f"Failed to load {output_path}: {e}") +TE_DATASET_URL = "https://raw.githubusercontent.com/CenikLab/TE_classic_ML/refs/heads/main/data/data_with_human_TE_cellline_all_NA_plain.csv" + + +def main() -> None: + """Download the RiboNN TE dataset, slice CDS/UTR regions, and write a parquet handoff.""" + # Configurable dataset path + data_path = SCRIPT_DIR / "data_with_human_TE_cellline_all_NA_plain.csv" + + # Download if missing + if not data_path.exists(): + print(f"Downloading TE dataset to {data_path} ...") + urllib.request.urlretrieve(TE_DATASET_URL, data_path) + print("Download complete.") + else: + print(f"Found existing dataset at {data_path}.") + + # Slice the transcript sequence into CDS / 5'UTR / 3'UTR using utr5_size and cds_size, + # and add a row index column 'id'. + data = pl.read_csv(data_path, separator="\t") + data = data.with_columns( + [ + pl.struct(["utr5_size", "cds_size", "tx_sequence"]) + .map_elements( + lambda row: row["tx_sequence"][row["utr5_size"] : row["utr5_size"] + row["cds_size"]], + return_dtype=pl.Utf8, + ) + .alias("cds_sequence"), + pl.struct(["utr5_size", "tx_sequence"]) + .map_elements(lambda row: row["tx_sequence"][: row["utr5_size"]], return_dtype=pl.Utf8) + .alias("utr5_sequence"), + pl.struct(["utr5_size", "cds_size", "tx_sequence"]) + .map_elements(lambda row: row["tx_sequence"][row["utr5_size"] + row["cds_size"] :], return_dtype=pl.Utf8) + .alias("utr3_sequence"), + ] + ).with_row_index("id") + + print(f"✅ Loaded {len(data)} sequences") + print(f"Shape: {data.shape}") + print(f"Key columns: {[col for col in ['id', 'cds_sequence', 'mean_te', 'fold'] if col in data.columns]}") # Show basic statistics te_stats = data.select( @@ -90,4 +79,12 @@ print("\nTranslation Efficiency stats:") print(f" Mean: {te_stats['mean'][0]:.4f}") print(f" Range: [{te_stats['min'][0]:.4f}, {te_stats['max'][0]:.4f}]") - data_loaded = True + + # Write only the columns needed by the downstream embedding-extraction script. + output_path = SCRIPT_DIR / "ribonn_cds.parquet" + data.select(["id", "cds_sequence", "mean_te", "fold"]).write_parquet(output_path) + print(f"\n✅ Wrote {output_path}") + + +if __name__ == "__main__": + main() From 6a9c7eae87bff06d4a98230ed5c83b1d050f9c53 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Tue, 19 May 2026 19:57:43 +0000 Subject: [PATCH 13/17] Add mrfp script --- .../evaluation/mrfp/preprocess.py | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 bionemo-recipes/recipes/codonfm_native_te/evaluation/mrfp/preprocess.py diff --git a/bionemo-recipes/recipes/codonfm_native_te/evaluation/mrfp/preprocess.py b/bionemo-recipes/recipes/codonfm_native_te/evaluation/mrfp/preprocess.py new file mode 100644 index 0000000000..c3baf95f25 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/evaluation/mrfp/preprocess.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download and preprocess the mRFP Expression dataset. + +Extracted verbatim from notebooks/5-EnCodon-Downstream-Task-mRFP-expression.ipynb (section 3), +with the column-normalization step (normally done by data_scripts/preprocess_validation.py) +inlined so this script is self-contained. +""" + +import re +import urllib.request +from pathlib import Path + +import polars as pl + + +SCRIPT_DIR = Path(__file__).parent + +MRFP_DATASET_URL = ( + "https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/master/" + "benchmarks/CodonBERT/data/fine-tune/mRFP_Expression.csv" +) + +RAW_REF_SEQ_COL = "Sequence" + + +def _camel_to_snake(name: str) -> str: + """Convert a column name to snake_case, matching preprocess_validation.py.""" + name = name.replace(" ", "_") + return re.sub(r"(? None: + """Download the mRFP Expression CSV, normalize columns, and write a parquet handoff.""" + raw_path = SCRIPT_DIR / "mRFP_Expression.csv" + + if not raw_path.exists(): + print(f"Downloading mRFP Expression dataset to {raw_path} ...") + urllib.request.urlretrieve(MRFP_DATASET_URL, raw_path) + print("Download complete.") + else: + print(f"Found existing dataset at {raw_path}.") + + data = pl.read_csv(raw_path) + + rename_map = {col: _camel_to_snake(col) for col in data.columns} + rename_map[RAW_REF_SEQ_COL] = "ref_seq" + data = data.rename(rename_map) + + data = data.with_row_index("id") + data = data.with_columns(pl.col("id").cast(pl.Utf8)) + + before = len(data) + data = data.filter(pl.col("ref_seq").str.len_chars() % 3 == 0) + dropped = before - len(data) + if dropped: + print(f"Dropped {dropped} rows whose ref_seq length is not divisible by 3.") + + if data.is_empty(): + raise ValueError("Output dataframe is empty after filtering.") + if data["ref_seq"].null_count() > 0: + raise ValueError("ref_seq column contains nulls.") + + print(f"Loaded {len(data)} sequences") + print(f"Shape: {data.shape}") + print(f"Columns: {data.columns}") + print(f"Split counts: {data['split'].value_counts().to_dict(as_series=False)}") + + value_stats = data.select( + [ + pl.col("value").mean().alias("mean"), + pl.col("value").std().alias("std"), + pl.col("value").min().alias("min"), + pl.col("value").max().alias("max"), + ] + ) + print("\nmRFP expression stats:") + print(f" Mean: {value_stats['mean'][0]:.4f}") + print(f" Range: [{value_stats['min'][0]:.4f}, {value_stats['max'][0]:.4f}]") + + output_path = SCRIPT_DIR / "mrfp_expression.parquet" + data.select(["id", "ref_seq", "value", "dataset", "split"]).write_parquet(output_path) + print(f"\nWrote {output_path}") + + +if __name__ == "__main__": + main() From baceec33ff70d382a9b125a3fdc9a1c628538b6a Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Wed, 20 May 2026 02:03:15 +0000 Subject: [PATCH 14/17] Add script to plot benchmark numbers --- .../codonfm_native_te/benchmark_plots.py | 193 ++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 bionemo-recipes/recipes/codonfm_native_te/benchmark_plots.py diff --git a/bionemo-recipes/recipes/codonfm_native_te/benchmark_plots.py b/bionemo-recipes/recipes/codonfm_native_te/benchmark_plots.py new file mode 100644 index 0000000000..216aba87f7 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/benchmark_plots.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark plots for distributed training step times. + +Fits Amdahl's-law model t(N) = a + b/N to each configuration, +extrapolates to 32 and 64 nodes, and plots: + 1. Step time vs nodes (measured + extrapolated), per config + 2. Bar chart of total training time (1M steps) per (config, node count) + +Run with the project's venv: + /Users/balvisio/.venv/bin/python benchmark_plots.py +""" + +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np +from scipy.optimize import curve_fit + + +TOTAL_STEPS = 1_000_000 +SECONDS_PER_DAY = 86_400 + +# Measured step times (seconds). Keys = config label. +DATA: dict[str, dict[int, float]] = { + "1B - BF16 - BSHD": {1: 2.05, 2: 1.06, 4: 0.57, 8: 0.34, 16: 0.21}, + "1B - MXFP8 - THD": {1: 0.71, 2: 0.34, 4: 0.23, 8: 0.23, 16: 0.19}, + "5B - MXFP8 - THD": {1: 1.42, 2: 0.77, 4: 0.46, 8: 0.31, 16: 0.27}, + "10B - THD": {1: 4.18, 2: 2.22, 4: 1.34, 8: 0.75, 16: 0.53}, + "10B - MXFP8 - THD": {1: 2.88, 2: 1.59, 4: 0.90, 8: 0.60, 16: 0.44}, + "10B - BSHD": {1: 13.19, 2: 6.56, 4: 3.29, 8: 1.67, 16: 0.90, 32: 0.57}, +} + +EXTRAPOLATE_NODES = [32, 64] +ALL_NODES = [1, 2, 4, 8, 16, 32, 64] + + +def amdahl(n, a, b): # noqa: D103 + return a + b / n + + +def power_law(n, a, b): # noqa: D103 + return a * np.power(n, -b) + + +def fit_models(nodes: np.ndarray, times: np.ndarray): # noqa: D103 + (a_amd, b_amd), _ = curve_fit(amdahl, nodes, times, p0=[0.1, times[0]]) + (a_pow, b_pow), _ = curve_fit(power_law, nodes, times, p0=[times[0], 0.8]) + return (a_amd, b_amd), (a_pow, b_pow) + + +def days(step_time_s: float) -> float: # noqa: D103 + return step_time_s * TOTAL_STEPS / SECONDS_PER_DAY + + +def main() -> None: # noqa: D103 + fits = {} + print(f"{'Config':<22} {'Amdahl a (floor)':>18} {'Amdahl b':>12} {'Power a':>10} {'Power b':>10}") + print("-" * 78) + for label, points in DATA.items(): + nodes = np.array(sorted(points.keys()), dtype=float) + times = np.array([points[int(n)] for n in nodes], dtype=float) + (a_amd, b_amd), (a_pow, b_pow) = fit_models(nodes, times) + fits[label] = { + "amdahl": (a_amd, b_amd), + "power": (a_pow, b_pow), + "nodes": nodes, + "times": times, + } + print(f"{label:<22} {a_amd:>18.4f} {b_amd:>12.4f} {a_pow:>10.4f} {b_pow:>10.4f}") + + print() + print(f"Extrapolated step times (Amdahl fit) and total days for {TOTAL_STEPS:,} steps:") + header = f"{'Config':<22} " + " ".join(f"{'N=' + str(n):>11}" for n in ALL_NODES) + print(header) + print("-" * len(header)) + extrap_table = {} + for label, f in fits.items(): + a, b = f["amdahl"] + row = [] + extrap_table[label] = {} + for n in ALL_NODES: + if n in DATA[label]: + t = DATA[label][n] + tag = "" + else: + t = amdahl(n, a, b) + tag = "*" + extrap_table[label][n] = t + row.append(f"{t:>7.3f}s{tag:<3}") + print(f"{label:<22} " + " ".join(row)) + print(" * = extrapolated") + + print() + print(f"Total training time (days) for {TOTAL_STEPS:,} steps:") + print(header) + print("-" * len(header)) + for label, by_n in extrap_table.items(): + row = " ".join(f"{days(by_n[n]):>10.2f}d" for n in ALL_NODES) + print(f"{label:<22} " + row) + + # ---------- Plot 1: step time vs nodes (one panel per config) ---------- + fig, axes = plt.subplots(2, 3, figsize=(15, 9)) + axes = axes.flatten() + smooth_n = np.linspace(1, 64, 200) + + for ax, (label, f) in zip(axes, fits.items()): + a, b = f["amdahl"] + ax.plot(smooth_n, amdahl(smooth_n, a, b), "-", color="C0", label=f"Amdahl: {a:.3f} + {b:.3f}/N") + ax.plot(f["nodes"], f["times"], "o", color="C0", markersize=8, label="Measured") + for n_meas, t_meas in zip(f["nodes"], f["times"]): + ax.annotate(f"{t_meas:.3f}s", (n_meas, t_meas), xytext=(6, 6), textcoords="offset points", fontsize=9) + extrap_nodes = [n for n in EXTRAPOLATE_NODES if n not in DATA[label]] + if extrap_nodes: + extrap_x = np.array(extrap_nodes, dtype=float) + ax.plot(extrap_x, amdahl(extrap_x, a, b), "s", color="C3", markersize=9, label="Extrapolated (Amdahl)") + for n in extrap_nodes: + t = amdahl(n, a, b) + ax.annotate(f"{t:.3f}s", (n, t), xytext=(6, 6), textcoords="offset points", fontsize=9) + ax.set_xscale("log", base=2) + ax.set_yscale("log") + ax.set_xticks(ALL_NODES) + ax.set_xticklabels(ALL_NODES) + ax.set_xlabel("# nodes") + ax.set_ylabel("Step time (s)") + ax.set_title(label) + ax.grid(True, which="both", alpha=0.3) + ax.legend(fontsize=8, loc="upper right") + + # Hide unused 6th subplot + for ax in axes[len(fits) :]: + ax.axis("off") + fig.suptitle( + "Step time vs # nodes — measured + Amdahl extrapolation to 32/64\nHardware: NVIDIA B300 GPUs", fontsize=13 + ) + fig.tight_layout(rect=[0, 0, 1, 0.97]) + fig.savefig("step_time_vs_nodes.png", dpi=140) + print("\nWrote step_time_vs_nodes.png") + + # ---------- Plot 2: bar chart of total days ---------- + fig2, ax2 = plt.subplots(figsize=(14, 7)) + configs = list(extrap_table.keys()) + n_configs = len(configs) + bar_width = 0.85 / n_configs + x = np.arange(len(ALL_NODES)) + colors = plt.cm.viridis(np.linspace(0.1, 0.85, n_configs)) + + for i, label in enumerate(configs): + days_per_n = [days(extrap_table[label][n]) for n in ALL_NODES] + offset = (i - (n_configs - 1) / 2) * bar_width + bars = ax2.bar(x + offset, days_per_n, bar_width, label=label, color=colors[i]) + for bar, d, n in zip(bars, days_per_n, ALL_NODES): + tag = "*" if n not in DATA[label] else "" + ax2.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.3, + f"{d:.1f}{tag}", + ha="center", + va="bottom", + fontsize=7, + rotation=90, + ) + + ax2.set_xticks(x) + ax2.set_xticklabels([f"{n} nodes" for n in ALL_NODES]) + ax2.set_ylabel(f"Total training time (days) — {TOTAL_STEPS:,} steps") + ax2.set_title( + "Total training time per configuration — Hardware: NVIDIA B300 GPUs\n(* = extrapolated step time, Amdahl fit)" + ) + ax2.legend(loc="upper right") + ax2.grid(True, axis="y", alpha=0.3) + fig2.tight_layout() + fig2.savefig("total_days_per_config.png", dpi=140) + print("Wrote total_days_per_config.png") + + +if __name__ == "__main__": + main() From 77b25f8c47d9e505a8405c9bff2e255bffebae0f Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Wed, 20 May 2026 02:59:22 +0000 Subject: [PATCH 15/17] Add scripts and results to evaluate PTL based model This commit should be deleted eventually --- .../evaluation/mrfp/evaluate_rf.py | 324 ++++++++++++++++ .../evaluation/mrfp/preprocess.py | 100 +++++ .../evaluation/ribonn/TE-1B/metrics.csv | 11 + .../ribonn/TE-1B/metrics_summary.csv | 2 + .../evaluation/ribonn/TE-5B/metrics.csv | 11 + .../ribonn/TE-5B/metrics_summary.csv | 2 + .../evaluation/ribonn/TE-80M/metrics.csv | 11 + .../ribonn/TE-80M/metrics_summary.csv | 2 + .../evaluation/ribonn/evaluate_rf.py | 346 ++++++++++++++++++ .../evaluation/ribonn/preprocess.py | 90 +++++ 10 files changed, 899 insertions(+) create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/evaluation/mrfp/evaluate_rf.py create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/evaluation/mrfp/preprocess.py create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-1B/metrics.csv create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-1B/metrics_summary.csv create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-5B/metrics.csv create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-5B/metrics_summary.csv create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-80M/metrics.csv create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-80M/metrics_summary.csv create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/evaluate_rf.py create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/preprocess.py diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/mrfp/evaluate_rf.py b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/mrfp/evaluate_rf.py new file mode 100644 index 0000000000..dcef44920f --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/mrfp/evaluate_rf.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reproduce the mRFP Expression Random Forest eval against a PTL CodonFM model. + +Mirror of `codonfm_native_te/evaluation/mrfp/evaluate_rf.py`, but using the PyTorch-Lightning +`EncodonInference` wrapper from this recipe (which is what the published +`nvidia/NV-CodonFM-Encodon-TE-*` checkpoints on Hugging Face Hub were trained with). + +Loads CDS sequences + labels + train/val/test splits from `mrfp_expression.parquet` (produced +by `preprocess.py` in the native_te recipe — the parquet schema is identical), extracts +CLS-token embeddings, tunes a RandomForestRegressor with GridSearchCV on a predefined +train/val split, refits on train only, and writes per-split metrics to a CSV. + +Embeddings are cached to disk and validated on load against the current ids, targets, splits, +sequence hash, and use_transformer_engine flag, so re-running the script only re-tunes the +RF unless any of those inputs change. + +Usage: + python evaluate_rf.py \ + --model-name-or-path nvidia/NV-CodonFM-Encodon-TE-80M-v1 +""" + +import argparse +import hashlib +import sys +from pathlib import Path + +import numpy as np +import polars as pl +import torch +from scipy.stats import spearmanr +from sklearn.ensemble import RandomForestRegressor +from sklearn.metrics import r2_score +from sklearn.model_selection import GridSearchCV +from tqdm import tqdm + + +# The `src.*` PTL inference modules live at the recipe root. +RECIPE_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(RECIPE_ROOT)) + +from src.data.metadata import MetadataFields # noqa: E402 +from src.inference.encodon import EncodonInference # noqa: E402 +from src.inference.task_types import TaskTypes # noqa: E402 +from src.utils.load_checkpoint import download_checkpoint # noqa: E402 + + +SCRIPT_DIR = Path(__file__).parent +DEFAULT_CHECKPOINT_CACHE = SCRIPT_DIR / "checkpoints" + +# Mirrors notebook section 5 param_grid (single-point grid). +RF_PARAM_GRID = { + "n_estimators": [1000], + "max_depth": [10], + "min_samples_split": [25], + "min_samples_leaf": [2], +} + + +def _slugify(s: str) -> str: + """Make a model name/path safe to embed in a filename.""" + return s.replace("/", "__").replace(":", "_") + + +def _resolve_model_path(model_name_or_path: str) -> str: + """Return a local checkpoint dir, downloading from HF Hub if `model_name_or_path` isn't local.""" + p = Path(model_name_or_path) + if p.is_dir(): + return str(p) + local_dir = DEFAULT_CHECKPOINT_CACHE / p.name + print(f"Downloading checkpoint {model_name_or_path} -> {local_dir}") + return download_checkpoint(repo_id=model_name_or_path, local_dir=str(local_dir)) + + +def extract_embeddings( + encodon_model: EncodonInference, + sequences: list[str], + batch_size: int, +) -> np.ndarray: + """Return CLS embeddings for `sequences`, looping verbatim from the mRFP notebook.""" + all_embeddings: list[np.ndarray] = [] + + for i in tqdm(range(0, len(sequences), batch_size)): + batch_seqs = sequences[i : i + batch_size] + + batch_items = [] + for raw_seq in batch_seqs: + seq = raw_seq.upper().replace("U", "T") + tokens = encodon_model.tokenizer.tokenize(seq) + input_ids = encodon_model.tokenizer.convert_tokens_to_ids(tokens) + + if len(input_ids) > encodon_model.model.hparams.max_position_embeddings - 2: + input_ids = input_ids[: encodon_model.model.hparams.max_position_embeddings - 2] + + input_ids = [encodon_model.tokenizer.cls_token_id, *input_ids, encodon_model.tokenizer.sep_token_id] + attention_mask = [1] * len(input_ids) + + batch_items.append( + { + MetadataFields.INPUT_IDS: input_ids, + MetadataFields.ATTENTION_MASK: attention_mask, + } + ) + + max_len = encodon_model.model.hparams.max_position_embeddings + + padded_input_ids = [] + padded_attention_masks = [] + + for item in batch_items: + input_ids = item[MetadataFields.INPUT_IDS] + attention_mask = item[MetadataFields.ATTENTION_MASK] + + pad_len = max_len - len(input_ids) + input_ids.extend([encodon_model.tokenizer.pad_token_id] * pad_len) + attention_mask.extend([0] * pad_len) + + padded_input_ids.append(input_ids) + padded_attention_masks.append(attention_mask) + + batch = { + MetadataFields.INPUT_IDS: torch.tensor(padded_input_ids, dtype=torch.long).to(encodon_model.device), + MetadataFields.ATTENTION_MASK: torch.tensor(padded_attention_masks, dtype=torch.long).to( + encodon_model.device + ), + } + + output = encodon_model.extract_embeddings(batch) + all_embeddings.append(output.embeddings) + + return np.vstack(all_embeddings) + + +def load_or_extract_embeddings( + df: pl.DataFrame, + model_name_or_path: str, + cache_path: Path, + batch_size: int, + device: str, + use_transformer_engine: bool, + force_extract: bool, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Return (embeddings, ids, targets, splits), loading from cache when it is still valid.""" + df_pd = df.to_pandas() + ids = df_pd["id"].to_numpy() + targets = df_pd["value"].to_numpy() + splits = df_pd["split"].to_numpy() + raw_sequences = df_pd["ref_seq"].tolist() + seqs_hash = hashlib.sha256("\n".join(raw_sequences).encode()).hexdigest() + + if cache_path.exists() and not force_extract: + z = np.load(cache_path, allow_pickle=False) + cache_valid = ( + "use_transformer_engine" in z.files + and bool(z["use_transformer_engine"]) == use_transformer_engine + and "seqs_hash" in z.files + and str(z["seqs_hash"].item()) == seqs_hash + and np.array_equal(z["ids"], ids) + and np.array_equal(z["targets"], targets) + and np.array_equal(z["splits"], splits) + ) + if cache_valid: + print(f"Loading cached embeddings from {cache_path}") + return z["embeddings"], z["ids"], z["targets"], z["splits"] + print(f"Cache at {cache_path} is stale; re-extracting.") + + checkpoint_path = _resolve_model_path(model_name_or_path) + print(f"Loading PTL model from {checkpoint_path} (use_transformer_engine={use_transformer_engine})") + encodon_model = EncodonInference( + model_path=checkpoint_path, + task_type=TaskTypes.EMBEDDING_PREDICTION, + use_transformer_engine=use_transformer_engine, + ) + encodon_model.configure_model() + encodon_model.to(device) + encodon_model.eval() + + print(f"Extracting embeddings for {len(raw_sequences)} sequences...") + embeddings = extract_embeddings(encodon_model, raw_sequences, batch_size=batch_size) + + cache_path.parent.mkdir(parents=True, exist_ok=True) + np.savez( + cache_path, + embeddings=embeddings, + ids=ids, + targets=targets, + splits=splits, + seqs_hash=np.array(seqs_hash), + use_transformer_engine=np.array(use_transformer_engine), + ) + print(f"Cached embeddings to {cache_path}") + return embeddings, ids, targets, splits + + +def train_and_evaluate( + embeddings: np.ndarray, + targets: np.ndarray, + splits: np.ndarray, + seed: int, +) -> tuple[list[dict], dict]: + """Tune RF via GridSearchCV on train/val, refit on train, return per-split metrics + best params.""" + train_mask = splits == "train" + val_mask = splits == "val" + test_mask = splits == "test" + + x_train, y_train = embeddings[train_mask], targets[train_mask] + x_val, y_val = embeddings[val_mask], targets[val_mask] + x_test, y_test = embeddings[test_mask], targets[test_mask] + print(f"Train: {len(y_train)}, Val: {len(y_val)}, Test: {len(y_test)}") + + x_train_val = np.vstack([x_train, x_val]) + y_train_val = np.concatenate([y_train, y_val]) + train_indices = list(range(len(x_train))) + val_indices = list(range(len(x_train), len(x_train_val))) + cv_splits = [(train_indices, val_indices)] + + rf_base = RandomForestRegressor(random_state=seed, n_jobs=-1) + print("Performing hyperparameter tuning...") + grid_search = GridSearchCV( + estimator=rf_base, + param_grid=RF_PARAM_GRID, + cv=cv_splits, + scoring="r2", + n_jobs=-1, + verbose=1, + ) + grid_search.fit(x_train_val, y_train_val) + rf = grid_search.best_estimator_ + + print("\n=== BEST PARAMETERS ===") + for param, value in grid_search.best_params_.items(): + print(f"{param}: {value}") + print(f"Best validation R²: {grid_search.best_score_:.4f}") + + rf.fit(x_train, y_train) + + rows: list[dict] = [] + for name, x, y in [("train", x_train, y_train), ("val", x_val, y_val), ("test", x_test, y_test)]: + y_pred = rf.predict(x) + r2 = float(r2_score(y, y_pred)) + spearman_r, _ = spearmanr(y, y_pred) + spearman_r = float(spearman_r) + print(f"{name.capitalize():<5} R² = {r2:.4f} | Spearman r = {spearman_r:.4f}") + rows.append({"split": name, "r2": r2, "spearman_r": spearman_r}) + + return rows, grid_search.best_params_ + + +def main() -> None: + """CLI entrypoint for the PTL mRFP RF evaluation.""" + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--model-name-or-path", + required=True, + help=( + "Local checkpoint directory or HF Hub repo id (e.g. nvidia/NV-CodonFM-Encodon-TE-80M-v1). " + "If not a local directory, the checkpoint is downloaded to ./checkpoints/ next to this script." + ), + ) + parser.add_argument( + "--data-path", + type=Path, + default=SCRIPT_DIR / "mrfp_expression.parquet", + help="Parquet file produced by preprocess.py (default: mrfp_expression.parquet next to this script).", + ) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--device", default="cuda") + parser.add_argument("--output-dir", type=Path, default=SCRIPT_DIR) + parser.add_argument( + "--no-te", + action="store_true", + help="Disable TransformerEngine in EncodonInference (default: TE enabled, matching the notebook).", + ) + parser.add_argument( + "--force-extract", + action="store_true", + help="Re-extract embeddings even if a cached file exists.", + ) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + use_transformer_engine = not args.no_te + + df = pl.read_parquet(args.data_path) + print(f"Loaded {len(df)} rows from {args.data_path}") + + n_tag = f"n{len(df)}" + cache_path = args.output_dir / f"embeddings_{_slugify(args.model_name_or_path)}_{n_tag}.npz" + + embeddings, _ids, targets, splits = load_or_extract_embeddings( + df=df, + model_name_or_path=args.model_name_or_path, + cache_path=cache_path, + batch_size=args.batch_size, + device=args.device, + use_transformer_engine=use_transformer_engine, + force_extract=args.force_extract, + ) + print(f"Embeddings shape: {embeddings.shape}") + + print("\n=== TRAINING RANDOM FOREST ===") + rows, _best_params = train_and_evaluate(embeddings, targets, splits, seed=args.seed) + + metrics_path = args.output_dir / "metrics.csv" + pl.DataFrame(rows).write_csv(metrics_path) + print(f"\nWrote per-split metrics to {metrics_path}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/mrfp/preprocess.py b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/mrfp/preprocess.py new file mode 100644 index 0000000000..c3baf95f25 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/mrfp/preprocess.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download and preprocess the mRFP Expression dataset. + +Extracted verbatim from notebooks/5-EnCodon-Downstream-Task-mRFP-expression.ipynb (section 3), +with the column-normalization step (normally done by data_scripts/preprocess_validation.py) +inlined so this script is self-contained. +""" + +import re +import urllib.request +from pathlib import Path + +import polars as pl + + +SCRIPT_DIR = Path(__file__).parent + +MRFP_DATASET_URL = ( + "https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/master/" + "benchmarks/CodonBERT/data/fine-tune/mRFP_Expression.csv" +) + +RAW_REF_SEQ_COL = "Sequence" + + +def _camel_to_snake(name: str) -> str: + """Convert a column name to snake_case, matching preprocess_validation.py.""" + name = name.replace(" ", "_") + return re.sub(r"(? None: + """Download the mRFP Expression CSV, normalize columns, and write a parquet handoff.""" + raw_path = SCRIPT_DIR / "mRFP_Expression.csv" + + if not raw_path.exists(): + print(f"Downloading mRFP Expression dataset to {raw_path} ...") + urllib.request.urlretrieve(MRFP_DATASET_URL, raw_path) + print("Download complete.") + else: + print(f"Found existing dataset at {raw_path}.") + + data = pl.read_csv(raw_path) + + rename_map = {col: _camel_to_snake(col) for col in data.columns} + rename_map[RAW_REF_SEQ_COL] = "ref_seq" + data = data.rename(rename_map) + + data = data.with_row_index("id") + data = data.with_columns(pl.col("id").cast(pl.Utf8)) + + before = len(data) + data = data.filter(pl.col("ref_seq").str.len_chars() % 3 == 0) + dropped = before - len(data) + if dropped: + print(f"Dropped {dropped} rows whose ref_seq length is not divisible by 3.") + + if data.is_empty(): + raise ValueError("Output dataframe is empty after filtering.") + if data["ref_seq"].null_count() > 0: + raise ValueError("ref_seq column contains nulls.") + + print(f"Loaded {len(data)} sequences") + print(f"Shape: {data.shape}") + print(f"Columns: {data.columns}") + print(f"Split counts: {data['split'].value_counts().to_dict(as_series=False)}") + + value_stats = data.select( + [ + pl.col("value").mean().alias("mean"), + pl.col("value").std().alias("std"), + pl.col("value").min().alias("min"), + pl.col("value").max().alias("max"), + ] + ) + print("\nmRFP expression stats:") + print(f" Mean: {value_stats['mean'][0]:.4f}") + print(f" Range: [{value_stats['min'][0]:.4f}, {value_stats['max'][0]:.4f}]") + + output_path = SCRIPT_DIR / "mrfp_expression.parquet" + data.select(["id", "ref_seq", "value", "dataset", "split"]).write_parquet(output_path) + print(f"\nWrote {output_path}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-1B/metrics.csv b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-1B/metrics.csv new file mode 100644 index 0000000000..d3c4814998 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-1B/metrics.csv @@ -0,0 +1,11 @@ +fold,r2,pearson_r,mse,rmse +0,0.4955365464728886,0.7071278471506246,0.23847743897702264,0.48834151879296794 +1,0.5336473141998186,0.7369823654429968,0.2205653618910011,0.4696438670854769 +2,0.4927357198187039,0.7044137508532713,0.24958914940553403,0.4995889804684787 +3,0.5524338994512115,0.753821564015981,0.21669001946047656,0.46549975237423763 +4,0.5439863866689088,0.7465496747905151,0.22101799798656677,0.4701255130138831 +5,0.5245497270666517,0.7276394563415014,0.2229086390889875,0.472132014471575 +6,0.5492112441486938,0.7471346572764762,0.21787474961790235,0.4667705535034342 +7,0.48575018004592707,0.6994166661654686,0.2317518899220776,0.48140615899890354 +8,0.5092752705640416,0.7142507763427919,0.21491118777913248,0.463585146202003 +9,0.5078336749784663,0.7151159541185469,0.23244195165783943,0.4821223409652776 diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-1B/metrics_summary.csv b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-1B/metrics_summary.csv new file mode 100644 index 0000000000..0dd2ba63d1 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-1B/metrics_summary.csv @@ -0,0 +1,2 @@ +mean_r2,std_r2,mean_pearson_r,std_pearson_r,mean_rmse +0.5194959963415313,0.023372850389493077,0.7252452712498173,0.018830589379758456,0.47604919764521614 diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-5B/metrics.csv b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-5B/metrics.csv new file mode 100644 index 0000000000..74adfc879f --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-5B/metrics.csv @@ -0,0 +1,11 @@ +fold,r2,pearson_r,mse,rmse +0,0.4998078443342312,0.7109200312209408,0.23645824775918697,0.48626972737276886 +1,0.5118285155831124,0.7176781006143864,0.23088474325076175,0.48050467557637955 +2,0.48752652218295545,0.7001025769870611,0.25215222994912584,0.5021476176873946 +3,0.5636104706174017,0.7590740369438052,0.21127886025844234,0.45965080252126433 +4,0.5330031987626446,0.7371841682459039,0.22634126494963064,0.475753365673466 +5,0.5107191778397523,0.7166889860027131,0.2293929111181122,0.47894980020677763 +6,0.532024566817876,0.7338281449169276,0.22618139651539235,0.47558531991157205 +7,0.4904093535700814,0.7026037869959859,0.2296521861831553,0.47922039416447554 +8,0.49961372683715,0.7076476770227154,0.21914242723697053,0.4681265077273135 +9,0.5108825164907425,0.7168515772282419,0.231002034631054,0.480626710276337 diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-5B/metrics_summary.csv b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-5B/metrics_summary.csv new file mode 100644 index 0000000000..58c3620c6c --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-5B/metrics_summary.csv @@ -0,0 +1,2 @@ +mean_r2,std_r2,mean_pearson_r,std_pearson_r,mean_rmse +0.5139425893035947,0.021995717796660087,0.7202579086178682,0.01724081322466779,0.47879915432797415 diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-80M/metrics.csv b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-80M/metrics.csv new file mode 100644 index 0000000000..cee5eee74e --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-80M/metrics.csv @@ -0,0 +1,11 @@ +fold,r2,pearson_r,mse,rmse +0,0.4357315323337414,0.6815745696655744,0.2667493514618028,0.5164778324979715 +1,0.4350455498274717,0.6756935154376349,0.26719988229600716,0.5169138054801856 +2,0.4205079376486369,0.6620494362796442,0.2851273716293269,0.5339731937366584 +3,0.4410580162058856,0.6875341625364932,0.27061287527610955,0.5202046474956846 +4,0.44994010236681425,0.6967706358911174,0.26659979832512964,0.5163330304417195 +5,0.46259637226738337,0.7029817002160667,0.25195465881277485,0.5019508529854042 +6,0.4525194912038706,0.6958287267216597,0.26460770643974096,0.5144003367414731 +7,0.4090697944359645,0.6483720379942055,0.2663086823515812,0.5160510462653682 +8,0.41571097340467045,0.652089153372791,0.2558873461629833,0.505853087529357 +9,0.4349917178592828,0.6745878003467807,0.2668439938427022,0.5165694472601938 diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-80M/metrics_summary.csv b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-80M/metrics_summary.csv new file mode 100644 index 0000000000..84726cc89f --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/TE-80M/metrics_summary.csv @@ -0,0 +1,2 @@ +mean_r2,std_r2,mean_pearson_r,std_pearson_r,mean_rmse +0.4357171487553721,0.0160569448893339,0.6777481738461968,0.01793850835345407,0.5159352349469998 diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/evaluate_rf.py b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/evaluate_rf.py new file mode 100644 index 0000000000..e9a6f6b6ea --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/evaluate_rf.py @@ -0,0 +1,346 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reproduce the RiboNN Random Forest TE-prediction eval against a PTL CodonFM model. + +Mirror of `codonfm_native_te/evaluation/ribonn/evaluate_rf.py`, but using the PyTorch-Lightning +`EncodonInference` wrapper from this recipe (which is what the published +`nvidia/NV-CodonFM-Encodon-TE-*` checkpoints on Hugging Face Hub were trained with). + +Loads CDS sequences + labels from `ribonn_cds.parquet` (produced by `preprocess.py` in the +native_te recipe — the parquet schema is identical), extracts CLS-token embeddings, runs +leave-one-fold-out cross-validation with a RandomForestRegressor, and writes per-fold and +aggregate metrics to CSV. + +Embeddings are cached to disk and validated on load against the current ids, targets, folds, +sequence hash, and use_transformer_engine flag, so re-running the script only re-tunes the +RF unless any of those inputs change. + +Usage: + python evaluate_rf.py \ + --model-name-or-path nvidia/NV-CodonFM-Encodon-TE-80M-v1 \ + --demo-size 500 +""" + +import argparse +import hashlib +import sys +from pathlib import Path + +import numpy as np +import polars as pl +import torch +from scipy.stats import pearsonr +from sklearn.ensemble import RandomForestRegressor +from sklearn.metrics import mean_squared_error, r2_score +from sklearn.model_selection import train_test_split +from tqdm import tqdm + + +# The `src.*` PTL inference modules live at the recipe root. +RECIPE_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(RECIPE_ROOT)) + +from src.data.metadata import MetadataFields # noqa: E402 +from src.inference.encodon import EncodonInference # noqa: E402 +from src.inference.task_types import TaskTypes # noqa: E402 +from src.utils.load_checkpoint import download_checkpoint # noqa: E402 + + +SCRIPT_DIR = Path(__file__).parent +DEFAULT_CHECKPOINT_CACHE = SCRIPT_DIR / "checkpoints" + + +def _slugify(s: str) -> str: + """Make a model name/path safe to embed in a filename.""" + return s.replace("/", "__").replace(":", "_") + + +def _resolve_model_path(model_name_or_path: str) -> str: + """Return a local checkpoint dir, downloading from HF Hub if `model_name_or_path` isn't local.""" + p = Path(model_name_or_path) + if p.is_dir(): + return str(p) + local_dir = DEFAULT_CHECKPOINT_CACHE / p.name + print(f"Downloading checkpoint {model_name_or_path} -> {local_dir}") + return download_checkpoint(repo_id=model_name_or_path, local_dir=str(local_dir)) + + +def extract_embeddings( + encodon_model: EncodonInference, + sequences: list[str], + batch_size: int, +) -> np.ndarray: + """Return CLS embeddings for `sequences`, looping verbatim from the RiboNN notebook.""" + all_embeddings: list[np.ndarray] = [] + + for i in tqdm(range(0, len(sequences), batch_size)): + batch_seqs = sequences[i : i + batch_size] + + # Prepare batch + batch_items = [] + for raw_seq in batch_seqs: + seq = raw_seq.upper().replace("U", "T") + tokens = encodon_model.tokenizer.tokenize(seq) + input_ids = encodon_model.tokenizer.convert_tokens_to_ids(tokens) + + # Truncate if needed + if len(input_ids) > encodon_model.model.hparams.max_position_embeddings - 2: # Leave room for CLS/SEP + input_ids = input_ids[: encodon_model.model.hparams.max_position_embeddings - 2] + + # Add special tokens + input_ids = [encodon_model.tokenizer.cls_token_id, *input_ids, encodon_model.tokenizer.sep_token_id] + attention_mask = [1] * len(input_ids) + + batch_items.append( + { + MetadataFields.INPUT_IDS: input_ids, + MetadataFields.ATTENTION_MASK: attention_mask, + } + ) + + # Pad batch + max_len = encodon_model.model.hparams.max_position_embeddings + + padded_input_ids = [] + padded_attention_masks = [] + + for item in batch_items: + input_ids = item[MetadataFields.INPUT_IDS] + attention_mask = item[MetadataFields.ATTENTION_MASK] + + # Pad + pad_len = max_len - len(input_ids) + input_ids.extend([encodon_model.tokenizer.pad_token_id] * pad_len) + attention_mask.extend([0] * pad_len) + + padded_input_ids.append(input_ids) + padded_attention_masks.append(attention_mask) + + # Create batch tensor + batch = { + MetadataFields.INPUT_IDS: torch.tensor(padded_input_ids, dtype=torch.long).to(encodon_model.device), + MetadataFields.ATTENTION_MASK: torch.tensor(padded_attention_masks, dtype=torch.long).to( + encodon_model.device + ), + } + + # Extract embeddings + output = encodon_model.extract_embeddings(batch) + all_embeddings.append(output.embeddings) + + return np.vstack(all_embeddings) + + +def load_or_extract_embeddings( + df: pl.DataFrame, + model_name_or_path: str, + cache_path: Path, + batch_size: int, + device: str, + use_transformer_engine: bool, + force_extract: bool, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Return (embeddings, ids, targets, folds), loading from cache when it is still valid.""" + df_pd = df.to_pandas() + ids = df_pd["id"].to_numpy() + targets = df_pd["mean_te"].to_numpy() + folds = df_pd["fold"].to_numpy() + raw_sequences = df_pd["cds_sequence"].tolist() + seqs_hash = hashlib.sha256("\n".join(raw_sequences).encode()).hexdigest() + + if cache_path.exists() and not force_extract: + z = np.load(cache_path, allow_pickle=False) + cache_valid = ( + "use_transformer_engine" in z.files + and bool(z["use_transformer_engine"]) == use_transformer_engine + and "seqs_hash" in z.files + and str(z["seqs_hash"].item()) == seqs_hash + and np.array_equal(z["ids"], ids) + and np.array_equal(z["targets"], targets) + and np.array_equal(z["folds"], folds) + ) + if cache_valid: + print(f"Loading cached embeddings from {cache_path}") + return z["embeddings"], z["ids"], z["targets"], z["folds"] + print(f"⚠️ Cache at {cache_path} is stale; re-extracting.") + + checkpoint_path = _resolve_model_path(model_name_or_path) + print(f"Loading PTL model from {checkpoint_path} (use_transformer_engine={use_transformer_engine})") + encodon_model = EncodonInference( + model_path=checkpoint_path, + task_type=TaskTypes.EMBEDDING_PREDICTION, + use_transformer_engine=use_transformer_engine, + ) + encodon_model.configure_model() + encodon_model.to(device) + encodon_model.eval() + + print(f"Extracting embeddings for {len(raw_sequences)} sequences...") + embeddings = extract_embeddings(encodon_model, raw_sequences, batch_size=batch_size) + + cache_path.parent.mkdir(parents=True, exist_ok=True) + np.savez( + cache_path, + embeddings=embeddings, + ids=ids, + targets=targets, + folds=folds, + seqs_hash=np.array(seqs_hash), + use_transformer_engine=np.array(use_transformer_engine), + ) + print(f"✅ Cached embeddings to {cache_path}") + return embeddings, ids, targets, folds + + +def cross_validate( + embeddings: np.ndarray, + targets: np.ndarray, + folds: np.ndarray, + seed: int, +) -> list[dict]: + """Run leave-one-fold-out CV with RandomForestRegressor; return per-fold metrics.""" + rows: list[dict] = [] + for fold in np.unique(folds): + train_mask = folds != fold + test_mask = ~train_mask + x_train, x_test = embeddings[train_mask], embeddings[test_mask] + y_train, y_test = targets[train_mask], targets[test_mask] + + rf = RandomForestRegressor( + n_estimators=500, + max_depth=15, + min_samples_split=2, + random_state=seed, + n_jobs=-1, + ) + rf.fit(x_train, y_train) + y_pred = rf.predict(x_test) + + r2 = r2_score(y_test, y_pred) + pearson_r, _ = pearsonr(y_test, y_pred) + mse = mean_squared_error(y_test, y_pred) + rmse = float(np.sqrt(mse)) + + print(f"Fold {fold}: R² = {r2:.4f}, r = {pearson_r:.4f}, RMSE = {rmse:.4f}") + rows.append( + {"fold": int(fold), "r2": float(r2), "pearson_r": float(pearson_r), "mse": float(mse), "rmse": rmse} + ) + + return rows + + +def main() -> None: + """CLI entrypoint for the PTL RiboNN RF evaluation.""" + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--model-name-or-path", + required=True, + help=( + "Local checkpoint directory or HF Hub repo id (e.g. nvidia/NV-CodonFM-Encodon-TE-80M-v1). " + "If not a local directory, the checkpoint is downloaded to ./checkpoints/ next to this script." + ), + ) + parser.add_argument( + "--data-path", + type=Path, + default=SCRIPT_DIR / "ribonn_cds.parquet", + help="Parquet file produced by preprocess.py (default: ribonn_cds.parquet next to this script).", + ) + parser.add_argument( + "--demo-size", + type=int, + default=None, + help="If set, stratified-sample this many rows by 'fold'. Notebook uses 500.", + ) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--device", default="cuda") + parser.add_argument("--output-dir", type=Path, default=SCRIPT_DIR) + parser.add_argument( + "--no-te", + action="store_true", + help="Disable TransformerEngine in EncodonInference (default: TE enabled, matching the notebook).", + ) + parser.add_argument( + "--force-extract", + action="store_true", + help="Re-extract embeddings even if a cached file exists.", + ) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + use_transformer_engine = not args.no_te + + df = pl.read_parquet(args.data_path) + print(f"Loaded {len(df)} rows from {args.data_path}") + + # Subsample stratified by fold (mirrors notebook section 4). + if args.demo_size is not None and args.demo_size < len(df): + print(f"=== SUBSAMPLING DATA to {args.demo_size} rows ===") + sample_fraction = args.demo_size / len(df) + _, sampled_pd = train_test_split( + df.to_pandas(), + test_size=sample_fraction, + stratify=df["fold"].to_numpy(), + random_state=args.seed, + ) + df = pl.from_pandas(sampled_pd) + + n_tag = f"n{len(df)}" + cache_path = args.output_dir / f"embeddings_{_slugify(args.model_name_or_path)}_{n_tag}.npz" + + embeddings, _ids, targets, folds = load_or_extract_embeddings( + df=df, + model_name_or_path=args.model_name_or_path, + cache_path=cache_path, + batch_size=args.batch_size, + device=args.device, + use_transformer_engine=use_transformer_engine, + force_extract=args.force_extract, + ) + print(f"Embeddings shape: {embeddings.shape}") + + print("\n=== TRAINING RANDOM FOREST ===") + rows = cross_validate(embeddings, targets, folds, seed=args.seed) + + metrics_path = args.output_dir / "metrics.csv" + pl.DataFrame(rows).write_csv(metrics_path) + print(f"\n✅ Wrote per-fold metrics to {metrics_path}") + + # Summary stats — mirrors notebook section 5: Mean RMSE uses sqrt(mean(MSE)), + # not mean(sqrt(MSE)). + r2 = np.array([r["r2"] for r in rows]) + pr = np.array([r["pearson_r"] for r in rows]) + mse = np.array([r["mse"] for r in rows]) + summary = { + "mean_r2": float(r2.mean()), + "std_r2": float(r2.std()), + "mean_pearson_r": float(pr.mean()), + "std_pearson_r": float(pr.std()), + "mean_rmse": float(np.sqrt(mse.mean())), + } + summary_path = args.output_dir / "metrics_summary.csv" + pl.DataFrame([summary]).write_csv(summary_path) + + print("\n=== CROSS-VALIDATION RESULTS ===") + print(f"Mean R²: {summary['mean_r2']:.4f} ± {summary['std_r2']:.4f}") + print(f"Mean Pearson r: {summary['mean_pearson_r']:.4f} ± {summary['std_pearson_r']:.4f}") + print(f"Mean RMSE: {summary['mean_rmse']:.4f}") + print(f"✅ Wrote summary metrics to {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/preprocess.py b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/preprocess.py new file mode 100644 index 0000000000..f8db7e8de9 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/evaluation/ribonn/preprocess.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download and preprocess the RiboNN translation efficiency dataset. + +Extracted verbatim from notebooks/4-EnCodon-Downstream-Task-riboNN.ipynb (section 3). +""" + +import urllib.request +from pathlib import Path + +import polars as pl + + +SCRIPT_DIR = Path(__file__).parent + +# Source URL for the TE dataset +TE_DATASET_URL = "https://raw.githubusercontent.com/CenikLab/TE_classic_ML/refs/heads/main/data/data_with_human_TE_cellline_all_NA_plain.csv" + + +def main() -> None: + """Download the RiboNN TE dataset, slice CDS/UTR regions, and write a parquet handoff.""" + # Configurable dataset path + data_path = SCRIPT_DIR / "data_with_human_TE_cellline_all_NA_plain.csv" + + # Download if missing + if not data_path.exists(): + print(f"Downloading TE dataset to {data_path} ...") + urllib.request.urlretrieve(TE_DATASET_URL, data_path) + print("Download complete.") + else: + print(f"Found existing dataset at {data_path}.") + + # Slice the transcript sequence into CDS / 5'UTR / 3'UTR using utr5_size and cds_size, + # and add a row index column 'id'. + data = pl.read_csv(data_path, separator="\t") + data = data.with_columns( + [ + pl.struct(["utr5_size", "cds_size", "tx_sequence"]) + .map_elements( + lambda row: row["tx_sequence"][row["utr5_size"] : row["utr5_size"] + row["cds_size"]], + return_dtype=pl.Utf8, + ) + .alias("cds_sequence"), + pl.struct(["utr5_size", "tx_sequence"]) + .map_elements(lambda row: row["tx_sequence"][: row["utr5_size"]], return_dtype=pl.Utf8) + .alias("utr5_sequence"), + pl.struct(["utr5_size", "cds_size", "tx_sequence"]) + .map_elements(lambda row: row["tx_sequence"][row["utr5_size"] + row["cds_size"] :], return_dtype=pl.Utf8) + .alias("utr3_sequence"), + ] + ).with_row_index("id") + + print(f"✅ Loaded {len(data)} sequences") + print(f"Shape: {data.shape}") + print(f"Key columns: {[col for col in ['id', 'cds_sequence', 'mean_te', 'fold'] if col in data.columns]}") + + # Show basic statistics + te_stats = data.select( + [ + pl.col("mean_te").mean().alias("mean"), + pl.col("mean_te").std().alias("std"), + pl.col("mean_te").min().alias("min"), + pl.col("mean_te").max().alias("max"), + ] + ) + print("\nTranslation Efficiency stats:") + print(f" Mean: {te_stats['mean'][0]:.4f}") + print(f" Range: [{te_stats['min'][0]:.4f}, {te_stats['max'][0]:.4f}]") + + # Write only the columns needed by the downstream embedding-extraction script. + output_path = SCRIPT_DIR / "ribonn_cds.parquet" + data.select(["id", "cds_sequence", "mean_te", "fold"]).write_parquet(output_path) + print(f"\n✅ Wrote {output_path}") + + +if __name__ == "__main__": + main() From 0db38e1ef49ed08ccbebb5713f36409f4cc39f02 Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Sat, 23 May 2026 20:46:33 +0000 Subject: [PATCH 16/17] Renamed slurm script and created new SLURM script for PTL-based recipe --- .../slurm/{1b.sh => pretraining.sh} | 2 +- .../codonfm_ptl_te/slurm/pretraining.sh | 236 ++++++++++++++++++ 2 files changed, 237 insertions(+), 1 deletion(-) rename bionemo-recipes/recipes/codonfm_native_te/slurm/{1b.sh => pretraining.sh} (99%) create mode 100755 bionemo-recipes/recipes/codonfm_ptl_te/slurm/pretraining.sh diff --git a/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh b/bionemo-recipes/recipes/codonfm_native_te/slurm/pretraining.sh similarity index 99% rename from bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh rename to bionemo-recipes/recipes/codonfm_native_te/slurm/pretraining.sh index c9aa6eb82f..e1a4953bfe 100755 --- a/bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh +++ b/bionemo-recipes/recipes/codonfm_native_te/slurm/pretraining.sh @@ -20,7 +20,7 @@ else fi # ============================================================================ -# Codon 1B +# CodonFM # ============================================================================ BASE_DIR="" diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/slurm/pretraining.sh b/bionemo-recipes/recipes/codonfm_ptl_te/slurm/pretraining.sh new file mode 100755 index 0000000000..a4a27b7761 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/slurm/pretraining.sh @@ -0,0 +1,236 @@ +#!/bin/bash +#SBATCH --account= +#SBATCH --nodes=1 +#SBATCH --partition= +#SBATCH --ntasks-per-node=1 +#SBATCH --time=03:55:00 +#SBATCH --mem=0 +#SBATCH --job-name= +#SBATCH --mail-type=FAIL +#SBATCH --overcommit +#SBATCH --exclusive +set -euxo pipefail + +# ============================================================================ +# This script is adapted from the experiment scripts here: +# https://gitlab-master.nvidia.com/bio-foundation-models/codon-fm/-/tree/405b2315836a9c1c1ae0c5e41d5abcf4f24d6aa8/experiment_scripts/pretraining/encodon_filtered/mlm +# +# Modifications: +# - 'num_jobs' is not supported in the PTL recipe in bionemo-recipes. +# - '--sharded-state-dict' is not supported in the PTL recipe in bionemo-recipes. It is always 'sharded'. +# - Added support for selecting the sequence packing method (thd or bshd). +# - Added support for selecting the distributed strategy (fsdp or ddp). +# - Added support for selecting the gradient accumulation steps to keep the global batch size constant. +# - Added support for selecting the attention backend (xformers or pytorch SDPA). +# ============================================================================ + +# Establish or inherit chain ID: manual launch picks SLURM_JOB_ID; trap-resubmit inherits via --export. +if [ -z "${CHAIN_ID:-}" ]; then + export CHAIN_ID="${SLURM_JOB_ID}" + echo "Starting NEW chain: CHAIN_ID=${CHAIN_ID}" +else + echo "Continuing chain ${CHAIN_ID} (current job ${SLURM_JOB_ID})" +fi + +# ============================================================================ +# CodonFM +# ============================================================================ + +BASE_DIR="" +CONTAINER="" +DATA_DIR="${BASE_DIR}/data" +CODE_MOUNT="/workspace/bionemo" + + +: "${WANDB_API_KEY:?Set WANDB_API_KEY in ~/.bash_profile}" +: "${HUGGING_FACE_HUB_TOKEN:?Set HUGGING_FACE_HUB_TOKEN in ~/.bash_profile}" +: "${CLUSTER_NAME:?Set CLUSTER_NAME in ~/.bash_profile}" + +export GLOBAL_BATCH_SIZE=1536 +export MICRO_BATCH_SIZE=96 + +# Experiment parameters +export CONFIG_NAME=encodon_xx +export NPROC_PER_NODE=8 +export DIST_STRATEGY=ddp # fsdp or ddp + +# Training +export NUM_TRAIN_STEPS=100 +export LEARNING_RATE=7.5e-5 +export NUM_WORKERS=12 +export USE_SEQUENCE_PACKING=False + +export PRECISION=bf16-mixed + +# Logging / W&B +export LOGGER_FREQUENCY=10 +export WANDB_PROJECT= + +# Attn-backend +export USE_XFORMERS=1 +export USE_TRANSFORMER_ENGINE=0 + +# Derived: build wandb run name from model size, batch size, and precision recipe +MODEL_SIZE="${CONFIG_NAME##*_}" +PRECISION_TAG="${PRECISION}" + +if [ "${USE_SEQUENCE_PACKING}" = "True" ]; then + BATCH_TYPE_TAG="thd" +else + BATCH_TYPE_TAG="bshd" +fi + +# Derive grad accumulation from GBS / (MBS * GPUs). +TOTAL_GPUS=$(( NPROC_PER_NODE * SLURM_JOB_NUM_NODES )) +TOTAL_PER_STEP=$(( MICRO_BATCH_SIZE * TOTAL_GPUS )) +if [ "${TOTAL_PER_STEP}" -eq 0 ] || [ "$(( GLOBAL_BATCH_SIZE % TOTAL_PER_STEP ))" -ne 0 ]; then + echo "ERROR: GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE} must be a positive multiple of MICRO_BATCH_SIZE*NPROC_PER_NODE*NODES=${TOTAL_PER_STEP}" >&2 + exit 1 +fi +export GRAD_ACC_STEPS=$(( GLOBAL_BATCH_SIZE / TOTAL_PER_STEP )) +echo "Batch sizing: GBS=${GLOBAL_BATCH_SIZE}, MBS=${MICRO_BATCH_SIZE}, NPROC=${NPROC_PER_NODE}, NODES=${SLURM_JOB_NUM_NODES}, GRAD_ACC=${GRAD_ACC_STEPS}" + +export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_gbs${GLOBAL_BATCH_SIZE}_mbs${MICRO_BATCH_SIZE}_ga${GRAD_ACC_STEPS}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}_chain_${CHAIN_ID}" + +# Mounts +RESULTS_DIR="${BASE_DIR}/results/${WANDB_RUN_NAME}" +CKPT_DIR="${BASE_DIR}/checkpoints/${WANDB_RUN_NAME}" + +mkdir -p "${RESULTS_DIR}" "${CKPT_DIR}" + +MOUNTS="${DATA_DIR}:${CODE_MOUNT}/data,${RESULTS_DIR}:${CODE_MOUNT}/results,${CKPT_DIR}:${CODE_MOUNT}/checkpoints" + +# Resolve head node on the host (scontrol is not available inside the container). +MASTER_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) +MASTER_PORT=29500 + + +read -r -d '' COMMAND <<'OUTER_EOF' || true +set -euxo pipefail + +echo "=========================================" +echo "CodonFM ${CONFIG_NAME} - STRATEGY: ${DIST_STRATEGY} - PRECISION: ${PRECISION_TAG} - CLUSTER: ${CLUSTER_NAME}" +echo "Job ID: ${SLURM_JOB_ID}" +echo "Nodes: ${SLURM_JOB_NUM_NODES}" +echo "=========================================" + +export USE_XFORMERS=${USE_XFORMERS:-0} +if [ "${USE_XFORMERS}" = "1" ]; then + echo "Using Xformers" +else + echo "Using PyTorch SDPA attention" +fi + +# cuDNN fused-attn sub-backend 1 OOMs on Blackwell (sm_103) with THD+padding (TE 2.12 / cuDNN 9.19); force flash-attn varlen. +if [ "${USE_SEQUENCE_PACKING}" = "True" ]; then + export NVTE_FUSED_ATTN=0 + EXTRA_ARGS="--collate_fn thd --attn_input_format thd" +else + EXTRA_ARGS="--collate_fn bshd --attn_input_format bshd" +fi + +# Pick training script based on distributed strategy. +case "${DIST_STRATEGY}" in + fsdp) + EXTRA_ARGS="${EXTRA_ARGS} --enable_fsdp" + ;; + ddp) + EXTRA_ARGS="${EXTRA_ARGS}" + ;; + *) + echo "DIST_STRATEGY must be 'fsdp' or 'ddp', got '${DIST_STRATEGY}'" >&2 + exit 1 + ;; +esac + +if [ "${PRECISION}" = "bf16-mixed" ]; then + EXTRA_ARGS="${EXTRA_ARGS} --bf16" +fi + +if [ "${USE_TRANSFORMER_ENGINE}" = "1" ]; then + EXTRA_ARGS="${EXTRA_ARGS} --use_transformer_engine" +fi + +torchrun \ + --nproc_per_node=${NPROC_PER_NODE} \ + --rdzv_id=${SLURM_JOB_ID} \ + --rdzv_backend=c10d \ + --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \ + --nnodes=${SLURM_JOB_NUM_NODES} \ + --node-rank=${SLURM_NODEID} \ + -m src.runner pretrain \ + --exp_name ${WANDB_RUN_NAME} \ + --model_name ${CONFIG_NAME} \ + --data_path /workspace/bionemo/data/processed_unfiltered/ \ + --process_item mlm_memmap \ + --dataset_name CodonMemmapDataset \ + --lr ${LEARNING_RATE} \ + --num_gpus ${NPROC_PER_NODE} \ + --num_nodes ${SLURM_JOB_NUM_NODES} \ + --train_batch_size ${MICRO_BATCH_SIZE} \ + --val_batch_size ${MICRO_BATCH_SIZE} \ + --num_workers ${NUM_WORKERS} \ + ${EXTRA_ARGS} \ + --split_name_prefix nopathogen \ + --taxid_exclusion_file /workspace/bionemo/data/taxids_to_remove.json \ + --enable_wandb \ + --project_name ${WANDB_PROJECT} \ + --entity clara-discovery \ + --gradient_accumulation_steps ${GRAD_ACC_STEPS} \ + --max_steps ${NUM_TRAIN_STEPS} \ + --log_every_n_steps ${LOGGER_FREQUENCY} + +echo "=========================================" +echo "Training complete!" +echo "=========================================" +OUTER_EOF + +# Inject environment variables into the command. +COMMAND="export DIST_STRATEGY=\"${DIST_STRATEGY}\"; ${COMMAND}" +COMMAND="export PRECISION_TAG=\"${PRECISION_TAG}\"; ${COMMAND}" +COMMAND="export CLUSTER_NAME=\"${CLUSTER_NAME}\"; ${COMMAND}" +COMMAND="export NPROC_PER_NODE=\"${NPROC_PER_NODE}\"; ${COMMAND}" +COMMAND="export CONFIG_NAME=\"${CONFIG_NAME}\"; ${COMMAND}" +COMMAND="export LOGGER_FREQUENCY=\"${LOGGER_FREQUENCY}\"; ${COMMAND}" +COMMAND="export NUM_TRAIN_STEPS=\"${NUM_TRAIN_STEPS}\"; ${COMMAND}" +COMMAND="export GLOBAL_BATCH_SIZE=\"${GLOBAL_BATCH_SIZE}\"; ${COMMAND}" +COMMAND="export MICRO_BATCH_SIZE=\"${MICRO_BATCH_SIZE}\"; ${COMMAND}" +COMMAND="export GRAD_ACC_STEPS=\"${GRAD_ACC_STEPS}\"; ${COMMAND}" +COMMAND="export LEARNING_RATE=\"${LEARNING_RATE}\"; ${COMMAND}" +COMMAND="export NUM_WORKERS=\"${NUM_WORKERS}\"; ${COMMAND}" +COMMAND="export USE_SEQUENCE_PACKING=\"${USE_SEQUENCE_PACKING}\"; ${COMMAND}" +COMMAND="export PRECISION=\"${PRECISION}\"; ${COMMAND}" +COMMAND="export WANDB_RUN_NAME=\"${WANDB_RUN_NAME}\"; ${COMMAND}" +COMMAND="export WANDB_PROJECT=\"${WANDB_PROJECT}\"; ${COMMAND}" +COMMAND="export USE_XFORMERS=\"${USE_XFORMERS}\"; ${COMMAND}" +COMMAND="export MASTER_ADDR=\"${MASTER_ADDR}\"; ${COMMAND}" +COMMAND="export MASTER_PORT=\"${MASTER_PORT}\"; ${COMMAND}" +COMMAND="export USE_TRANSFORMER_ENGINE=\"${USE_TRANSFORMER_ENGINE}\"; ${COMMAND}" +COMMAND="export WANDB_API_KEY=\"${WANDB_API_KEY}\"; ${COMMAND}" +COMMAND="export HUGGING_FACE_HUB_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}" +COMMAND="export HF_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}" + + +echo "Launching: ${WANDB_RUN_NAME}" + +# AUTO-CHAIN: resubmit on timeout. +trap ' + rc=$? + if [ "$rc" -eq 143 ] || [ "$rc" -eq 137 ]; then + echo "Timed out (rc=$rc) — resubmitting chain ${CHAIN_ID}." + sbatch --dependency=singleton \ + --export=ALL,CHAIN_ID="${CHAIN_ID}" \ + "${BASH_SOURCE[0]}" + elif [ "$rc" -eq 0 ]; then + echo "Training finished cleanly — chain ${CHAIN_ID} ends." + else + echo "Real error (rc=$rc) — chain ${CHAIN_ID} ends so you can investigate." + fi + ' EXIT + +srun \ + --output "${RESULTS_DIR}/slurm-%j-%n.out" \ + --error "${RESULTS_DIR}/error-%j-%n.out" \ + --container-image "${CONTAINER}" \ + --container-mounts "${MOUNTS}" \ + bash -c "${COMMAND}" From 1431df393c8f464dfd2ede19fb7e3de39e51613b Mon Sep 17 00:00:00 2001 From: Bruno Alvisio Date: Sun, 24 May 2026 22:47:57 +0000 Subject: [PATCH 17/17] Add new PTL callback to measure the wall-clock time per optmizer step to match native recipe --- .../recipes/codonfm_ptl_te/src/config.py | 2 + .../codonfm_ptl_te/src/utils/__init__.py | 2 + .../src/utils/interval_step_timing.py | 59 +++++++++++++++++++ 3 files changed, 63 insertions(+) create mode 100644 bionemo-recipes/recipes/codonfm_ptl_te/src/utils/interval_step_timing.py diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/src/config.py b/bionemo-recipes/recipes/codonfm_ptl_te/src/config.py index 9ccf066a51..ed4d990b30 100644 --- a/bionemo-recipes/recipes/codonfm_ptl_te/src/config.py +++ b/bionemo-recipes/recipes/codonfm_ptl_te/src/config.py @@ -35,6 +35,7 @@ from src.tokenizer import Tokenizer from src.utils.fsdp_config import get_fsdp_strategy from src.utils.grad_norm_callback import GradientNormLogger +from src.utils.interval_step_timing import IntervalStepTimingCallback from src.utils.pred_writer import PredWriter from src.utils.scheduler import linear_scheduler_with_warmup_lr_lambda from src.utils.throughput_logger import ThroughputLogger @@ -136,6 +137,7 @@ def get_callbacks_config(args: Any) -> Dict[str, fdl.Config]: "lr_monitor": fdl.Config(LearningRateMonitor, logging_interval="step", log_weight_decay=True), "grad_norm_callback": fdl.Config(GradientNormLogger, log_every_n_steps=args.log_every_n_steps), "timer_callback": fdl.Config(StepTimingCallback, log_every_n_steps=args.log_every_n_steps, mode="train"), + "interval_timer_callback": fdl.Config(IntervalStepTimingCallback, log_every_n_steps=args.log_every_n_steps), "throughput_callback": fdl.Config(ThroughputLogger, log_every_n_steps=args.log_every_n_steps, warmup_steps=40), } if args.mode == "eval": diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/src/utils/__init__.py b/bionemo-recipes/recipes/codonfm_ptl_te/src/utils/__init__.py index 43041cf2dd..97bb6a374d 100644 --- a/bionemo-recipes/recipes/codonfm_ptl_te/src/utils/__init__.py +++ b/bionemo-recipes/recipes/codonfm_ptl_te/src/utils/__init__.py @@ -15,6 +15,7 @@ from src.utils.grad_norm_callback import GradientNormLogger +from src.utils.interval_step_timing import IntervalStepTimingCallback from src.utils.pred_writer import PredWriter from src.utils.pylogger import RankedLogger from src.utils.throughput_logger import ThroughputLogger @@ -22,6 +23,7 @@ __all__ = [ "GradientNormLogger", + "IntervalStepTimingCallback", "PredWriter", "RankedLogger", "ThroughputLogger", diff --git a/bionemo-recipes/recipes/codonfm_ptl_te/src/utils/interval_step_timing.py b/bionemo-recipes/recipes/codonfm_ptl_te/src/utils/interval_step_timing.py new file mode 100644 index 0000000000..9c0bd67476 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_ptl_te/src/utils/interval_step_timing.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +from lightning.pytorch.callbacks import Callback + + +class IntervalStepTimingCallback(Callback): + """Logs mean wall-clock time per optimizer step over a fixed logging interval. + + Mirrors the semantics of `train/step_time` in the native_te recipe's `PerfLogger`: + samples `time.perf_counter()` only at log boundaries and divides by + `log_every_n_steps`, yielding the average optimizer-step wall time over the + last interval rather than a per-step measurement. + """ + + def __init__(self, log_every_n_steps: int = 10): # noqa: D107 + self.log_every_n_steps = log_every_n_steps + self.previous_log_time: float | None = None + + def on_train_start(self, trainer, pl_module): # noqa: D102 + self.previous_log_time = time.perf_counter() + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # noqa: D102 + if (batch_idx + 1) % trainer.accumulate_grad_batches != 0: + return + + step = trainer.global_step + if step == 0 or step % self.log_every_n_steps != 0: + return + + if torch.cuda.is_available(): + torch.cuda.synchronize() + now = time.perf_counter() + step_time = (now - self.previous_log_time) / self.log_every_n_steps + self.previous_log_time = now + + pl_module.log( + "timing_train/step_time", + step_time, + prog_bar=True, + on_step=True, + on_epoch=False, + sync_dist=True, + )