Skip to content

Commit cae54b3

Browse files
authored
Permit text context with frame stacking (#15585)
If the configured context length is too short to fit commonly used text contexts we detect that and error out. Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com>
1 parent c42e444 commit cae54b3

1 file changed

Lines changed: 16 additions & 3 deletions

File tree

nemo/collections/tts/models/magpietts.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def check_frame_stacking_config_validity(self):
835835
Check if the configuration is compatible with frame stacking.
836836
"""
837837
if self.frame_stacking_factor > 1:
838-
# The settings below are not supported with frame stacking.
838+
# Reject configurations that are not supported with frame stacking.
839839
# Some of them may work - but they have not been tested.
840840

841841
# disallow alignment encoder
@@ -847,9 +847,22 @@ def check_frame_stacking_config_validity(self):
847847
# disallow training prior
848848
if self.cfg.prior_scaling_factor is not None and self.cfg.prior_scaling_factor > 0:
849849
raise ValueError("Training-time attention prior is not supported for frame stacking")
850-
# disallow text conditioning
850+
# With frame stacking, the audio context sequence length is divided by the
851+
# frame stacking factor (e.g., 108 tokens at 21fps --> 54 positions with 2x stacking).
852+
# The text context is NOT stacked but must fit within the same sequence length
853+
# as the audio context. If needed, this constraint could be likey be removed by also
854+
# stacking the text context, but that would require some experimentation.
851855
if self.use_text_conditioning_encoder:
852-
raise ValueError("Text conditioning is not supported for frame stacking")
856+
# Use 5 seconds as the baseline context length since it is known to fit
857+
# existing text contexts.
858+
min_required_context_sec = 5.0 * self.frame_stacking_factor
859+
actual_context_length_sec = self.cfg.get('context_duration_max')
860+
if actual_context_length_sec < min_required_context_sec:
861+
raise ValueError(
862+
f"With text context and a frame stacking factor of {self.frame_stacking_factor}, "
863+
f"context_duration_max must be >= {min_required_context_sec} seconds "
864+
f"(5 seconds x frame_stacking_factor); got context_duration_max={actual_context_length_sec}"
865+
)
853866

854867
@property
855868
def has_baked_context_embedding(self) -> bool:

0 commit comments

Comments
 (0)