Skip to content

Commit f1f3182

Browse files
committed
LTX2 context windows - Fix audio index mapping for wrapped/strided primary windows
The previous window-level calculation collapsed wrapped or strided primary windows into a contiguous audio tail, so audio attended to a different temporal region than the video. Replace with per-frame mapping that computes each primary index's audio span independently and concatenates in order.
1 parent ae3830a commit f1f3182

1 file changed

Lines changed: 15 additions & 8 deletions

File tree

comfy/model_base.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,17 +1088,24 @@ def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
10881088
return result
10891089

10901090
video_total = latent_shapes[0][dim]
1091-
video_window_len = len(primary_indices)
10921091

10931092
for i in range(1, len(latent_shapes)):
10941093
mod_total = latent_shapes[i][dim]
1095-
# Length proportional to video window frame count
1096-
mod_window_len = max(round(video_window_len * mod_total / video_total), 1)
1097-
# Anchor to end of video range
1098-
v_end = max(primary_indices) + 1
1099-
mod_end = min(round(v_end * mod_total / video_total), mod_total)
1100-
mod_start = max(mod_end - mod_window_len, 0)
1101-
result.append(list(range(mod_start, min(mod_start + mod_window_len, mod_total))))
1094+
# Map each primary index to its proportional range of modality indices and
1095+
# concatenate in order. Preserves wrapped/strided geometry so the modality
1096+
# attends to the same temporal regions as the primary window.
1097+
mod_indices = []
1098+
seen = set()
1099+
for v_idx in primary_indices:
1100+
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
1101+
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
1102+
if a_end <= a_start:
1103+
a_end = a_start + 1
1104+
for a in range(a_start, a_end):
1105+
if a not in seen:
1106+
seen.add(a)
1107+
mod_indices.append(a)
1108+
result.append(mod_indices)
11021109

11031110
return result
11041111

0 commit comments

Comments
 (0)