Skip to content

Commit 199e69e

Browse files
[TTS][Magpietts] Added CFG distillation
1 parent 71fe0c6 commit 199e69e

File tree

5 files changed

+1174
-5
lines changed

5 files changed

+1174
-5
lines changed

examples/tts/magpietts.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,19 @@
2121
MagpieTTSModelOfflinePO,
2222
MagpieTTSModelOfflinePODataGen,
2323
MagpieTTSModelOnlinePO,
24+
OnlineCFGDistillation,
2425
)
2526
from nemo.core.config import hydra_runner
2627
from nemo.utils import logging
2728
from nemo.utils.exp_manager import exp_manager
2829

30+
_TRAIN_MODES: list[str] = [
31+
"train",
32+
"online_cfg_distillation_train",
33+
"dpo_train",
34+
"onlinepo_train",
35+
]
36+
2937

3038
@hydra_runner(config_path="conf/magpietts", config_name="magpietts_lhotse")
3139
def main(cfg):
@@ -54,8 +62,12 @@ def main(cfg):
5462
pl.seed_everything(seed, workers=True)
5563

5664
mode = cfg.get('mode', 'train')
65+
train_modes_msg = ", ".join(_TRAIN_MODES)
66+
5767
if mode == 'train':
5868
model = MagpieTTSModel(cfg=cfg.model, trainer=trainer)
69+
elif mode == 'online_cfg_distillation_train':
70+
model = OnlineCFGDistillation(cfg=cfg.model, trainer=trainer)
5971
elif mode == 'dpo_train':
6072
model_cfg = cfg.model
6173
with open_dict(model_cfg):
@@ -69,21 +81,19 @@ def main(cfg):
6981
elif mode == 'test':
7082
model = MagpieTTSModelOfflinePODataGen(cfg=cfg.model, trainer=trainer)
7183
else:
72-
raise NotImplementedError(f"Only train, dpo_train, onlinepo_train and test modes are supported. Got {mode}")
84+
raise NotImplementedError(f"Only {train_modes_msg} and test modes are supported. Got {mode}")
7385

7486
model.maybe_init_from_pretrained_checkpoint(cfg=cfg)
7587

