1616from contextlib import nullcontext
1717from math import ceil
1818from pathlib import Path
19- from typing import List , Tuple
19+ from typing import Iterable , List , Tuple
2020
2121import torch
2222import torch .nn .functional as F
@@ -128,9 +128,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
128128 semantic_codec = None
129129
130130 if semantic_codec is not None :
131+ semantic_codec .eval ()
132+ semantic_codec .freeze ()
131133 self .register_nemo_submodule (name = "semantic_codec" , config_field = "semantic_codec" , model = semantic_codec )
132- self .semantic_codec .eval ()
133- self .semantic_codec .freeze ()
134134 else :
135135 self .semantic_codec = None
136136
@@ -140,12 +140,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
140140 self .slm_encoder = instantiate (cfg .get ("slm_encoder" ))
141141 self .slm_encoder .eval ()
142142 self .slm_encoder .freeze ()
143- self .slm_decoder = instantiate (cfg .slm_decoder )
143+ self .slm_predictor = instantiate (cfg .slm_predictor )
144144 self .slm_loss_fn = torch .nn .MSELoss ()
145145 self .slm_loss_scale = cfg .get ("slm_loss_scale" , 1.0 )
146146 else :
147147 self .slm_encoder = None
148- self .slm_decoder = None
148+ self .slm_predictor = None
149149 self .slm_loss_fn = None
150150 self .slm_loss_scale = None
151151
@@ -261,7 +261,7 @@ def codebook_size(self):
261261 def state_dict (self , destination = None , prefix = '' , keep_vars = False ):
262262 if hasattr (self , '_no_state_dict' ) and self ._no_state_dict :
263263 return {}
264- # Don't save the speaker verification and codec model in the state dict
264+ # Avoid saving weights of frozen pretrained models
265265 state_dict = super ().state_dict (destination , prefix , keep_vars )
266266 for key in list (state_dict .keys ()):
267267 if self .use_scl_loss and "speaker_encoder." in key :
@@ -274,7 +274,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
274274 return state_dict
275275
276276 def load_state_dict (self , state_dict , strict = True ):
277- # Override to load all the keys except .speaker_encoder. and WavLM model
277+ # Avoid loading weights of frozen pretrained models
278278 for key in list (state_dict .keys ()):
279279 if self .use_scl_loss and "speaker_encoder." in key :
280280 del state_dict [key ]
@@ -327,8 +327,10 @@ def encode_audio(
327327 encoded , encoded_len = self .audio_encoder (audio = audio_preprocessed , audio_len = audio_preprocessed_len )
328328
329329 if self .semantic_codec is not None :
330- semantic , _ = self .semantic_codec .encode_audio (audio = audio , audio_len = audio_len , sample_rate = sample_rate )
331- semantic = semantic .detach ()
330+ with torch .no_grad ():
331+ semantic , _ = self .semantic_codec .encode_audio (
332+ audio = audio , audio_len = audio_len , sample_rate = sample_rate
333+ )
332334 encoded = torch .concat ([semantic , encoded ], dim = 1 )
333335
334336 return encoded , encoded_len
@@ -574,7 +576,7 @@ def _process_batch(self, batch):
574576
575577 if self .training and self .use_slm_loss :
576578 slm_emb = self .slm_encoder (audio = audio )
577- slm_emb_pred = self .slm_decoder (inputs = encoded )
579+ slm_emb_pred = self .slm_predictor (inputs = encoded )
578580 else :
579581 slm_emb = None
580582 slm_emb_pred = None
@@ -886,11 +888,11 @@ def update_lr(self, interval="step"):
886888 if schedulers is None or self .lr_schedule_interval != interval :
887889 return
888890
889- if self . discriminator is None :
891+ if not isinstance ( schedulers , Iterable ) :
890892 schedulers .step ()
891893 else :
892- schedulers [ 0 ]. step ()
893- schedulers [ 1 ] .step ()
894+ for sch in schedulers :
895+ sch .step ()
894896
895897 def configure_callbacks (self ):
896898 if not self .log_config :
0 commit comments