Skip to content

Commit 7ec53cc

Browse files
JudyxujjJingjing Xu
andauthored
Improve relative positional encoding implementation (#95)
a more efficient implementation of the relative positional encoding --------- Co-authored-by: Jingjing Xu <jxu@gw-02.apptek.local>
1 parent f347906 commit 7ec53cc

1 file changed

Lines changed: 8 additions & 12 deletions

File tree

i6_models/parts/conformer/mhsa_rel_pos.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,13 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
163163
k = key_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T', #heads, F']
164164

165165
if self.learnable_pos_emb:
166-
pos_seq_q = torch.arange(time_dim_size, device=input_tensor.device)
167-
pos_seq_k = torch.arange(time_dim_size, device=input_tensor.device)
168-
169-
distance_mat = pos_seq_k[None, :] - pos_seq_q[:, None]
170-
distance_mat_clipped = torch.clamp(distance_mat, -self.rel_pos_clip, self.rel_pos_clip)
171-
172-
final_mat = distance_mat_clipped + self.rel_pos_clip
173-
174-
rel_pos_embeddings = self.rel_pos_embeddings[final_mat] # [T, T', pos_emb_dim]
166+
# 1D optimization: 2T-1 unique relative positions instead of T×T distance matrix.
167+
# Build [-(T-1), ..., -1, 0, 1, ..., T-1] directly.
168+
rel_pos = torch.arange(-(time_dim_size - 1), time_dim_size, device=input_tensor.device)
169+
indices = torch.clamp(rel_pos, -self.rel_pos_clip, self.rel_pos_clip) + self.rel_pos_clip
170+
rel_pos_embeddings = self.rel_pos_embeddings[indices].view(
171+
1, 2 * time_dim_size - 1, self.pos_emb_dim
172+
) # [1, T+T'-1, pos_emb_dim]
175173
else:
176174
rel_pos_embeddings = (
177175
self._sinusoidal_pe(
@@ -207,9 +205,7 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
207205
q_with_bias_v,
208206
rel_pos_embeddings.to(device=q_with_bias_v.device, dtype=q_with_bias_v.dtype),
209207
) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1]
210-
if not self.learnable_pos_emb:
211-
attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T']
212-
208+
attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T']
213209
# We use attn_mask to add BD matrix to attention scores.
214210
#
215211
# Inside torch's SDPA the mask is added after regular scaling, so to get correct

0 commit comments

Comments
 (0)