7688
try:
77-
if mode in ['train', 'dpo_train', 'onlinepo_train']:
89+
if mode in _TRAIN_MODES:
7890
logging.info("Starting training...")
7991
trainer.fit(model)
8092
elif mode == 'test':
8193
logging.info("Starting testing...")
8294
trainer.test(model)
8395
else:
84-
raise NotImplementedError(
85-
f"Only train, dpo_train, onlinepo_train and test modes are supported. Got {mode}"
86-
)
96+
raise NotImplementedError(f"Only {train_modes_msg} and test modes are supported. Got {mode}")
8797
logging.info("Training/testing completed successfully.")
8898
finally:
8999
# Ensure WandB completes all uploads before Python thread shutdown
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Losses used in CFG distillation of the MagpieTTS model.
16+
"""
17+
18+
from typing import Generator, Optional
19+
20+
import torch
21+
from torch import Tensor, nn
22+
23+
from nemo.core.classes import Loss, typecheck
24+
from nemo.core.neural_types import LabelsType, LogitsType, LossType, MaskType, NeuralType
25+
26+
__all__ = [
27+
"KLDivergenceLoss",
28+
"CodesCrossEntropyLoss",
29+
"NRMSELogitsLoss",
30+
]
31+
32+
33+
def _iter_slices(
34+
num_codebooks: int,
35+
num_tokens_per_codebook: int,
36+
frame_stacking_factor: int,
37+
mask: Tensor,
38+
) -> Generator[[int, int, int, int, Tensor, Tensor], None, None]:
39+
for fs_index in range(frame_stacking_factor):
40+
slice_mask = mask[:, fs_index::frame_stacking_factor].float()
41+
slice_len = slice_mask.sum(dim=-1).clamp_min(1)
42+
offset = num_codebooks * fs_index * num_tokens_per_codebook
43+
44+
for codebook in range(num_codebooks):
45+
start = offset + codebook * num_tokens_per_codebook
46+
end = start + num_tokens_per_codebook
47+
48+
yield fs_index, codebook, start, end, slice_mask, slice_len
49+
50+
51+
class KLDivergenceLoss(Loss):
52+
"""The Kullback-Leibler divergence loss."""
53+
54+
@property
55+
def input_types(self) -> dict[str, NeuralType]:
56+
"""Define definitions of module input ports.
57+
58+
Returns:
59+
dict[str, NeuralType]: A dictionary describing expected input tensors.
60+
"""
61+
return {
62+
"student_logits": NeuralType(("B", "T", "D"), LogitsType()),
63+
"teacher_logits": NeuralType(("B", "T", "D"), LogitsType()),
64+
"mask": NeuralType(("B", "T"), MaskType()),
65+
"sample_weights": NeuralType(tuple("B"), MaskType(), optional=True),
66+
}
67+
68+
@property
69+
def output_types(self) -> dict[str, NeuralType]:
70+
"""Define definitions of module output ports.
71+
72+
Returns:
73+
dict[str, NeuralType]: A dictionary describing expected output tensors.
74+
"""
75+
return {"loss": NeuralType(elements_type=LossType())}
76+
77+
def __init__(
78+
self,
79+
num_codebooks: int,
80+
num_tokens_per_codebook: int,
81+
frame_stacking_factor: int,
82+
) -> None:
83+
super().__init__()
84+
self.num_codebooks = num_codebooks
85+
self.num_tokens_per_codebook = num_tokens_per_codebook
86+
self.frame_stacking_factor = frame_stacking_factor
87+
self.criterion = nn.KLDivLoss(reduction="none", log_target=False)
88+
89+
@typecheck()
90+
def forward(
91+
self,
92+
student_logits: Tensor,
93+
teacher_logits: Tensor,
94+
mask: Tensor,
95+
sample_weights: Optional[Tensor] = None,
96+
) -> Tensor:
97+
"""Compute the Kullback-Leibler divergence loss between student and teacher logits.
98+
99+
Args:
100+
student_logits (Tensor): Student logits of shape `(B, T', D)`, where `B` is batch size,
101+
`T'` is the frame-stacked sequence length, and `D` is the concatenated logit dimension
102+
across all codebooks and frame-stacking positions.
103+
teacher_logits (Tensor): Teacher logits of shape `(B, T', D)`.
104+
mask (Tensor): Binary mask of shape `(B, T)` over the unstacked time dimension. For each
105+
frame-stacking position, the corresponding stacked-time mask is obtained by slicing.
106+
sample_weights (Optional[Tensor]): Optional per-sample weighting factors of shape `(B,)`.
107+
If provided, these weights scale the per-sample loss contribution before averaging.
108+
If `None`, all samples contribute equally.
109+
110+
Returns:
111+
Tensor: Scalar tensor representing the averaged masked KL divergence loss.
112+
"""
113+
loss = 0.0
114+
student_log_probs = student_logits.log_softmax(dim=-1)
115+
teacher_probs = teacher_logits.softmax(dim=-1)
116+
117+
for _, _, start, end, slice_mask, slice_len in _iter_slices(
118+
self.num_codebooks,
119+
self.num_tokens_per_codebook,
120+
self.frame_stacking_factor,
121+
mask,
122+
):
123+
teacher_probs_slice = teacher_probs[:, :, start:end]
124+
student_log_probs_slice = student_log_probs[:, :, start:end]
125+
slice_loss = self.criterion(input=student_log_probs_slice, target=teacher_probs_slice)
126+
slice_loss = slice_loss.sum(dim=-1)
127+
slice_loss = (slice_loss * slice_mask).sum(dim=-1) / slice_len
128+
loss = loss + slice_loss
129+
130+
loss = loss / (self.num_codebooks * self.frame_stacking_factor)
131+
132+
if sample_weights is not None:
133+
loss = loss * sample_weights
134+
135+
return loss.mean()
136+
137+
138+
class CodesCrossEntropyLoss(Loss):
139+
"""Cross-entropy loss that supports time masks."""
140+
141+
@property
142+
def input_types(self) -> dict[str, NeuralType]:
143+
"""Define definitions of module input ports.
144+
145+
Returns:
146+
dict[str, NeuralType]: A dictionary describing expected input tensors.
147+
"""
148+
return {
149+
"predicted_logits": NeuralType(("B", "T", "D"), LogitsType()),
150+
"target_codes": NeuralType(("B", "C", "T"), LabelsType()),
151+
"mask": NeuralType(("B", "T"), MaskType()),
152+
"sample_weights": NeuralType(tuple("B"), MaskType(), optional=True),
153+
}
154+
155+
@property
156+
def output_types(self) -> dict[str, NeuralType]:
157+
"""Define definitions of module output ports.
158+
159+
Returns:
160+
dict[str, NeuralType]: A dictionary describing expected output tensors.
161+
"""
162+
return {"loss": NeuralType(elements_type=LossType())}
163+
164+
def __init__(
165+
self,
166+
num_codebooks: int,
167+
num_tokens_per_codebook: int,
168+
frame_stacking_factor: int,
169+
) -> None:
170+
super().__init__()
171+
self.num_codebooks = num_codebooks
172+
self.num_tokens_per_codebook = num_tokens_per_codebook
173+
self.frame_stacking_factor = frame_stacking_factor
174+
self.criterion = nn.CrossEntropyLoss(reduction="none")
175+
176+
@typecheck()
177+
def forward(
178+
self,
179+
predicted_logits: Tensor,
180+
target_codes: Tensor,
181+
mask: Tensor,
182+
sample_weights: Optional[Tensor] = None,
183+
) -> Tensor:
184+
"""Compute cross-entropy loss for discretized code sequences with frame stacking and time masking.
185+
186+
Args:
187+
predicted_logits (Tensor): Predicted logits of shape `(B, T', D)`, where `B` is batch size,
188+
`T'` is the frame-stacked sequence length, and `D` is the concatenated logit dimension
189+
across all codebooks and frame-stacking positions.
190+
target_codes (Tensor): Target code indices of shape `(B, C, T)`, where `C` is the number
191+
of codebooks and `T` is the unstacked time dimension.
192+
mask (Tensor): Binary mask of shape `(B, T)` over the unstacked time dimension.
193+
sample_weights (Optional[Tensor]): Optional per-sample weighting factors of shape `(B,)`.
194+
If provided, these weights scale the per-sample loss contribution before averaging.
195+
If `None`, all samples contribute equally.
196+
197+
Returns:
198+
Tensor: Scalar tensor representing the averaged masked cross-entropy loss.
199+
"""
200+
loss = 0.0
201+
202+
for fs_index, codebook, start, end, slice_mask, slice_len in _iter_slices(
203+
self.num_codebooks,
204+
self.num_tokens_per_codebook,
205+
self.frame_stacking_factor,
206+
mask,
207+
):
208+
target_slice = target_codes[:, codebook, fs_index :: self.frame_stacking_factor]
209+
logits_slice = predicted_logits[:, :, start:end].permute(0, 2, 1)
210+
slice_loss = self.criterion(input=logits_slice, target=target_slice)
211+
slice_loss = (slice_loss * slice_mask).sum(dim=-1) / slice_len
212+
loss = loss + slice_loss
213+
214+
loss = loss / (self.num_codebooks * self.frame_stacking_factor)
215+
216+
if sample_weights is not None:
217+
loss = loss * sample_weights
218+
219+
return loss.mean()
220+
221+
222+
class NRMSELogitsLoss(Loss):
223+
"""Normalized Root Mean Square Error (NRMSE) loss applied to raw logits."""
224+
225+
@property
226+
def input_types(self) -> dict[str, NeuralType]:
227+
"""Define definitions of module input ports.
228+
229+
Returns:
230+
dict[str, NeuralType]: A dictionary describing expected input tensors.
231+
"""
232+
return {
233+
"student_logits": NeuralType(("B", "T", "D"), LogitsType()),
234+
"teacher_logits": NeuralType(("B", "T", "D"), LogitsType()),
235+
"mask": NeuralType(("B", "T"), MaskType()),
236+
"sample_weights": NeuralType(tuple("B"), MaskType(), optional=True),
237+
}
238+
239+
@property
240+
def output_types(self) -> dict[str, NeuralType]:
241+
"""Define definitions of module output ports.
242+
243+
Returns:
244+
dict[str, NeuralType]: A dictionary describing expected output tensors.
245+
"""
246+
return {"loss": NeuralType(elements_type=LossType())}
247+
248+
def __init__(
249+
self,
250+
num_codebooks: int,
251+
num_tokens_per_codebook: int,
252+
frame_stacking_factor: int,
253+
) -> None:
254+
super().__init__()
255+
self.num_codebooks = num_codebooks
256+
self.num_tokens_per_codebook = num_tokens_per_codebook
257+
self.frame_stacking_factor = frame_stacking_factor
258+
self.eps = 1e-8
259+
self.criterion = nn.MSELoss(reduction="none")
260+
261+
@typecheck()
262+
def forward(
263+
self,
264+
student_logits: Tensor,
265+
teacher_logits: Tensor,
266+
mask: Tensor,
267+
sample_weights: Optional[Tensor] = None,
268+
) -> Tensor:
269+
"""Compute the normalized RMSE loss between student and teacher logits.
270+
271+
Args:
272+
student_logits (Tensor): Student logits of shape `(B, T', D)`, where `B` is batch size,
273+
`T'` is the frame-stacked sequence length, and `D` is the concatenated logit dimension
274+
across all codebooks and frame-stacking positions.
275+
teacher_logits (Tensor): Teacher logits of shape `(B, T', D)`.
276+
mask (Tensor): Binary mask of shape `(B, T)` over the unstacked time dimension.
277+
sample_weights (Optional[Tensor]): Optional per-sample weighting factors of shape `(B,)`.
278+
If provided, these weights scale the per-sample loss contribution before averaging.
279+
If `None`, all samples contribute equally.
280+
281+
Returns:
282+
Tensor: Scalar tensor representing the averaged masked normalized RMSE loss.
283+
"""
284+
inf_mask = torch.isinf(teacher_logits) | torch.isinf(student_logits)
285+
teacher_logits = teacher_logits.masked_fill(inf_mask, 0.0)
286+
student_logits = student_logits.masked_fill(inf_mask, 0.0)
287+
loss = 0.0
288+
289+
for _, _, start, end, slice_mask, slice_len in _iter_slices(
290+
self.num_codebooks,
291+
self.num_tokens_per_codebook,
292+
self.frame_stacking_factor,
293+
mask,
294+
):
295+
student_logits_slice = student_logits[:, :, start:end]
296+
teacher_logits_slice = teacher_logits[:, :, start:end]
297+
slice_loss = self.criterion(input=student_logits_slice, target=teacher_logits_slice)
298+
slice_loss = torch.sqrt(slice_loss.mean(dim=-1))
299+
norm = teacher_logits_slice.std(dim=-1).clamp_min(self.eps)
300+
slice_loss = slice_loss / norm
301+
slice_loss = (slice_loss * slice_mask).sum(dim=-1) / slice_len
302+
loss = loss + slice_loss
303+
304+
loss = loss / (self.num_codebooks * self.frame_stacking_factor)
305+
306+
if sample_weights is not None:
307+
loss = loss * sample_weights
308+
309+
return loss.mean()

nemo/collections/tts/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL
1919
from nemo.collections.tts.models.hifigan import HifiGanModel
2020
from nemo.collections.tts.models.magpietts import InferBatchOutput, MagpieTTSModel
21+
from nemo.collections.tts.models.magpietts_cfg_distillation import OnlineCFGDistillation
2122
from nemo.collections.tts.models.magpietts_preference_optimization import (
2223
MagpieTTSModelOfflinePO,
2324
MagpieTTSModelOfflinePODataGen,
@@ -34,6 +35,7 @@
3435
"HifiGanModel",
3536
"InferBatchOutput",
3637
"MagpieTTSModel",
38+
"OnlineCFGDistillation",
3739
"MagpieTTSModelOfflinePODataGen",
3840
"MagpieTTSModelOfflinePO",
3941
"MagpieTTSModelOnlinePO",

nemo/collections/tts/models/magpietts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,7 @@ def load_state_dict(self, state_dict, strict=True):
10511051
'eval_speaker_verification_model',
10521052
'whisper_model',
10531053
'squim_objective_model',
1054+
'_teacher_model',
10541055
]
10551056
# Skip context_encoder if checkpoint has baked embedding (weights won't be in checkpoint)
10561057
if has_baked_embedding_in_ckpt:

0 commit comments

Comments
 (0)