Skip to content

Commit 6a0ecb7

Browse files
bring back zeroshot disabling code (#15564)
Signed-off-by: Paarth Neekhara <paarth.n@gmail.com>
1 parent 64f01b9 commit 6a0ecb7

1 file changed

Lines changed: 22 additions & 3 deletions

File tree

nemo/collections/tts/models/magpietts.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,10 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
619619
self.register_buffer('_baked_embedding_T', None) # Time dimension
620620
self.register_buffer('_baked_embedding_D', None) # Embedding dimension
621621
self.register_buffer('baked_context_embedding_len', None) # Per-speaker lengths (N,)
622+
# Probability of bypassing the context encoder during training and instead feeding
623+
# batch-shuffled raw context embeddings, so the model learns not to clone voices
624+
# from untransformed (i.e. not encoded by the context encoder) input.
625+
self.train_shuffle_context_embedding_prob = cfg.get('train_shuffle_context_embedding_prob', 0.0)
622626
else:
623627
raise ValueError(f"Unsupported model type {self.model_type}")
624628

@@ -1773,9 +1777,24 @@ def _prepare_decoder_context(
17731777
context_input_lens = context_input_lens.to(text.device)
17741778
context_mask = get_mask_from_lengths(context_input_lens)
17751779
else:
1776-
context_embeddings = self.context_encoder(
1777-
context_input_embedded, context_mask, cond=None, cond_mask=None
1778-
)['output']
1780+
# Zero-shot disable: with some probability, bypass the context encoder and feed
1781+
# batch-shuffled raw embeddings so the model learns to not clone from untransformed input.
1782+
# Skip when batch_size == 1: rolling a single sample maps it back to itself,
1783+
# so the context would remain matched to the correct speaker.
1784+
batch_size = context_input_embedded.size(0)
1785+
if (
1786+
self.training
1787+
and batch_size > 1
1788+
and self.train_shuffle_context_embedding_prob > 0
1789+
and random.random() < self.train_shuffle_context_embedding_prob
1790+
):
1791+
shift = random.randint(1, batch_size - 1)
1792+
context_embeddings = context_input_embedded.roll(shift, dims=0)
1793+
context_mask = context_mask.roll(shift, dims=0)
1794+
else:
1795+
context_embeddings = self.context_encoder(
1796+
context_input_embedded, context_mask, cond=None, cond_mask=None
1797+
)['output']
17791798
else:
17801799
raise ValueError(f"Unsupported model type for decoder context: {self.model_type}")
17811800

0 commit comments

Comments
 (0)