Skip to content

Commit 2cc0f47

Browse files
committed
Apply isort and black reformatting
Signed-off-by: Edresson <Edresson@users.noreply.github.com>
1 parent 5b43960 commit 2cc0f47

2 files changed

Lines changed: 14 additions & 7 deletions

File tree

nemo/collections/common/data/lhotse/cutset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
from lhotse.utils import fastcopy
3535
from omegaconf import DictConfig, ListConfig, OmegaConf
3636

37-
from nemo.collections.speechlm2.parts.precision import fp32_precision
38-
3937
from nemo.collections.common.data.lhotse.nemo_adapters import (
4038
LazyNeMoIterator,
4139
LazyNeMoTarredIterator,
@@ -55,6 +53,7 @@
5553
TextTurn,
5654
)
5755
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
56+
from nemo.collections.speechlm2.parts.precision import fp32_precision
5857

5958

6059
def temperature_reweighting(weights: List[Union[float, int]], temperature: float = 1.0) -> List[float]:

nemo/collections/speechlm2/models/duplex_ear_tts.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def get_codec_silence_frame_last_one(self):
131131
audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame)
132132

133133
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
134-
sil_codes, sil_codes_lens = self.audio_codec.encode(audio.unsqueeze(1).to(self.audio_codec_run_dtype), audio_len)
134+
sil_codes, sil_codes_lens = self.audio_codec.encode(
135+
audio.unsqueeze(1).to(self.audio_codec_run_dtype), audio_len
136+
)
135137
return sil_codes[0, -1]
136138

137139
def get_codec_silence_frame(self):
@@ -142,7 +144,9 @@ def get_codec_silence_frame(self):
142144
audio, audio_len = self.pad_audio_to_factor(audio, audio_len, self.target_samples_per_frame)
143145

144146
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
145-
sil_codes, _ = self.audio_codec.encode(audio.unsqueeze(1).to(self.audio_codec_run_dtype), audio_len) # [1, T, C]
147+
sil_codes, _ = self.audio_codec.encode(
148+
audio.unsqueeze(1).to(self.audio_codec_run_dtype), audio_len
149+
) # [1, T, C]
146150
sil_codes = sil_codes[0] # [T, C]
147151

148152
# Convert each frame (C tokens) into a tuple
@@ -328,7 +332,9 @@ def prepare_inputs(self, batch: dict):
328332
target_audio, target_audio_lens, self.target_samples_per_frame, 1
329333
)
330334
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
331-
target_codes, target_codes_lens = self.audio_codec.encode(target_audio.unsqueeze(1).to(self.audio_codec_run_dtype), target_audio_lens)
335+
target_codes, target_codes_lens = self.audio_codec.encode(
336+
target_audio.unsqueeze(1).to(self.audio_codec_run_dtype), target_audio_lens
337+
)
332338

333339
with fp32_precision():
334340
target_len = target_codes.shape[1]
@@ -1013,7 +1019,9 @@ def set_init_inputs(self, speaker_audio=None, speaker_audio_lens=None, system_pr
10131019
[target_audio.size(-1)] * target_audio.size(0), dtype=torch.long, device=self.device
10141020
)
10151021
with ensures_target_precision(self.audio_codec_run_dtype), torch.no_grad():
1016-
code, _ = self.audio_codec.encode(target_audio.unsqueeze(1).to(self.audio_codec_run_dtype), target_audio_len)
1022+
code, _ = self.audio_codec.encode(
1023+
target_audio.unsqueeze(1).to(self.audio_codec_run_dtype), target_audio_len
1024+
)
10171025

10181026
# get context hidden
10191027
if self.cfg.tts_config.context_hidden_size is not None:
@@ -1683,7 +1691,7 @@ def setup_audio_codec(model):
16831691
p.requires_grad = False
16841692

16851693
model.audio_codec.eval()
1686-
model.audio_codec.to(model.device) # force codec to run in the same device as the main model
1694+
model.audio_codec.to(model.device) # force codec to run in the same device as the main model
16871695

16881696
assert callable(model.tts_model.set_rvq_embs)
16891697

0 commit comments

Comments
 (0)