@@ -619,6 +619,10 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
619619 self .register_buffer ('_baked_embedding_T' , None ) # Time dimension
620620 self .register_buffer ('_baked_embedding_D' , None ) # Embedding dimension
621621 self .register_buffer ('baked_context_embedding_len' , None ) # Per-speaker lengths (N,)
622+ # Probability of bypassing the context encoder during training and instead feeding
623+ # batch-shuffled raw context embeddings, so the model learns not to clone voices
624+ # from untransformed (i.e. not encoded by the context encoder) input.
625+ self .train_shuffle_context_embedding_prob = cfg .get ('train_shuffle_context_embedding_prob' , 0.0 )
622626 else :
623627 raise ValueError (f"Unsupported model type { self .model_type } " )
624628
@@ -1773,9 +1777,24 @@ def _prepare_decoder_context(
17731777 context_input_lens = context_input_lens .to (text .device )
17741778 context_mask = get_mask_from_lengths (context_input_lens )
17751779 else :
1776- context_embeddings = self .context_encoder (
1777- context_input_embedded , context_mask , cond = None , cond_mask = None
1778- )['output' ]
1780+ # Zero-shot disable: with some probability, bypass the context encoder and feed
1781+ # batch-shuffled raw embeddings so the model learns to not clone from untransformed input.
1782+ # Skip when batch_size == 1: rolling a single sample maps it back to itself,
1783+ # so the context would remain matched to the correct speaker.
1784+ batch_size = context_input_embedded .size (0 )
1785+ if (
1786+ self .training
1787+ and batch_size > 1
1788+ and self .train_shuffle_context_embedding_prob > 0
1789+ and random .random () < self .train_shuffle_context_embedding_prob
1790+ ):
1791+ shift = random .randint (1 , batch_size - 1 )
1792+ context_embeddings = context_input_embedded .roll (shift , dims = 0 )
1793+ context_mask = context_mask .roll (shift , dims = 0 )
1794+ else :
1795+ context_embeddings = self .context_encoder (
1796+ context_input_embedded , context_mask , cond = None , cond_mask = None
1797+ )['output' ]
17791798 else :
17801799 raise ValueError (f"Unsupported model type for decoder context: { self .model_type } " )
17811800
0 commit comments