Skip to content

Commit 919c945

Browse files
pzelaskoclaudestevehuang52claude[bot]
authored
SALM with NeMo Automodel integration for Nemotron Nano V3 LLM backbone (#15447)
* WIP: bringing Yifan's changes to main Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Add workaround for exp_manager issue Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Support reading indexed JSONL datasets with ShareGPT format Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Support reading indexed tarred datasets with ShareGPT format Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Refactor for compactness Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fixes for real-life data Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fixes for real-life data Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fixes for real-life data Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fixes for missing wids-meta.json Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fixes for tarfile edge cases Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fixes for real-world tar files Signed-off-by: Piotr Żelasko <petezor@gmail.com> * move salm llm init to configure_model Signed-off-by: Piotr Żelasko <petezor@gmail.com> * fix: delayed perception init Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Add AutomodelParallelStrategy for Automodel LLM support Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Replace HF Automodel with NeMo Automodel for SALM's LLM backbone Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Update salm default config with new options Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Init fixes Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix dtype initialization Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix mesh selection for speech encoder Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix for mismatched device_mesh axis names in gradient clipping - use automodel's utility Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix for using embed_tokens in FSDP context before running forward on full LLM Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Definitive fix for using embed_tokens outside of llm with fsdp Signed-off-by: Piotr Żelasko <petezor@gmail.com> * this version actually works with Automodel Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * fix from_pretrained with transformers v5 Signed-off-by: Piotr Żelasko <petezor@gmail.com> * fix from_pretrained with transformers v5 Signed-off-by: Piotr Żelasko <petezor@gmail.com> * fix generate/eval Signed-off-by: Piotr Żelasko <petezor@gmail.com> * fix to_hf Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fixes for AutoTokenizer decoding in v5 Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Flag to run configure_model() at the end of __init__ for safetensors converted models Signed-off-by: Piotr Żelasko <petezor@gmail.com> * preliminary: support distributed models in to_hf.py Signed-off-by: Piotr Żelasko <petezor@gmail.com> * fix passing automodel kwargs Signed-off-by: Piotr Żelasko <petezor@gmail.com> * fix Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Enable inference with model parallelism Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix for lightning save_hyperparameters() call Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix for loading into DTensor Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Accelerate loading DTensor Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Accelerate loading DTensor Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Accelerate loading DTensor Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix for pe buffers not in ckpt (essentially strict=False) Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Add Nemotron Nano v3 prompt formatter with <think> reasoning support Implements NemotronNanoV3PromptFormatter (NAME="nemotron-nano-v3") using ChatML-style <|im_start|>/<|im_end|> template with encode_dialog override that handles: auto-insert empty system turn, history thinking truncation, <think></think> prepend for non-thinking assistant turns, and dynamic inference prefix (thinking on/off). Includes Lhotse Cut integration via registered_prompt_format_fn. Verified against HF apply_chat_template for nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 (both string and token match). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Automodel LoRA support Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fixes for model parallel Signed-off-by: Piotr Żelasko <petezor@gmail.com> * LoRA fix Signed-off-by: Piotr Żelasko <petezor@gmail.com> * small ckpt conversion/inference fix Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Separate SALM and SALMAutomodel into independent classes Restore salm.py to its original HF Transformers + PEFT LoRA implementation from main, and extract the NeMo Automodel-based implementation into a new SALMAutomodel class in salm_automodel.py. This keeps both backends available and independent, with scripts auto-detecting the model class from config.json. - salm.py: restored from main (eager init, HF PEFT, move_embedding) - salm_automodel.py: new file with SALMAutomodel (deferred init, automodel LoRA) - salm_train.py: selects model class via model.use_nemo_automodel config key - salm_eval.py/salm_generate.py: auto-detect model class from config.json - salm_automodel.yaml: new config for SALMAutomodel training - Tests split into test_salm.py (CPU) and test_salm_automodel.py (CUDA) - New functional test SPEECHLM_Automodel_Training_SALM.sh Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Piotr Żelasko <petezor@gmail.com> * fix linters Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Add SALMAutomodel docs and speechlm2 pip extra Add documentation for SALMAutomodel (NeMo Automodel variant of SALM) across all speechlm2 doc pages: intro, models, configs, and training_and_scaling. Create pip install nemo-toolkit[speechlm2] extra that composes speechlm2-only (nemo_automodel git dep) + asr + tts. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add SALMAutomodel tutorial notebook and fix EP/FSDP2 docs Add tutorials/speechlm2/SpeechLM_With_NeMo_Automodel.ipynb covering the full pipeline: data download, training, checkpoint conversion, and evaluation with Nemotron Nano V3 MoE backbone on 2 GPUs. Fix docs to clarify that Expert Parallelism reuses the FSDP2 data-parallel axis — dense layers are sharded via FSDP2 while MoE layers use EP on the same GPUs, not a separate dimension. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix uv torch index conflict for speechlm2 extra The docs CI runs `uv sync --all-extras --all-groups` which resolves the speechlm2 extra pulling nemo_automodel from git. uv treats git source deps as workspace members and applies their [tool.uv.sources], causing a conflict: Automodel maps torch to per-platform indexes while NeMo defaulted to PyPI for all platforms. Add matching [tool.uv.sources] for torch to pyproject.toml and regenerate uv.lock with nemo_automodel included. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Remove direction arg Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Apply isort and black reformatting Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> * fix linter Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * fix tests Signed-off-by: Piotr Żelasko <petezor@gmail.com> * fixes Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * fixes for trust_remote_code Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Apply isort and black reformatting Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> * Add explicit enable_thinking support to SALM eval paths * Apply isort and black reformatting Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> * fix inference with ep_size=1 for automodel models Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Fixes Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Fixes for inference and tutorial Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Apply isort and black reformatting Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> * Remove deprecated activation_checkpointing parameter everywhere Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix CI Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Fix to_hf.py crash when run without torchrun Guard dist.init_process_group on RANK env var presence so the script works with plain `python` (single-file checkpoints) as well as `torchrun` (distributed checkpoints). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Apply suggestions from code review Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> * Add flashoptim support and bf16-automodel half precision setup Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Apply isort and black reformatting Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> * patch flashoptim handling of unevenly sharded state dicts Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Reproducibility fix Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Refactor AutomodelPrecision to FlashPrecision to enable re-use by other collections in subsequent PRs Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Apply isort and black reformatting Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> * Address code review Signed-off-by: Piotr Żelasko <petezor@gmail.com> * disable linter Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix test Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * fix for torch.compile config Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Apply isort and black reformatting Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> * fix tests Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Dataloader DP rank patch for Automodel's device_mesh Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Apply isort and black reformatting Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> * fix sloppy fix Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * fix CI HF tokenizer download issue Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Add tests for correct DP rank resolution in the dataloader Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Apply isort and black reformatting Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> * xfail tests with corrupted tokenizer in CI Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Update test pytorch version safeguard Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fix new peft version requiring newer torchao than available in CI container Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Fixes Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> * Bump Automodel pin for transformers compat Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> --------- Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com> Signed-off-by: Piotr Żelasko <petezor@gmail.com> Signed-off-by: pzelasko <pzelasko@users.noreply.github.com> Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: pzelasko <pzelasko@users.noreply.github.com> Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
1 parent 73a5e7d commit 919c945

61 files changed

Lines changed: 10327 additions & 2296 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/source/features/mixed_precision.rst

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Mixed Precision Training
55

66
Mixed precision training enhances computational efficiency by conducting operations in low-precision
77
format while selectively maintaining critical data in single-precision. NeMo supports FP16 and BF16
8-
precision via PyTorch Lightning, in both mixed and true half-precision modes.
8+
precision via PyTorch Lightning, in mixed, true, and flash half-precision modes.
99

1010
Precision Modes
1111
---------------
@@ -23,6 +23,16 @@ PyTorch Lightning provides two categories of half-precision training:
2323
but requires the model to be numerically stable in half-precision.
2424
SpeechLM2 models use ``"bf16-true"`` by default for training.
2525

26+
**Flash Precision** (``"bf16-flash"`` / ``"fp16-flash"``):
27+
The model also runs in half-precision, but NeMo avoids Lightning's global
28+
default-dtype override and autocast context. This mode is intended for use
29+
with FlashOptim, a library of drop-in optimizers that reduces training
30+
memory by shrinking optimizer states, master weights, and gradients. In
31+
practice, this may be a better fit than AMP / mixed precision when
32+
optimizer-state memory or checkpoint size is the bottleneck, and may lead to
33+
improved convergence compared to Lightning's true half-precision as it keeps
34+
track of the residual between half and full precision weights.
35+
2636
Configuration
2737
-------------
2838

@@ -37,6 +47,8 @@ In YAML (with Hydra):
3747
# precision: "16-mixed" # FP16 mixed precision
3848
# precision: "bf16-true" # True BF16 half precision
3949
# precision: "fp16-true" # True FP16 half precision
50+
# precision: "bf16-flash" # BF16 flash precision
51+
# precision: "fp16-flash" # FP16 flash precision
4052
4153
In Python:
4254

@@ -71,7 +83,7 @@ the substring ``"audio"`` is kept in its original precision (typically FP32). Al
7183
tensors are cast to the target half-precision dtype.
7284

7385
This plugin is used automatically when you launch training with NeMo's ``resolve_trainer_cfg``
74-
utility (used by all NeMo example training scripts). When the trainer config specifies
86+
utility (used by many NeMo example training scripts). When the trainer config specifies
7587
``precision: "bf16-true"`` or ``precision: "fp16-true"``, ``resolve_trainer_cfg`` replaces
7688
the precision setting with the ``HalfPrecisionForAudio`` plugin:
7789

@@ -94,3 +106,55 @@ If you construct the trainer manually, you can install the plugin directly:
94106
devices=2,
95107
accelerator="gpu",
96108
)
109+
110+
FlashPrecision
111+
---------------
112+
113+
NeMo provides the ``FlashPrecision`` plugin (in
114+
``nemo.utils.trainer_utils``) primarily for FlashOptim-backed training.
115+
According to the official FlashOptim README, FlashOptim provides drop-in
116+
optimizer replacements that reduce training memory by compressing optimizer
117+
states, master weights, and gradients while preserving the standard PyTorch
118+
optimizer API.
119+
120+
FlashOptim generally expects the model parameters to already be in bf16/fp16,
121+
while the optimizer manages reduced-precision state and master-weight
122+
correction internally. ``FlashPrecision`` fits that model: it preserves the
123+
same audio-aware input casting behavior as ``HalfPrecisionForAudio``, but does
124+
not enter autocast and does not change PyTorch's global default dtype. This
125+
avoids layering Lightning's global precision policy on top of FlashOptim's own
126+
reduced-precision optimizer behavior.
127+
128+
When the trainer config specifies ``precision: "bf16-flash"`` or
129+
``precision: "fp16-flash"``, ``resolve_trainer_cfg`` replaces the precision
130+
setting with the ``FlashPrecision`` plugin:
131+
132+
.. code-block:: python
133+
134+
from nemo.utils.trainer_utils import resolve_trainer_cfg
135+
136+
# In YAML: trainer.precision = "bf16-flash"
137+
trainer = pl.Trainer(**resolve_trainer_cfg(cfg.trainer))
138+
139+
If you construct the trainer manually, you can install the plugin directly:
140+
141+
.. code-block:: python
142+
143+
from nemo.utils.trainer_utils import FlashPrecision
144+
145+
trainer = pl.Trainer(
146+
plugins=[FlashPrecision("bf16-flash")],
147+
devices=2,
148+
accelerator="gpu",
149+
)
150+
151+
If you're going to use ``FlashPrecision``, make sure to set up ``flashoptim`` optimizer, e.g.:
152+
153+
.. code-block:: yaml
154+
155+
optimizer:
156+
_target_: flashoptim.FlashAdamW
157+
lr: 1e-4
158+
betas: [0.9, 0.999]
159+
weight_decay: 5e-2
160+

