Skip to content

Commit 2ea3e0f

Browse files
authored
Causal transformer (#15730)
* Add full causal support Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * fix black/isort Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> --------- Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
1 parent e6df25c commit 2ea3e0f

2 files changed

Lines changed: 65 additions & 8 deletions

File tree

nemo/collections/asr/modules/transformer_encoder.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818
import torch.nn as nn
19-
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
19+
from torch.nn.attention.flex_attention import and_masks, create_block_mask, flex_attention
2020

2121
flex_attention_compiled = torch.compile(flex_attention, dynamic=True)
2222

@@ -33,8 +33,8 @@ class TransformerEncoderConfig:
3333
ff_expansion: float = 4.0
3434
pre_block_norm: bool = True
3535
subsampling_factor: int = 4
36-
# Attention mode — currently only "full" is supported.
37-
# Future: "causal", "lookahead", "local", "sliding_window"
36+
# Attention mode: "full" (bidirectional) or "causal" (each token only attends to itself and earlier tokens).
37+
# Future: "lookahead", "local", "sliding_window".
3838
attn_mode: str = "full"
3939

4040

@@ -47,6 +47,18 @@ def pad_mask(b, h, q_idx, kv_idx):
4747
return pad_mask
4848

4949

50+
def _make_causal_mod():
51+
"""Strictly causal — each query only attends to its own and earlier kv positions."""
52+
53+
def causal(b, h, q_idx, kv_idx):
54+
return q_idx >= kv_idx
55+
56+
return causal
57+
58+
59+
_SUPPORTED_ATTN_MODES = ("full", "causal")
60+
61+
5062
class FeatureStacking(nn.Module):
5163
"""Stacks consecutive input frames and projects to model dimension.
5264
@@ -174,7 +186,8 @@ class TransformerEncoder(nn.Module):
174186
such as Whisper or GPT-2 — required when loading pretrained weights from those
175187
checkpoints.
176188
subsampling_factor: Frame stacking factor for the pre-encoder.
177-
attn_mode: Attention pattern — currently only "full" (bidirectional) is supported.
189+
attn_mode: Attention pattern — "full" (bidirectional, default) or "causal" (each token
190+
only attends to itself and earlier tokens).
178191
"""
179192

180193
def __init__(
@@ -194,8 +207,11 @@ def __init__(
194207
super().__init__()
195208
if d_model % n_heads != 0:
196209
raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads}).")
197-
if attn_mode != "full":
198-
raise ValueError(f"attn_mode='{attn_mode}' is not yet supported. Currently only 'full' is available.")
210+
if attn_mode not in _SUPPORTED_ATTN_MODES:
211+
raise ValueError(
212+
f"attn_mode='{attn_mode}' is not yet supported. " f"Supported modes: {_SUPPORTED_ATTN_MODES}."
213+
)
214+
self.attn_mode = attn_mode
199215

200216
cfg = TransformerEncoderConfig(
201217
feat_in=feat_in,
@@ -231,7 +247,11 @@ def forward(self, audio_signal, length):
231247
x = self.embed_norm(x)
232248

233249
B, T, _ = x.shape
234-
block_mask = create_block_mask(_make_padding_mod(length), B=B, H=1, Q_LEN=T, KV_LEN=T, device=x.device)
250+
if self.attn_mode == "causal":
251+
mask_mod = and_masks(_make_causal_mod(), _make_padding_mod(length))
252+
else:
253+
mask_mod = _make_padding_mod(length)
254+
block_mask = create_block_mask(mask_mod, B=B, H=1, Q_LEN=T, KV_LEN=T, device=x.device)
235255

236256
for layer in self.layers:
237257
x = layer(x, block_mask=block_mask)

tests/collections/asr/test_transformer_encoder.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,44 @@ def test_model_creation_without_qk_norm(self):
124124
@pytest.mark.unit
125125
def test_invalid_attn_mode(self):
126126
with pytest.raises(ValueError, match="not yet supported"):
127-
TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, attn_mode="causal")
127+
TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, attn_mode="sliding_window")
128+
129+
@pytest.mark.unit
130+
def test_causal_forward_cpu(self):
131+
model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, attn_mode="causal")
132+
model.eval()
133+
134+
x = torch.randn(2, 80, 400)
135+
lengths = torch.tensor([400, 300])
136+
137+
with torch.no_grad():
138+
out, out_lengths = model(x, lengths)
139+
140+
assert out.shape == (2, 64, 100)
141+
assert out_lengths.tolist() == [100, 75]
142+
assert not torch.isnan(out).any()
143+
144+
@pytest.mark.unit
145+
def test_causal_future_does_not_affect_past(self):
146+
"""Output at position t must be invariant to changes at positions > t."""
147+
model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, attn_mode="causal")
148+
model.eval()
149+
150+
B, C, T = 1, 80, 400
151+
x_a = torch.randn(B, C, T)
152+
x_b = x_a.clone()
153+
# Perturb only the second half of frames.
154+
x_b[:, :, T // 2 :] = torch.randn(B, C, T - T // 2)
155+
lengths = torch.tensor([T])
156+
157+
with torch.no_grad():
158+
out_a, _ = model(x_a, lengths)
159+
out_b, _ = model(x_b, lengths)
160+
161+
# Output frames covering only past + present should be identical.
162+
# First half of *output* frames corresponds to first half of input frames after subsampling.
163+
safe_t = (T // 2) // model.pre_encode.subsampling_factor
164+
assert torch.allclose(out_a[:, :, :safe_t], out_b[:, :, :safe_t], atol=1e-5)
128165

129166
@pytest.mark.unit
130167
def test_forward_cpu(self):

0 commit comments

Comments
 (0)