@@ -163,15 +163,13 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
163163 k = key_seq .view (batch_dim_size , - 1 , self .num_heads , self .embed_dim_per_head ) # [B, T', #heads, F']
164164
165165 if self .learnable_pos_emb :
166- pos_seq_q = torch .arange (time_dim_size , device = input_tensor .device )
167- pos_seq_k = torch .arange (time_dim_size , device = input_tensor .device )
168-
169- distance_mat = pos_seq_k [None , :] - pos_seq_q [:, None ]
170- distance_mat_clipped = torch .clamp (distance_mat , - self .rel_pos_clip , self .rel_pos_clip )
171-
172- final_mat = distance_mat_clipped + self .rel_pos_clip
173-
174- rel_pos_embeddings = self .rel_pos_embeddings [final_mat ] # [T, T', pos_emb_dim]
166+ # 1D optimization: 2T-1 unique relative positions instead of T×T distance matrix.
167+ # Build [-(T-1), ..., -1, 0, 1, ..., T-1] directly.
168+ rel_pos = torch .arange (- (time_dim_size - 1 ), time_dim_size , device = input_tensor .device )
169+ indices = torch .clamp (rel_pos , - self .rel_pos_clip , self .rel_pos_clip ) + self .rel_pos_clip
170+ rel_pos_embeddings = self .rel_pos_embeddings [indices ].view (
171+ 1 , 2 * time_dim_size - 1 , self .pos_emb_dim
172+ ) # [1, T+T'-1, pos_emb_dim]
175173 else :
176174 rel_pos_embeddings = (
177175 self ._sinusoidal_pe (
@@ -207,9 +205,7 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
207205 q_with_bias_v ,
208206 rel_pos_embeddings .to (device = q_with_bias_v .device , dtype = q_with_bias_v .dtype ),
209207 ) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1]
210- if not self .learnable_pos_emb :
211- attn_bd = self ._rel_shift_bhij (attn_bd , k_len = time_dim_size ) # [B, #heads, T, T']
212-
208+ attn_bd = self ._rel_shift_bhij (attn_bd , k_len = time_dim_size ) # [B, #heads, T, T']
213209 # We use attn_mask to add BD matrix to attention scores.
214210 #
215211 # Inside torch's SDPA the mask is added after regular scaling, so to get correct
0 commit comments