docs/source/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ NVIDIA NeMo Speech Developer Docs
3131
</a>
3232
<a class="task-card" href="speechlm2/intro.html">
3333
<h3>🧠 Speech Language Models</h3>
34-
<p>Audio-aware LLMs that understand and generate speech. Speech-to-text, speech-to-speech, and more.</p>
34+
<p>Audio-aware LLMs that understand and generate speech. Use HuggingFace Transformers, or NeMo Automodel for efficient MoE and model parallelism. Speech-to-text, speech-to-speech, and more.</p>
3535
<strong>Quick Start →</strong>
3636
</a>
3737
<a class="task-card" href="audio/intro.html">

docs/source/speechlm2/configs.rst

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ See the `SALM paper <https://arxiv.org/abs/2310.09424>`_ for more details.
4040
pretrained_llm: "TinyLlama/TinyLlama_v1.1" # HF model path
4141
pretrained_asr: "stt_en_fastconformer_hybrid_large_streaming_80ms" # NeMo checkpoint name
4242
pretrained_weights: True # Whether to load weights or just architecture
43-
43+
44+
# Fine-tune from a previous training checkpoint (weights only, fresh optimizer)
45+
init_from_checkpoint: null # path to .ckpt, DCP dir, or HF dir
46+
4447
# Special token settings
4548
audio_locator_tag: "<audio>" # Tag to replace with audio embeddings
4649
@@ -94,6 +97,68 @@ See the `SALM paper <https://arxiv.org/abs/2310.09424>`_ for more details.
9497
dropout_pre_encoder: 0
9598
dropout_emb: 0.0
9699
100+
SALMAutomodel Configuration
101+
----------------------------
102+
103+
The SALMAutomodel configuration extends the SALM configuration with NeMo Automodel
104+
support. The key difference is ``use_nemo_automodel: true`` and the use of
105+
``AutomodelParallelStrategy`` instead of ``DDPStrategy``.
106+
107+
The example below shows a configuration for training with NVIDIA Nemotron Nano V3
108+
MoE as the LLM backbone, with Expert Parallelism across 8 GPUs:
109+
110+
.. code-block:: yaml
111+
112+
model:
113+
use_nemo_automodel: true # Selects SALMAutomodel in salm_train.py
114+
pretrained_llm: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
115+
pretrained_asr: "nvidia/canary-1b-flash"
116+
pretrained_weights: True
117+
118+
freeze_params:
119+
- "^llm\\..+$"
120+
- "^perception\\.preprocessor\\..+$"
121+
- "^perception\\.encoder\\..+$"
122+
prevent_freeze_params: []
123+
124+
# LoRA uses Automodel-native format (not HF PEFT):
125+
# lora:
126+
# dim: 128
127+
# alpha: 256
128+
# dropout: 0.01
129+
# target_modules: ["q_proj", "v_proj"]
130+
131+
perception:
132+
target: nemo.collections.speechlm2.modules.perception.AudioPerceptionModule
133+
output_dim: 2048
134+
modality_adapter:
135+
_target_: nemo.collections.speechlm2.modules.perception.IdentityConnector
136+
d_model: 1024
137+
138+
trainer:
139+
strategy:
140+
_target_: nemo.collections.speechlm2.parts.parallel.AutomodelParallelStrategy
141+
ep_size: 8 # Expert Parallelism across 8 GPUs for MoE
142+
# tp_size: 1
143+
# dp_size: null # inferred
144+
145+
NeMo Automodel applies MoE-specific optimizations automatically when an MoE model
146+
is detected:
147+
148+
* **Grouped GEMM** — fuses expert computations into a single batched matrix multiply
149+
for higher GPU throughput.
150+
* **DeepEP** (Deep Expert Parallelism) — efficient all-to-all expert routing across
151+
GPUs, minimizing communication overhead for MoE layers.
152+
153+
Note the differences from the SALM configuration:
154+
155+
* ``model.use_nemo_automodel: true`` — selects ``SALMAutomodel`` in the training script.
156+
* ``model.pretrained_llm`` can point to MoE models like ``nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16``.
157+
* ``trainer.strategy._target_`` uses ``AutomodelParallelStrategy`` instead of ``ModelParallelStrategy``.
158+
* ``ep_size`` controls Expert Parallelism on the FSDP data-parallel axis — dense layers are sharded via FSDP2, while MoE layers use EP for expert routing on the same GPUs.
159+
* LoRA config uses ``dim``/``alpha`` keys (Automodel native) instead of ``r``/``lora_alpha`` (HF PEFT).
160+
* No ``embed_tokens`` freeze pattern — embeddings stay inside the LLM.
161+
97162
DuplexS2SModel Configuration
98163
-----------------------------
99164

