1616
1717import torch
1818import torch .nn as nn
19- from torch .nn .attention .flex_attention import create_block_mask , flex_attention
19+ from torch .nn .attention .flex_attention import and_masks , create_block_mask , flex_attention
2020
2121flex_attention_compiled = torch .compile (flex_attention , dynamic = True )
2222
@@ -33,8 +33,8 @@ class TransformerEncoderConfig:
3333 ff_expansion : float = 4.0
3434 pre_block_norm : bool = True
3535 subsampling_factor : int = 4
36- # Attention mode — currently only "full" is supported .
37- # Future: "causal", " lookahead", "local", "sliding_window"
36+ # Attention mode: "full" (bidirectional) or "causal" (each token only attends to itself and earlier tokens) .
37+ # Future: "lookahead", "local", "sliding_window".
3838 attn_mode : str = "full"
3939
4040
@@ -47,6 +47,18 @@ def pad_mask(b, h, q_idx, kv_idx):
4747 return pad_mask
4848
4949
50+ def _make_causal_mod ():
51+ """Strictly causal — each query only attends to its own and earlier kv positions."""
52+
53+ def causal (b , h , q_idx , kv_idx ):
54+ return q_idx >= kv_idx
55+
56+ return causal
57+
58+
59+ _SUPPORTED_ATTN_MODES = ("full" , "causal" )
60+
61+
5062class FeatureStacking (nn .Module ):
5163 """Stacks consecutive input frames and projects to model dimension.
5264
@@ -174,7 +186,8 @@ class TransformerEncoder(nn.Module):
174186 such as Whisper or GPT-2 — required when loading pretrained weights from those
175187 checkpoints.
176188 subsampling_factor: Frame stacking factor for the pre-encoder.
177- attn_mode: Attention pattern — currently only "full" (bidirectional) is supported.
189+ attn_mode: Attention pattern — "full" (bidirectional, default) or "causal" (each token
190+ only attends to itself and earlier tokens).
178191 """
179192
180193 def __init__ (
@@ -194,8 +207,11 @@ def __init__(
194207 super ().__init__ ()
195208 if d_model % n_heads != 0 :
196209 raise ValueError (f"d_model ({ d_model } ) must be divisible by n_heads ({ n_heads } )." )
197- if attn_mode != "full" :
198- raise ValueError (f"attn_mode='{ attn_mode } ' is not yet supported. Currently only 'full' is available." )
210+ if attn_mode not in _SUPPORTED_ATTN_MODES :
211+ raise ValueError (
212+ f"attn_mode='{ attn_mode } ' is not yet supported. " f"Supported modes: { _SUPPORTED_ATTN_MODES } ."
213+ )
214+ self .attn_mode = attn_mode
199215
200216 cfg = TransformerEncoderConfig (
201217 feat_in = feat_in ,
@@ -231,7 +247,11 @@ def forward(self, audio_signal, length):
231247 x = self .embed_norm (x )
232248
233249 B , T , _ = x .shape
234- block_mask = create_block_mask (_make_padding_mod (length ), B = B , H = 1 , Q_LEN = T , KV_LEN = T , device = x .device )
250+ if self .attn_mode == "causal" :
251+ mask_mod = and_masks (_make_causal_mod (), _make_padding_mod (length ))
252+ else :
253+ mask_mod = _make_padding_mod (length )
254+ block_mask = create_block_mask (mask_mod , B = B , H = 1 , Q_LEN = T , KV_LEN = T , device = x .device )
235255
236256 for layer in self .layers :
237257 x = layer (x , block_mask = block_mask )
0 commit comments