Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b814a5f
WIP scaled dot product attention
dorian-K Nov 5, 2025
b69bbfd
wip
dorian-K Nov 7, 2025
1194018
wip
dorian-K Nov 7, 2025
9a2cae2
wip
dorian-K Nov 7, 2025
e5f636c
fix formatting
dorian-K Dec 19, 2025
a0627ec
more
dorian-K Dec 19, 2025
90316e9
more
dorian-K Dec 19, 2025
b91ea9d
fix tests
dorian-K Dec 19, 2025
54beb8f
more
dorian-K Dec 19, 2025
f9adbf3
more
dorian-K Dec 19, 2025
eada973
fix
dorian-K Dec 19, 2025
62f825e
fix pycharm
dorian-K Dec 19, 2025
6c8fd3f
Merge branch 'master' into doriank-sdpa
dorian-K Jan 9, 2026
c8bd06b
remove debug prints, update
dorian-K Jan 9, 2026
f08d16a
more
dorian-K Jan 9, 2026
2de26f8
use is_causal=True
dorian-K Jan 9, 2026
ecb79ed
add back a test
dorian-K Jan 9, 2026
a1b6c5d
more tests
dorian-K Jan 9, 2026
6eb0906
fix test
dorian-K Jan 9, 2026
c8b7c11
more tests
dorian-K Jan 16, 2026
b31e9b8
Merge branch 'master' into doriank-sdpa
dorian-K Jan 16, 2026
d6e25f9
fix formatting
dorian-K Jan 16, 2026
8480c3d
fix import
dorian-K Jan 16, 2026
6b35b43
remove some formatting only changes
dorian-K Jan 16, 2026
ef324be
more sdpa vs fallback tests
dorian-K Jan 16, 2026
4480aa5
Add benchmarks
dorian-K Jan 16, 2026
5469e1e
Merge branch 'master' into doriank-sdpa
dorian-K Jan 23, 2026
547d139
force flash att
dorian-K Jan 23, 2026
4962ced
also upgrade cross attention
dorian-K Jan 23, 2026
f91e5be
add back att_dropout_broadcast
dorian-K Jan 23, 2026
33a045b
convert some asserts to exceptions
dorian-K Jan 23, 2026
322112c
fix private member
dorian-K Jan 23, 2026
96470b4
oops
dorian-K Jan 23, 2026
6d3a585
reduce line length
dorian-K Jan 23, 2026
0162ef3
Make tensors contigous to enable fused kernels
dorian-K Jan 30, 2026
7da3b66
_embed_dim -> _feat_dim
dorian-K Jan 30, 2026
315fd86
pycharm
dorian-K Jan 30, 2026
a388048
Merge branch 'master' into doriank-sdpa
dorian-K Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions returnn/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,72 @@ def lstm(
"""
raise NotImplementedError

@classmethod
def scaled_dot_product_attention(
cls,
query: Tensor,
key: Tensor,
value: Tensor,
*,
attention_mask: Optional[Tensor] = None,
att_dropout: float = 0.0,
att_dropout_broadcast: bool,
v_feat_dim: Dim,
qk_feat_dim: Dim,
kv_spatial_dim: Dim,
query_spatial_dim: Dim,
is_causal: bool = False,
scale: Optional[float] = None,
):
"""
Scaled dot-product attention.

:param query:
:param key:
:param value:
:param attention_mask:
:param att_dropout: dropout for attention weights
:param att_dropout_broadcast: whether to broadcast over all but ``axis``.
normally not wanted. disabled by default since behavior version 19.
:param v_feat_dim: Embedding dimension of value
:param qk_feat_dim: Embedding dimension of key and query
:param kv_spatial_dim: Spatial axis of key/value to attend over
:param query_spatial_dim: Spatial axis of query
:param is_causal: Special case when the attention mask should be causal (e.g. for auto-regressive decoding).
Allows for more efficient implementation in some backends.
:param scale: Scaling factor applied prior to softmax
:return: attention output
"""
query *= qk_feat_dim.dimension**-0.5 if scale is None else scale

if is_causal:
if attention_mask is not None:
raise NotImplementedError("causal attention with attention_mask is not supported")
Comment thread
albertz marked this conversation as resolved.
if kv_spatial_dim not in query.dims_set:
raise ValueError("query and key/value must share the same spatial axis for causal attention")
hist_dim = Dim(rf.range_over_dim(kv_spatial_dim, device="cpu") + 1, name=f"{kv_spatial_dim.description}:kv")
key, _ = rf.replace_dim(key, in_dim=kv_spatial_dim, out_dim=hist_dim)
value, _ = rf.replace_dim(value, in_dim=kv_spatial_dim, out_dim=hist_dim)
kv_spatial_dim = hist_dim

attn_bias = None
if attention_mask is not None:
if attention_mask.dtype == "bool":
attn_bias = rf.where(attention_mask, 0.0, float("-inf"))
else:
attn_bias = attention_mask # assume float-like

energy = rf.matmul(query, key, reduce=qk_feat_dim) # [.., Q_spatial, K_spatial]
if attn_bias is not None:
energy = energy + attn_bias
att_weights = rf.softmax(energy, axis=kv_spatial_dim) # [.., Q_spatial, K_spatial]
att_weights = rf.dropout(att_weights, att_dropout, axis=att_dropout_broadcast and kv_spatial_dim)
# no need for mask because softmax already sets those weights to zero
att = rf.matmul(att_weights, value, reduce=kv_spatial_dim, use_mask=False)
if value.feature_dim in att.dims:
att.feature_dim = value.feature_dim
return att

# For eager-based backends, this is a reasonable default implementation and type.
TensorArrayType = List[Tensor]

Expand Down
203 changes: 178 additions & 25 deletions returnn/frontend/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


__all__ = [
"scaled_dot_product_attention",
"dot_attention",
"SelfAttentionBase",
"SelfAttention",
Expand All @@ -28,6 +29,94 @@
]


def scaled_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
*,
attention_mask: Optional[Tensor] = None,
att_dropout: float = 0.0,
att_dropout_broadcast: Optional[bool] = None,
v_feat_dim: Dim,
qk_feat_dim: Dim,
kv_spatial_dim: Dim,
query_spatial_dim: Dim,
is_causal: bool = False,
scale: Optional[float] = None,
):
"""
Scaled dot-product attention.

:param query:
:param key:
:param value:
:param attention_mask:
:param att_dropout: dropout for attention weights
:param att_dropout_broadcast: whether to broadcast over all but ``axis``.
normally not wanted. disabled by default since behavior version 19.
:param v_feat_dim: Embedding dimension of value
:param qk_feat_dim: Embedding dimension of key (and query)
:param kv_spatial_dim: Spatial axis of key/value to attend over
:param query_spatial_dim: Spatial axis of query
:param is_causal: Special case when the attention mask should be causal (e.g. for auto-regressive decoding).
Allows for more efficient implementation in some backends.
:param scale: Scaling factor applied prior to softmax
:return: attention output
"""
if att_dropout_broadcast is None:
att_dropout_broadcast = _att_dropout_broadcast_default()
# noinspection PyProtectedMember
att = query._raw_backend.scaled_dot_product_attention(
query,
key,
value,
attention_mask=attention_mask,
att_dropout=att_dropout,
att_dropout_broadcast=att_dropout_broadcast,
v_feat_dim=v_feat_dim,
qk_feat_dim=qk_feat_dim,
kv_spatial_dim=kv_spatial_dim,
query_spatial_dim=query_spatial_dim,
is_causal=is_causal,
scale=scale,
)
return att


def _infer_att_dims(
query: Tensor, keys: Tensor, values: Tensor, *, qk_feat_dim: Dim, kv_spatial_dim: Dim
) -> Tuple[Tensor, Dim, Dim, bool]:
if kv_spatial_dim not in keys.dims_set:
raise ValueError(f"kv_spat_dim {kv_spatial_dim} not in keys.dims {keys.dims}")

# infer query spatial dim, necessary for pytorch backend
query_non_batch_dims = query.remaining_dims(keys.dims_set - {kv_spatial_dim})
if len(query_non_batch_dims) == 0:
query_spatial = Dim(1, name="dot_att_query_spatial_dummy")
query = rf.expand_dim(query, dim=query_spatial)
else:
if len(query_non_batch_dims) != 1:
raise ValueError(
"Query vector must have exactly one non-batch dim (the spatial dimension), "
f" got {query.dims}, keys.dims={keys.dims}"
)
query_spatial = query_non_batch_dims[0]

# infer dot product dim (v_feat_dim)
v_feat_dim = values.feature_dim
if v_feat_dim is None:
if qk_feat_dim in values.dims_set:
v_feat_dim = qk_feat_dim
else:
possible_feat_dims = values.dims_set - keys.dims_set
if len(possible_feat_dims) == 1:
v_feat_dim = list(possible_feat_dims)[0]
else:
raise ValueError(f"Cannot infer v_feat_dim from values.dims={values.dims}, keys.dims={keys.dims}")

return query, v_feat_dim, query_spatial, len(query_non_batch_dims) == 0


def dot_attention(
query: Tensor,
keys: Tensor,
Expand Down Expand Up @@ -55,17 +144,25 @@ def dot_attention(
normally not wanted. disabled by default since behavior version 19.
:return: like values but with axis removed, and maybe any additional axes from query
"""
query *= key_dim.dimension**-0.5
energy = rf.matmul(query, keys, reduce=key_dim)
att_weights = rf.softmax(energy, axis=axis)
if att_dropout_broadcast is None:
att_dropout_broadcast = _att_dropout_broadcast_default()
att_weights = rf.dropout(att_weights, att_dropout, axis=att_dropout_broadcast and axis)
# Masking not needed because softmax should already have masked,
# so we have 0.0 att weights for padded frames.
att = rf.matmul(att_weights, values, reduce=axis, use_mask=False)
if values.feature_dim in att.dims:
att.feature_dim = values.feature_dim
query, v_feat_dim, query_spatial, added_dummy_spat_dim_to_query = _infer_att_dims(
query, keys, values, qk_feat_dim=key_dim, kv_spatial_dim=axis
)

att = scaled_dot_product_attention(
query,
keys,
values,
att_dropout=att_dropout,
att_dropout_broadcast=att_dropout_broadcast,
v_feat_dim=v_feat_dim,
qk_feat_dim=key_dim,
kv_spatial_dim=axis,
query_spatial_dim=query_spatial,
is_causal=False,
)
if added_dummy_spat_dim_to_query:
att = rf.squeeze(att, axis=query_spatial)

return att


Expand All @@ -75,7 +172,7 @@ class SelfAttentionBase(rf.Module):
Shared base class for (non-causal) self attention (:class:`SelfAttention`)
and causal self attention (:class:`CausalSelfAttention`).

It uses :func:`dot_attention` for multi-headed dot-attention.
It uses :func:`scaled_dot_product_attention` for multi-headed dot-attention.
"""

def __init__(
Expand Down Expand Up @@ -156,15 +253,25 @@ def forward_qkv(self, source: Tensor) -> Tuple[Tensor, Tensor, Tensor]:

def attention(self, q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) -> Tensor:
"""apply attention"""
att = dot_attention(
q,
query, v_feat_dim, query_spatial, added_dummy_spat_dim_to_query = _infer_att_dims(
q, k, v, qk_feat_dim=self.key_dim_per_head, kv_spatial_dim=kv_axis
)

att = scaled_dot_product_attention(
query,
k,
v,
key_dim=self.key_dim_per_head,
axis=kv_axis,
att_dropout=self.att_dropout,
att_dropout_broadcast=self.att_dropout_broadcast,
v_feat_dim=v_feat_dim,
qk_feat_dim=self.key_dim_per_head,
kv_spatial_dim=kv_axis,
query_spatial_dim=query_spatial,
is_causal=False,
)
if added_dummy_spat_dim_to_query:
att = rf.squeeze(att, axis=query_spatial)

output, _ = rf.merge_dims(att, dims=(self.num_heads, self.value_dim_per_head), out_dim=self.value_dim_total)
if self.proj:
output = self.proj(output)
Expand Down Expand Up @@ -219,6 +326,34 @@ def default_initial_state(self, *, batch_dims: Sequence[Dim]) -> CausalSelfAtten
accum_axis=expand_dim,
)

def attention(self, q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) -> Tensor:
"""apply attention"""
query, v_feat_dim, query_spatial, added_dummy_spat_dim_to_query = _infer_att_dims(
q, k, v, qk_feat_dim=self.key_dim_per_head, kv_spatial_dim=kv_axis
)

is_causal = query_spatial is kv_axis # i.e. we are not in single-step mode

att = scaled_dot_product_attention(
query,
k,
v,
att_dropout=self.att_dropout,
att_dropout_broadcast=self.att_dropout_broadcast,
v_feat_dim=v_feat_dim,
qk_feat_dim=self.key_dim_per_head,
kv_spatial_dim=kv_axis,
query_spatial_dim=query_spatial,
is_causal=is_causal,
)
if added_dummy_spat_dim_to_query:
att = rf.squeeze(att, axis=query_spatial)

output, _ = rf.merge_dims(att, dims=(self.num_heads, self.value_dim_per_head), out_dim=self.value_dim_total)
if self.proj:
output = self.proj(output)
return output


def _causal_self_att_step(
k: Tensor,
Expand All @@ -227,6 +362,7 @@ def _causal_self_att_step(
axis: Dim,
state: Optional[CausalSelfAttentionState],
self: rf.Module,
with_causal_masking: bool = False,
) -> Tuple[Tensor, Tensor, Dim, CausalSelfAttentionState]:
new_state = CausalSelfAttentionState()
if axis == single_step_dim:
Expand All @@ -247,9 +383,14 @@ def _causal_self_att_step(
new_state.v_accum = v
new_state.accum_axis = axis
# See CumConcatLayer and https://github.com/rwth-i6/returnn/issues/391 for the idea.
hist_dim = Dim(rf.range_over_dim(axis, device="cpu") + 1, name=f"{axis.description}:kv")
k, _ = rf.replace_dim(k, in_dim=axis, out_dim=hist_dim)
v, _ = rf.replace_dim(v, in_dim=axis, out_dim=hist_dim)
if with_causal_masking:
# no longer needed if is_causal=True in scaled_dot_product_attention
hist_dim = Dim(rf.range_over_dim(axis, device="cpu") + 1, name=f"{axis.description}:kv")
k, _ = rf.replace_dim(k, in_dim=axis, out_dim=hist_dim)
v, _ = rf.replace_dim(v, in_dim=axis, out_dim=hist_dim)
else:
hist_dim = axis

return k, v, hist_dim, new_state


Expand Down Expand Up @@ -605,7 +746,9 @@ def __call__(
) -> Tuple[Tensor, CausalSelfAttentionState]:
"""forward"""
q, k, v = self.forward_qkv(source)
k, v, hist_dim, new_state = _causal_self_att_step(k, v, axis=axis, state=state, self=self)
k, v, hist_dim, new_state = _causal_self_att_step(
k, v, axis=axis, state=state, self=self, with_causal_masking=True
)

if self.learned_pos_emb is not None:
pos_emb, pos_emb_spatial_dim = self.learned_pos_emb(query_spatial_dim=axis, key_value_spatial_dim=hist_dim)
Expand Down Expand Up @@ -656,7 +799,7 @@ class CrossAttention(rf.Module):
"""
Cross attention

It uses :func:`dot_attention` for multi-headed dot-attention.
It uses :func:`scaled_dot_product_attention` for multi-headed dot-attention.
"""

def __init__(
Expand Down Expand Up @@ -765,15 +908,25 @@ def __call__(self, q: Tensor, encoder: rf.State) -> Tensor:

def attention(self, q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) -> Tensor:
"""apply attention"""
att = dot_attention(
q,
query, v_feat_dim, query_spatial, added_dummy_spat_dim_to_query = _infer_att_dims(
q, k, v, qk_feat_dim=self.key_dim_per_head, kv_spatial_dim=kv_axis
)

att = scaled_dot_product_attention(
query,
k,
v,
key_dim=self.key_dim_per_head,
axis=kv_axis,
att_dropout=self.att_dropout,
att_dropout_broadcast=self.att_dropout_broadcast,
v_feat_dim=v_feat_dim,
qk_feat_dim=self.key_dim_per_head,
kv_spatial_dim=kv_axis,
query_spatial_dim=query_spatial,
is_causal=False,
)
if added_dummy_spat_dim_to_query:
att = rf.squeeze(att, axis=query_spatial)

output, _ = rf.merge_dims(att, dims=(self.num_heads, self.value_dim_per_head), out_dim=self.value_dim_total)
if self.proj:
output = self.proj(output)
Expand Down
Loading