@@ -131,7 +131,9 @@ def get_codec_silence_frame_last_one(self):
131131 audio , audio_len = self .pad_audio_to_factor (audio , audio_len , self .target_samples_per_frame )
132132
133133 with ensures_target_precision (self .audio_codec_run_dtype ), torch .no_grad ():
134- sil_codes , sil_codes_lens = self .audio_codec .encode (audio .unsqueeze (1 ).to (self .audio_codec_run_dtype ), audio_len )
134+ sil_codes , sil_codes_lens = self .audio_codec .encode (
135+ audio .unsqueeze (1 ).to (self .audio_codec_run_dtype ), audio_len
136+ )
135137 return sil_codes [0 , - 1 ]
136138
137139 def get_codec_silence_frame (self ):
@@ -142,7 +144,9 @@ def get_codec_silence_frame(self):
142144 audio , audio_len = self .pad_audio_to_factor (audio , audio_len , self .target_samples_per_frame )
143145
144146 with ensures_target_precision (self .audio_codec_run_dtype ), torch .no_grad ():
145- sil_codes , _ = self .audio_codec .encode (audio .unsqueeze (1 ).to (self .audio_codec_run_dtype ), audio_len ) # [1, T, C]
147+ sil_codes , _ = self .audio_codec .encode (
148+ audio .unsqueeze (1 ).to (self .audio_codec_run_dtype ), audio_len
149+ ) # [1, T, C]
146150 sil_codes = sil_codes [0 ] # [T, C]
147151
148152 # Convert each frame (C tokens) into a tuple
@@ -328,7 +332,9 @@ def prepare_inputs(self, batch: dict):
328332 target_audio , target_audio_lens , self .target_samples_per_frame , 1
329333 )
330334 with ensures_target_precision (self .audio_codec_run_dtype ), torch .no_grad ():
331- target_codes , target_codes_lens = self .audio_codec .encode (target_audio .unsqueeze (1 ).to (self .audio_codec_run_dtype ), target_audio_lens )
335+ target_codes , target_codes_lens = self .audio_codec .encode (
336+ target_audio .unsqueeze (1 ).to (self .audio_codec_run_dtype ), target_audio_lens
337+ )
332338
333339 with fp32_precision ():
334340 target_len = target_codes .shape [1 ]
@@ -1013,7 +1019,9 @@ def set_init_inputs(self, speaker_audio=None, speaker_audio_lens=None, system_pr
10131019 [target_audio .size (- 1 )] * target_audio .size (0 ), dtype = torch .long , device = self .device
10141020 )
10151021 with ensures_target_precision (self .audio_codec_run_dtype ), torch .no_grad ():
1016- code , _ = self .audio_codec .encode (target_audio .unsqueeze (1 ).to (self .audio_codec_run_dtype ), target_audio_len )
1022+ code , _ = self .audio_codec .encode (
1023+ target_audio .unsqueeze (1 ).to (self .audio_codec_run_dtype ), target_audio_len
1024+ )
10171025
10181026 # get context hidden
10191027 if self .cfg .tts_config .context_hidden_size is not None :
@@ -1683,7 +1691,7 @@ def setup_audio_codec(model):
16831691 p .requires_grad = False
16841692
16851693 model .audio_codec .eval ()
1686- model .audio_codec .to (model .device ) # force codec to run in the same device as the main model
1694+ model .audio_codec .to (model .device ) # force codec to run in the same device as the main model
16871695
16881696 assert callable (model .tts_model .set_rvq_embs )
16891697
0 commit comments