@@ -264,6 +329,7 @@ Model Parameters
264329
- **pretrained_llm**: Path to the pretrained HuggingFace LLM
265330
- **pretrained_asr**: Name of the pretrained NeMo ASR model used for perception
266331
- **pretrained_audio_codec**: Path to the pretrained audio codec model (for speech generation)
332+
- **init_from_checkpoint**: Path to a training checkpoint to initialize model weights from (see :ref:`fine-tuning-from-checkpoint` below)
267333
- **freeze_params**: Regex patterns of parameters to freeze during training
268334
- **audio_loss_weight/text_loss_weight**: Weighting of different loss components
269335

@@ -291,6 +357,7 @@ Example Configuration Files
291357
Example configurations for all model types can be found in the example directory:
292358

293359
- SALM: `examples/speechlm2/conf/salm.yaml`
360+
- SALMAutomodel: `examples/speechlm2/conf/salm_automodel.yaml`
294361
- DuplexS2SModel: `examples/speechlm2/conf/s2s_duplex.yaml`
295362
- DuplexS2SSpeechDecoderModel: `examples/speechlm2/conf/s2s_duplex_speech_decoder.yaml`
296363
- DuplexSTTModel: `examples/speechlm2/conf/duplex_stt.yaml`
@@ -307,6 +374,10 @@ You can use these configurations with the training scripts by specifying the con
307374
--config-path=conf \
308375
--config-name=salm
309376
377+
# Train SALMAutomodel
378+
python examples/speechlm2/salm_train.py \
379+
--config-name=salm_automodel
380+
310381
You can also override configuration values from the command line:
311382

