Skip to content

Commit 12fc883

Browse files
committed
Rename slm decoder to predictor
Signed-off-by: Ryan <rlangman@nvidia.com>
1 parent 0b037e1 commit 12fc883

File tree

3 files changed

+19
-16
lines changed

3 files changed

+19
-16
lines changed

nemo/collections/tts/data/vocoder_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ class VocoderDataset(Dataset):
126126
Args:
127127
dataset_meta: Dict of dataset names (string) to dataset metadata.
128128
sample_rate: Sample rate to load audio as. If the audio is stored at a different sample rate, then it will
129-
be resampled.
129+
be resampled using librosa.
130+
resample_rate: Optional sample rate to resample to, using torch-based resampling.
130131
n_samples: Optional int, if provided then n_samples samples will be randomly sampled from the full
131132
audio file.
132133
weighted_sampling_steps_per_epoch: Optional int, If provided, then data will be sampled (with replacement) based on

nemo/collections/tts/models/audio_codec.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from contextlib import nullcontext
1717
from math import ceil
1818
from pathlib import Path
19-
from typing import List, Tuple
19+
from typing import Iterable, List, Tuple
2020

2121
import torch
2222
import torch.nn.functional as F
@@ -128,9 +128,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
128128
semantic_codec = None
129129

130130
if semantic_codec is not None:
131+
semantic_codec.eval()
132+
semantic_codec.freeze()
131133
self.register_nemo_submodule(name="semantic_codec", config_field="semantic_codec", model=semantic_codec)
132-
self.semantic_codec.eval()
133-
self.semantic_codec.freeze()
134134
else:
135135
self.semantic_codec = None
136136

@@ -140,12 +140,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
140140
self.slm_encoder = instantiate(cfg.get("slm_encoder"))
141141
self.slm_encoder.eval()
142142
self.slm_encoder.freeze()
143-
self.slm_decoder = instantiate(cfg.slm_decoder)
143+
self.slm_predictor = instantiate(cfg.slm_predictor)
144144
self.slm_loss_fn = torch.nn.MSELoss()
145145
self.slm_loss_scale = cfg.get("slm_loss_scale", 1.0)
146146
else:
147147
self.slm_encoder = None
148-
self.slm_decoder = None
148+
self.slm_predictor = None
149149
self.slm_loss_fn = None
150150
self.slm_loss_scale = None
151151

@@ -261,7 +261,7 @@ def codebook_size(self):
261261
def state_dict(self, destination=None, prefix='', keep_vars=False):
262262
if hasattr(self, '_no_state_dict') and self._no_state_dict:
263263
return {}
264-
# Don't save the speaker verification and codec model in the state dict
264+
# Avoid saving weights of frozen pretrained models
265265
state_dict = super().state_dict(destination, prefix, keep_vars)
266266
for key in list(state_dict.keys()):
267267
if self.use_scl_loss and "speaker_encoder." in key:
@@ -274,7 +274,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
274274
return state_dict
275275

276276
def load_state_dict(self, state_dict, strict=True):
277-
# Override to load all the keys except .speaker_encoder. and WavLM model
277+
# Avoid loading weights of frozen pretrained models
278278
for key in list(state_dict.keys()):
279279
if self.use_scl_loss and "speaker_encoder." in key:
280280
del state_dict[key]
@@ -327,8 +327,10 @@ def encode_audio(
327327
encoded, encoded_len = self.audio_encoder(audio=audio_preprocessed, audio_len=audio_preprocessed_len)
328328

329329
if self.semantic_codec is not None:
330-
semantic, _ = self.semantic_codec.encode_audio(audio=audio, audio_len=audio_len, sample_rate=sample_rate)
331-
semantic = semantic.detach()
330+
with torch.no_grad():
331+
semantic, _ = self.semantic_codec.encode_audio(
332+
audio=audio, audio_len=audio_len, sample_rate=sample_rate
333+
)
332334
encoded = torch.concat([semantic, encoded], dim=1)
333335

334336
return encoded, encoded_len
@@ -574,7 +576,7 @@ def _process_batch(self, batch):
574576

575577
if self.training and self.use_slm_loss:
576578
slm_emb = self.slm_encoder(audio=audio)
577-
slm_emb_pred = self.slm_decoder(inputs=encoded)
579+
slm_emb_pred = self.slm_predictor(inputs=encoded)
578580
else:
579581
slm_emb = None
580582
slm_emb_pred = None
@@ -886,11 +888,11 @@ def update_lr(self, interval="step"):
886888
if schedulers is None or self.lr_schedule_interval != interval:
887889
return
888890

889-
if self.discriminator is None:
891+
if not isinstance(schedulers, Iterable):
890892
schedulers.step()
891893
else:
892-
schedulers[0].step()
893-
schedulers[1].step()
894+
for sch in schedulers:
895+
sch.step()
894896

895897
def configure_callbacks(self):
896898
if not self.log_config:

nemo/collections/tts/modules/audio_codec_modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ def forward(self, audio):
247247
return slm_emb
248248

249249

250-
class SLMDecoder(NeuralModule):
251-
"""Decoder for predicting SLM embeddings for semantic distillation. This decoder uses transposed convolutions to upsample from
250+
class SLMPredictor(NeuralModule):
251+
"""Module for predicting SLM embeddings for semantic distillation. This decoder uses transposed convolutions to upsample from
252252
the codecs frame rate to the frame rate of the SLM model.
253253
254254
Args:

0 commit comments

Comments
 (0)