Skip to content

Commit 5ccc6c8

Browse files
authored
Add ASR fine-tuning skill (#15733)
* Add ASR fine-tuning skill draft * Refine ASR fine-tuning skill guidance * Improve ASR fine-tuning guardrails * Add ASR refinement and export guidance * Refine ASR fine-tuning iteration guidance * Add transcript style preflight to ASR finetune skill * Refine ASR finetune skill evaluation guidance * Address ASR finetune skill review feedback
1 parent e1fcfc4 commit 5ccc6c8

8 files changed

Lines changed: 1110 additions & 0 deletions

File tree

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
---
2+
name: nemo-speech-asr-finetune
3+
description: Guide NeMo Speech users through ASR fine-tuning with container setup and Lhotse training.
4+
---
5+
6+
# NeMo Speech ASR Fine-Tuning
7+
8+
Use this skill when a user wants to fine-tune a NeMo Speech ASR model, choose a checkpoint, adapt a tokenizer,
9+
configure Lhotse dataloading, train, average checkpoints, or evaluate a fine-tuned ASR `.nemo` checkpoint.
10+
Also use it for post-run refinement planning after fine-tuning.
11+
12+
Default posture:
13+
14+
- Use the NeMo container unless the user explicitly asks for local execution.
15+
- Prefer Lhotse for train and validation dataloaders.
16+
- Use `trainer.max_steps`, not `trainer.max_epochs`.
17+
- Use `val_wer` as the checkpoint monitor for validation.
18+
- By default, evaluate WER without capitalization and punctuation effects. Change that only when the user explicitly
19+
asks for raw/cased/punctuated scoring.
20+
- Report final quality from standalone evaluation, not only in-training validation logs.
21+
22+
## Staged Workflow
23+
24+
Load only the reference file needed for the current stage:
25+
26+
1. Setup and checkpoint selection: read `references/setup-checkpoints.md`.
27+
2. Data prep, transcript-style preflight, Lhotse, bucketing, validation dataloader, and blends: read
28+
`references/data-lhotse.md`.
29+
3. Architecture detection, tokenizer changes, and AED/Canary multitask metrics: read
30+
`references/architecture-tokenizer-metrics.md`.
31+
4. Training, checkpoint averaging, and evaluation: read `references/training-evaluation.md` and, when reporting WER,
32+
`references/evaluation-style-contract.md`.
33+
5. Post-run refinement, error analysis, curriculum, and general-vs-domain evaluation: read
34+
`references/refinement-iteration.md`.
35+
36+
If the user explicitly asks for parallel/sub-agent work, split the work by these same stages. Keep each agent scoped to
37+
one stage and have the main agent integrate the final command/config.
38+
39+
## Core Commands
40+
41+
Generic fine-tuning uses `examples/asr/speech_to_text_finetune.py`. For architecture-specific recipes, route to:
42+
43+
- CTC: `examples/asr/asr_ctc/speech_to_text_ctc_bpe.py`
44+
- RNNT: `examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py`
45+
- Hybrid RNNT/CTC or TDT/CTC: `examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py`
46+
- AED/Canary: `examples/asr/speech_multitask/speech_to_text_aed.py`
47+
48+
Always check the current repo docs before giving version-sensitive claims:
49+
50+
- `README.md`
51+
- `docs/source/asr/fine_tuning.rst`
52+
- `docs/source/asr/datasets.rst`
53+
- `docs/source/dataloaders.rst`
54+
- `docs/source/asr/featured_models.rst`
55+
- `docs/source/asr/asr_checkpoints.rst`
56+
- `nemo/collections/common/data/lhotse/dataloader.py`
57+
58+
## Non-Negotiable Pitfalls
59+
60+
- When changing Lhotse batch modes, explicitly null conflicting options. For OOMptimizer profiles, set
61+
`batch_size=null`, `batch_duration=null`, and `quadratic_duration=null` when adding `bucket_batch_size`.
62+
- Set `model.validation_ds.use_lhotse=true`, but prefer static validation `batch_size` with bucketing disabled.
63+
- Do not use fused loss/WER or tune `fused_batch_size` for RNNT/TDT fine-tuning guidance from this skill.
64+
- Run the first OOMptimizer pass with default CLI settings; lower `--memory-fraction` only after a real training OOM.
65+
- Run preflight checks before long jobs: disk space, free GPUs, manifest validity, and duration/text sanity.
66+
- Before any fine-tuning, audit transcript style within and across all fine-tuning/validation/test sources. Do not
67+
train on mixed casing, punctuation, inverse-text-normalization, or symbol conventions; choose and fix one target style
68+
first, and compare it with the original checkpoint's prediction style when applicable.
69+
- For small domain adaptation, start with a lower LR than large-data fine-tuning; do not blindly use `1e-4`.
70+
- Do not train a tokenizer on validation or test transcripts.
71+
- Do not ignore silent Lhotse filtering from `min_duration`, `max_duration`, `min_tps`, and `max_tps`.
72+
- Do not use `amp=true` for inference/evaluation; use `amp=false compute_dtype=bfloat16`.
73+
- Unless the user asks otherwise, report the default WER with capitalization and punctuation removed, and record any raw
74+
WER separately when it helps diagnose transcript-style mismatch.
75+
- For AED/Canary, configure `multitask_metrics_cfg` so ASR and translation/task-specific samples are evaluated with
76+
the right constrained metrics.
77+
- If checkpoint averaging is used, evaluate the averaged checkpoint and keep it only if it beats the best individual
78+
checkpoint.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# ASR Fine-Tuning Experiment Ledger
2+
3+
## Goal
4+
5+
- User goal:
6+
- Target checkpoint:
7+
- Architecture:
8+
- Success metric:
9+
- Optional guardrail metrics:
10+
11+
## Data
12+
13+
- Train manifests or input config:
14+
- Validation manifest:
15+
- Test manifest:
16+
- Data sources and weights:
17+
- Tarred/non-tarred:
18+
- Manifest sharding:
19+
- Transcript style policy:
20+
- Style transform artifact:
21+
- Original checkpoint output style:
22+
- Style mismatch decision:
23+
24+
## Preflight
25+
26+
- Disk space:
27+
- GPUs:
28+
- Container/image:
29+
- Manifest validation:
30+
- Duration distribution:
31+
- Text/token distribution:
32+
- Duration/token filters:
33+
- Examples/hours filtered:
34+
35+
## Lhotse And OOMptimizer
36+
37+
- Lhotse train:
38+
- Lhotse validation:
39+
- Bucketing mode:
40+
- Duration bins:
41+
- Bucket batch sizes:
42+
- Static batch size:
43+
- `bucket_buffer_size`:
44+
- `shuffle_buffer_size`:
45+
- `seed`:
46+
- `shard_seed`:
47+
- OOMptimizer settings:
48+
- Training pilot utilization:
49+
- CPU memory notes:
50+
51+
## Training
52+
53+
- Init checkpoint:
54+
- Script/config:
55+
- Precision:
56+
- `sync_batchnorm`:
57+
- `max_steps`:
58+
- `limit_train_batches`:
59+
- `val_check_interval`:
60+
- LR:
61+
- Scheduler:
62+
- Warmup:
63+
- Min LR:
64+
- Save top K:
65+
- Command/log path:
66+
67+
## Evaluation
68+
69+
| Model | Artifact | Prediction Manifest | Raw WER | Default WER | CER | Notes |
70+
| --- | --- | --- | --- | --- | --- | --- |
71+
| baseline | | | | | | |
72+
| final `.nemo` | | | | | | |
73+
| best single | | | | | | |
74+
| averaged | | | | | | |
75+
76+
Default WER uses capitalization and punctuation removal unless the user requested a different metric.
77+
78+
## Error Analysis
79+
80+
- Raw vs default WER gap:
81+
- Worst sources/domains:
82+
- Worst categories:
83+
- Label/audio defects:
84+
- Decoding findings:
85+
86+
## Decision
87+
88+
- Keep artifact:
89+
- Drop artifacts:
90+
- Next intervention:
91+
- Reason:
92+
- If validation/test influenced data or weights, blind holdout plan:
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Stage 3: Architecture, Tokenizer, And AED Metrics
2+
3+
## Architecture Detection
4+
5+
Inspect the model config before choosing scripts and overrides:
6+
7+
```python
8+
from nemo.collections.asr.models import ASRModel
9+
10+
cfg = ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v3", return_config=True)
11+
print(cfg.target)
12+
print(cfg.get("decoder", None))
13+
print(cfg.get("joint", None))
14+
print(cfg.get("loss", None))
15+
print(cfg.get("decoding", None))
16+
```
17+
18+
Classify:
19+
20+
- CTC: `EncDecCTC*`, `ConvASRDecoder`, no RNNT-style `joint`.
21+
- RNNT: `EncDecRNNT*` with `decoder` and `joint`.
22+
- TDT: RNNT-family config with `loss.loss_name: tdt`, `decoding.model_type: tdt`, durations, or extra duration
23+
outputs.
24+
- Hybrid RNNT/CTC or TDT/CTC: `EncDecHybridRNNTCTC*`, `aux_ctc`, `ctc_decoder`.
25+
- AED/Canary: `EncDecMultiTaskModel`, Transformer decoder, `prompt_format`.
26+
27+
Use `examples/asr/speech_to_text_finetune.py` for compatible-architecture fine-tuning. For architecture-specific
28+
recipes:
29+
30+
- CTC: `examples/asr/asr_ctc/speech_to_text_ctc_bpe.py`
31+
- RNNT: `examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py`
32+
- Hybrid RNNT/CTC or TDT/CTC: `examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py`
33+
- AED/Canary: `examples/asr/speech_multitask/speech_to_text_aed.py`
34+
35+
Reference configs to inspect before writing overrides:
36+
37+
- CTC: `examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml`
38+
- RNNT: `examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml`
39+
- TDT: `examples/asr/conf/conformer/tdt/conformer_tdt_bpe.yaml`
40+
- Hybrid TDT/CTC: `examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_tdt_ctc_bpe.yaml`
41+
- AED/Canary: `examples/asr/conf/speech_multitask/fast-conformer_aed.yaml`
42+
43+
## Tokenizer Decisions
44+
45+
Keep the pretrained tokenizer when language/script, casing, punctuation, and symbols match. Replace or extend it when
46+
the target language/script is new, important symbols are missing, normalization changes substantially, the run is
47+
multilingual/code-switching, or Canary/AED prompt/special tokens change.
48+
49+
Train from training text only:
50+
51+
```bash
52+
python scripts/tokenizers/process_asr_text_tokenizer.py \
53+
--manifest=/data/train.json \
54+
--data_root=/data/tokenizers/my_tokenizer \
55+
--vocab_size=1024 \
56+
--tokenizer=spe \
57+
--spe_type=unigram \
58+
--log
59+
```
60+
61+
Replace in generic fine-tuning:
62+
63+
```bash
64+
model.tokenizer.update_tokenizer=true \
65+
model.tokenizer.dir=/data/tokenizers/my_tokenizer/tokenizer_spe_unigram_v1024 \
66+
model.tokenizer.type=bpe
67+
```
68+
69+
Changing tokenizer size usually reinitializes decoder-side parameters such as CTC projection or RNNT/TDT
70+
decoder/joint pieces. Use conservative LR and validate early.
71+
72+
Aggregate tokenizer:
73+
74+
```yaml
75+
tokenizer:
76+
type: agg
77+
langs:
78+
en:
79+
dir: /data/tokenizers/en/tokenizer_spe_unigram_v1024
80+
type: bpe
81+
es:
82+
dir: /data/tokenizers/es/tokenizer_spe_unigram_v1024
83+
type: bpe
84+
```
85+
86+
Standard aggregate ASR configs expect a manifest language field such as `lang`. AED/Canary configs use their own
87+
prompt and language fields; follow `examples/asr/conf/speech_multitask/fast-conformer_aed.yaml`.
88+
89+
## Architecture-Specific Knobs
90+
91+
CTC:
92+
93+
- Main tokenizer-sensitive module is the decoder projection.
94+
- Useful when alignments or non-autoregressive decoding matter.
95+
- Long transcripts can violate CTC length constraints; subword tokenization helps reduce target length.
96+
97+
RNNT:
98+
99+
- Disable fused loss/WER for this skill: `model.joint.fuse_loss_wer=false`.
100+
- Do not tune `fused_batch_size`; use Lhotse bucketing plus OOMptimizer-generated `bucket_batch_size`.
101+
- `model.compute_eval_loss=false` is common when validation samples are long and WER is the main metric.
102+
- Use CUDA graphs for inference/evaluation when supported.
103+
104+
TDT:
105+
106+
- Preserve `loss.loss_name=tdt`, duration settings, extra outputs, and `decoding.model_type=tdt`.
107+
- Disable fused loss/WER and use Lhotse bucketing plus OOMptimizer.
108+
- Use CUDA graphs for inference/evaluation when supported.
109+
110+
Hybrid:
111+
112+
- Check `model.aux_ctc.ctc_loss_weight`; reference configs often use `0.3`.
113+
- Evaluate both decoder paths when relevant with `decoder_type=ctc` and `decoder_type=rnnt`.
114+
115+
AED/Canary:
116+
117+
- Use `examples/asr/speech_multitask/speech_to_text_aed.py`.
118+
- Preserve `prompt_format` and expected manifest fields.
119+
- Prefer 2D Lhotse buckets plus OOMptimizer.
120+
121+
## AED/Canary Multitask Metrics
122+
123+
`EncDecMultiTaskModel` reads `model.multitask_metrics_cfg` and constructs `MultiTaskMetric`
124+
(`nemo/collections/asr/metrics/multitask.py`). Metric constraints are evaluated against each Lhotse cut's `custom`
125+
dict, including manifest fields and `input_cfg.tags`.
126+
127+
Reference config:
128+
129+
```yaml
130+
model:
131+
multitask_metrics_cfg:
132+
log_predictions: true
133+
metrics:
134+
wer:
135+
_target_: nemo.collections.asr.metrics.WER
136+
constraint: ".source_lang==.target_lang"
137+
bleu:
138+
_target_: nemo.collections.asr.metrics.BLEU
139+
constraint: ".source_lang!=.target_lang"
140+
bleu_tokenizer: 13a
141+
check_cuts_for_bleu_tokenizers: false
142+
```
143+
144+
Use constraints to route ASR samples to WER and translation samples to BLEU. Add dataset/task/domain metadata through
145+
manifest fields or `input_cfg.tags`, then reference it in constraints such as `.domain==target` or
146+
`.task==asr and .source_lang==.target_lang`.
147+
148+
Current implementation supports only one instance of each metric class in a single `multitask_metrics_cfg`. For
149+
multiple WER slices by language/domain, prefer separate validation manifests/dataloaders or extend metric aggregation
150+
rather than defining duplicate WER metrics.
151+
152+
For AED validation data, set `use_lhotse: true`, `use_bucketing: false`, static `batch_size`, `text_field: "text"`,
153+
and `lang_field: "target_lang"`.

0 commit comments

Comments
 (0)