312383
.. code-block:: bash
@@ -316,4 +387,59 @@ You can also override configuration values from the command line:
316387
--config-name=salm \
317388
model.pretrained_llm="different/llm/path" \
318389
trainer.max_steps=1000 \
319-
data.train_ds.batch_size=8
390+
data.train_ds.batch_size=8
391+
392+
.. _fine-tuning-from-checkpoint:
393+
394+
Fine-Tuning from a Previous Checkpoint
395+
---------------------------------------
396+
397+
To start a new training run initialized from a previous checkpoint — with a fresh
398+
optimizer, LR scheduler, and step counter — set ``model.init_from_checkpoint``:
399+
400+
.. code-block:: yaml
401+
402+
model:
403+
init_from_checkpoint: /path/to/checkpoints/step=6375.ckpt
404+
405+
Or pass it as a Hydra override:
406+
407+
.. code-block:: bash
408+
409+
python examples/speechlm2/salm_train.py \
410+
--config-name=salm_automodel \
411+
++model.init_from_checkpoint=/path/to/checkpoints/step=6375.ckpt
412+
413+
This differs from ``exp_manager.resume_from_checkpoint`` which restores the
414+
**full** training state (optimizer, scheduler, step counter) to continue an
415+
interrupted run. ``init_from_checkpoint`` only loads model weights, giving you a
416+
clean starting point for fine-tuning on different data or with different
417+
hyperparameters.
418+
419+
Supported Checkpoint Formats
420+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
421+
422+
Three checkpoint formats are supported:
423+
424+
* **Distributed checkpoints (DCP)**: Directories with a ``.metadata`` file, produced
425+
by ``ModelParallelStrategy`` / ``AutomodelParallelStrategy``. This is the default
426+
format when training with FSDP2 or TP. DCP loading handles automatic resharding
427+
when the parallelism configuration differs between the source and target runs.
428+
429+
* **HuggingFace model directories**: Directories containing ``model.safetensors``,
430+
such as the output of ``to_hf.py``.
431+
432+
* **Single-file checkpoints**: Standard ``.ckpt`` or ``.pt`` files with a
433+
``state_dict`` key.
434+
435+
The model architecture is still defined by ``pretrained_llm`` and ``pretrained_asr``
436+
(needed for config and tokenizer initialization), but all weights are overridden by
437+
the checkpoint.
438+
439+
This feature works with both ``SALM`` and ``SALMAutomodel``.
440+
441+
.. note::
442+
``init_from_checkpoint`` requires the source and target models to use the
443+
same model class (e.g., both ``SALMAutomodel``). Cross-model loading
444+
(e.g., ``SALM`` checkpoint into ``SALMAutomodel``) will encounter state dict
445+
key mismatches because the two classes structure the embedding layer differently.

docs/source/speechlm2/intro.rst

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@ SpeechLM2
44
.. note::
55
The SpeechLM2 collection is still in active development and the code is likely to keep changing.
66

7+
.. note::
8+
Install with ``pip install nemo-toolkit[speechlm2]`` to get all required dependencies including NeMo Automodel.
79

10+
SpeechLM2 refers to a collection that augments pre-trained Large Language Models (LLMs) with speech understanding and generation capabilities.
811

9-
SpeechLM2 refers to a collection that augments pre-trained Large Language Models (LLMs) with speech understanding and generation capabilities.
10-
11-
This collection is designed to be compact, efficient, and to support easy swapping of different LLMs backed by HuggingFace AutoModel.
12+
This collection is designed to be compact, efficient, and to support easy swapping of different LLMs backed by HuggingFace AutoModel or NeMo Automodel.
1213
It has a first-class support for using dynamic batch sizes via Lhotse and various model parallelism techniques (e.g., FSDP2, Tensor Parallel, Sequence Parallel) via PyTorch DTensor API.
1314

1415
We currently support six main model types:
1516

16-
* **SALM** (Speech-Augmented Language Model) - a simple but effective approach to augmenting pre-trained LLMs with speech understanding capabilities.
17+
* **SALM** (Speech-Augmented Language Model) - a simple but effective approach to augmenting pre-trained LLMs with speech understanding capabilities. Available in two variants:
18+
19+
* ``SALM`` — uses HuggingFace Transformers for the LLM backbone with optional HF PEFT LoRA.
20+
* ``SALMAutomodel`` — uses `NeMo Automodel <https://github.com/NVIDIA-NeMo/Automodel>`_ for the LLM backbone with native LoRA, advanced parallelism (FSDP2, TP, SP, EP via ``AutomodelParallelStrategy``), and MoE optimizations (Grouped GEMM, DeepEP) for efficient training with models like `NVIDIA Nemotron Nano V3 <https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16>`_.
1721
* **DuplexS2SModel** - a full-duplex speech-to-speech model with an ASR encoder, directly predicting discrete audio codes.
1822
* **DuplexS2SSpeechDecoderModel** - a variant of DuplexS2SModel with a separate transformer decoder for speech generation.
1923
* **DuplexEARTTS** - a ready-to-use duplex text-to-speech model that supports user interruption via a special text interruption token.
@@ -71,7 +75,7 @@ You can run inference using the loaded pretrained SALM model:
7175
prompt = [{"role": "user", "content": f"{model.audio_locator_tag}"}]
7276
7377
# Generate response
74-
with torch.no_grad():
78+
with torch.inference_mode():
7579
output = model.generate(
7680
prompts=[prompt],
7781
audios=audio_signal,
@@ -83,6 +87,43 @@ You can run inference using the loaded pretrained SALM model:
8387
response = model.tokenizer.ids_to_text(output[0])
8488
print(f"Model response: {response}")
8589
90+
SALMAutomodel
91+
*************
92+
93+
``SALMAutomodel`` is the NeMo Automodel variant of SALM. It enables efficient training of
94+
Speech LLMs with MoE architectures like `NVIDIA Nemotron Nano V3 <https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16>`_
95+
using MoE-specific optimizations (Grouped GEMM, DeepEP). It uses deferred initialization
96+
(``configure_model()``) and supports distributed training and inference via
97+
``AutomodelParallelStrategy``.
98+
99+
.. code-block:: python
100+
101+
import torch
102+
import nemo.collections.speechlm2 as slm
103+
from nemo.collections.speechlm2.parts.parallel import setup_distributed
104+
105+
# Initialize distributed and create an Automodel-compatible device mesh with EP=2.
106+
# setup_distributed delegates mesh creation to nemo_automodel, which builds
107+
# the full (pp, dp_replicate, dp_shard, cp, tp) mesh with MoE submeshes.
108+
strategy = setup_distributed(ep_size=2)
109+
110+
# Load a pretrained SALMAutomodel with the Automodel device mesh
111+
model = slm.models.SALMAutomodel.from_pretrained(
112+
"path/to/checkpoint",
113+
device_mesh=strategy.device_mesh,
114+
distributed_config=strategy.distributed_config,
115+
moe_config=strategy.moe_config,
116+
moe_mesh=strategy.moe_mesh,
117+
).eval()
118+
119+
# Inference is identical to SALM
120+
with torch.inference_mode():
121+
output = model.generate(
122+
prompts=[prompt],
123+
audios=audio_signal,
124+
audio_lens=audio_len,
125+
)
126+
86127
DuplexS2SModel
87128
**************
88129

@@ -310,22 +351,35 @@ Alternatively, you can train a model using the provided training scripts in the
310351
--config-path=examples/speechlm2/conf \
311352
--config-name=salm
312353
313-
# For SALM inference/evaluation
354+
# For SALM inference/evaluation
314355
python examples/speechlm2/salm_eval.py \
315356
pretrained_name=/path/to/checkpoint \
316357
inputs=/path/to/test_manifest \
317358
batch_size=64 \
318359
max_new_tokens=128 \
319360
output_manifest=generations.jsonl
320361
362+
To train the SALMAutomodel variant (with NeMo Automodel backend), use the ``salm_automodel`` config:
363+
364+
.. code-block:: bash
365+
366+
# Train SALMAutomodel with NVIDIA Nemotron Nano V3 MoE backbone on 8 GPUs
367+
torchrun --nproc_per_node=8 examples/speechlm2/salm_train.py \
368+
--config-name=salm_automodel \
369+
model.pretrained_llm=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
370+
371+
The ``salm_automodel.yaml`` config sets ``model.use_nemo_automodel: true``, which selects the
372+
``SALMAutomodel`` class. This variant supports ``AutomodelParallelStrategy`` for FSDP2/TP/EP
373+
parallelism and MoE optimizations (Grouped GEMM, DeepEP).
374+
321375
For more detailed information on training at scale, model parallelism, and SLURM-based training, see :doc:`training and scaling <training_and_scaling>`.
322376

323377
Collection Structure
324378
--------------------
325379

326380
The speechlm2 collection is organized into the following key components:
327381

328-
- **Models**: Contains implementations of DuplexS2SModel, DuplexS2SSpeechDecoderModel, DuplexSTTModel, SALM, DuplexEARTTS, and the inference-only NemotronVoiceChat.
382+
- **Models**: Contains implementations of DuplexS2SModel, DuplexS2SSpeechDecoderModel, DuplexSTTModel, SALM, SALMAutomodel, DuplexEARTTS, and the inference-only NemotronVoiceChat.
329383
- **Modules**: Contains audio perception and speech generation modules.
330384
- **Data**: Includes dataset classes and data loading utilities.
331385

0 commit comments

Comments
 (0)