diff --git a/returnn/frontend/_backend.py b/returnn/frontend/_backend.py index fd6a1aab3..9b6fea0c2 100644 --- a/returnn/frontend/_backend.py +++ b/returnn/frontend/_backend.py @@ -1369,6 +1369,72 @@ def lstm( """ raise NotImplementedError + @classmethod + def scaled_dot_product_attention( + cls, + query: Tensor, + key: Tensor, + value: Tensor, + *, + attention_mask: Optional[Tensor] = None, + att_dropout: float = 0.0, + att_dropout_broadcast: bool, + v_feat_dim: Dim, + qk_feat_dim: Dim, + kv_spatial_dim: Dim, + query_spatial_dim: Dim, + is_causal: bool = False, + scale: Optional[float] = None, + ): + """ + Scaled dot-product attention. + + :param query: + :param key: + :param value: + :param attention_mask: + :param att_dropout: dropout for attention weights + :param att_dropout_broadcast: whether to broadcast over all but ``axis``. + normally not wanted. disabled by default since behavior version 19. + :param v_feat_dim: Embedding dimension of value + :param qk_feat_dim: Embedding dimension of key and query + :param kv_spatial_dim: Spatial axis of key/value to attend over + :param query_spatial_dim: Spatial axis of query + :param is_causal: Special case when the attention mask should be causal (e.g. for auto-regressive decoding). + Allows for more efficient implementation in some backends. + :param scale: Scaling factor applied prior to softmax + :return: attention output + """ + query *= qk_feat_dim.dimension**-0.5 if scale is None else scale + + if is_causal: + if attention_mask is not None: + raise NotImplementedError("causal attention with attention_mask is not supported") + if kv_spatial_dim not in query.dims_set: + raise ValueError("query and key/value must share the same spatial axis for causal attention") + hist_dim = Dim(rf.range_over_dim(kv_spatial_dim, device="cpu") + 1, name=f"{kv_spatial_dim.description}:kv") + key, _ = rf.replace_dim(key, in_dim=kv_spatial_dim, out_dim=hist_dim) + value, _ = rf.replace_dim(value, in_dim=kv_spatial_dim, out_dim=hist_dim) + kv_spatial_dim = hist_dim + + attn_bias = None + if attention_mask is not None: + if attention_mask.dtype == "bool": + attn_bias = rf.where(attention_mask, 0.0, float("-inf")) + else: + attn_bias = attention_mask # assume float-like + + energy = rf.matmul(query, key, reduce=qk_feat_dim) # [.., Q_spatial, K_spatial] + if attn_bias is not None: + energy = energy + attn_bias + att_weights = rf.softmax(energy, axis=kv_spatial_dim) # [.., Q_spatial, K_spatial] + att_weights = rf.dropout(att_weights, att_dropout, axis=att_dropout_broadcast and kv_spatial_dim) + # no need for mask because softmax already sets those weights to zero + att = rf.matmul(att_weights, value, reduce=kv_spatial_dim, use_mask=False) + if value.feature_dim in att.dims: + att.feature_dim = value.feature_dim + return att + # For eager-based backends, this is a reasonable default implementation and type. TensorArrayType = List[Tensor] diff --git a/returnn/frontend/attention.py b/returnn/frontend/attention.py index a9d01bdd4..dc54f148b 100644 --- a/returnn/frontend/attention.py +++ b/returnn/frontend/attention.py @@ -11,6 +11,7 @@ __all__ = [ + "scaled_dot_product_attention", "dot_attention", "SelfAttentionBase", "SelfAttention", @@ -28,6 +29,94 @@ ] +def scaled_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + *, + attention_mask: Optional[Tensor] = None, + att_dropout: float = 0.0, + att_dropout_broadcast: Optional[bool] = None, + v_feat_dim: Dim, + qk_feat_dim: Dim, + kv_spatial_dim: Dim, + query_spatial_dim: Dim, + is_causal: bool = False, + scale: Optional[float] = None, +): + """ + Scaled dot-product attention. + + :param query: + :param key: + :param value: + :param attention_mask: + :param att_dropout: dropout for attention weights + :param att_dropout_broadcast: whether to broadcast over all but ``axis``. + normally not wanted. disabled by default since behavior version 19. + :param v_feat_dim: Embedding dimension of value + :param qk_feat_dim: Embedding dimension of key (and query) + :param kv_spatial_dim: Spatial axis of key/value to attend over + :param query_spatial_dim: Spatial axis of query + :param is_causal: Special case when the attention mask should be causal (e.g. for auto-regressive decoding). + Allows for more efficient implementation in some backends. + :param scale: Scaling factor applied prior to softmax + :return: attention output + """ + if att_dropout_broadcast is None: + att_dropout_broadcast = _att_dropout_broadcast_default() + # noinspection PyProtectedMember + att = query._raw_backend.scaled_dot_product_attention( + query, + key, + value, + attention_mask=attention_mask, + att_dropout=att_dropout, + att_dropout_broadcast=att_dropout_broadcast, + v_feat_dim=v_feat_dim, + qk_feat_dim=qk_feat_dim, + kv_spatial_dim=kv_spatial_dim, + query_spatial_dim=query_spatial_dim, + is_causal=is_causal, + scale=scale, + ) + return att + + +def _infer_att_dims( + query: Tensor, keys: Tensor, values: Tensor, *, qk_feat_dim: Dim, kv_spatial_dim: Dim +) -> Tuple[Tensor, Dim, Dim, bool]: + if kv_spatial_dim not in keys.dims_set: + raise ValueError(f"kv_spat_dim {kv_spatial_dim} not in keys.dims {keys.dims}") + + # infer query spatial dim, necessary for pytorch backend + query_non_batch_dims = query.remaining_dims(keys.dims_set - {kv_spatial_dim}) + if len(query_non_batch_dims) == 0: + query_spatial = Dim(1, name="dot_att_query_spatial_dummy") + query = rf.expand_dim(query, dim=query_spatial) + else: + if len(query_non_batch_dims) != 1: + raise ValueError( + "Query vector must have exactly one non-batch dim (the spatial dimension), " + f" got {query.dims}, keys.dims={keys.dims}" + ) + query_spatial = query_non_batch_dims[0] + + # infer dot product dim (v_feat_dim) + v_feat_dim = values.feature_dim + if v_feat_dim is None: + if qk_feat_dim in values.dims_set: + v_feat_dim = qk_feat_dim + else: + possible_feat_dims = values.dims_set - keys.dims_set + if len(possible_feat_dims) == 1: + v_feat_dim = list(possible_feat_dims)[0] + else: + raise ValueError(f"Cannot infer v_feat_dim from values.dims={values.dims}, keys.dims={keys.dims}") + + return query, v_feat_dim, query_spatial, len(query_non_batch_dims) == 0 + + def dot_attention( query: Tensor, keys: Tensor, @@ -55,17 +144,25 @@ def dot_attention( normally not wanted. disabled by default since behavior version 19. :return: like values but with axis removed, and maybe any additional axes from query """ - query *= key_dim.dimension**-0.5 - energy = rf.matmul(query, keys, reduce=key_dim) - att_weights = rf.softmax(energy, axis=axis) - if att_dropout_broadcast is None: - att_dropout_broadcast = _att_dropout_broadcast_default() - att_weights = rf.dropout(att_weights, att_dropout, axis=att_dropout_broadcast and axis) - # Masking not needed because softmax should already have masked, - # so we have 0.0 att weights for padded frames. - att = rf.matmul(att_weights, values, reduce=axis, use_mask=False) - if values.feature_dim in att.dims: - att.feature_dim = values.feature_dim + query, v_feat_dim, query_spatial, added_dummy_spat_dim_to_query = _infer_att_dims( + query, keys, values, qk_feat_dim=key_dim, kv_spatial_dim=axis + ) + + att = scaled_dot_product_attention( + query, + keys, + values, + att_dropout=att_dropout, + att_dropout_broadcast=att_dropout_broadcast, + v_feat_dim=v_feat_dim, + qk_feat_dim=key_dim, + kv_spatial_dim=axis, + query_spatial_dim=query_spatial, + is_causal=False, + ) + if added_dummy_spat_dim_to_query: + att = rf.squeeze(att, axis=query_spatial) + return att @@ -75,7 +172,7 @@ class SelfAttentionBase(rf.Module): Shared base class for (non-causal) self attention (:class:`SelfAttention`) and causal self attention (:class:`CausalSelfAttention`). - It uses :func:`dot_attention` for multi-headed dot-attention. + It uses :func:`scaled_dot_product_attention` for multi-headed dot-attention. """ def __init__( @@ -156,15 +253,25 @@ def forward_qkv(self, source: Tensor) -> Tuple[Tensor, Tensor, Tensor]: def attention(self, q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) -> Tensor: """apply attention""" - att = dot_attention( - q, + query, v_feat_dim, query_spatial, added_dummy_spat_dim_to_query = _infer_att_dims( + q, k, v, qk_feat_dim=self.key_dim_per_head, kv_spatial_dim=kv_axis + ) + + att = scaled_dot_product_attention( + query, k, v, - key_dim=self.key_dim_per_head, - axis=kv_axis, att_dropout=self.att_dropout, att_dropout_broadcast=self.att_dropout_broadcast, + v_feat_dim=v_feat_dim, + qk_feat_dim=self.key_dim_per_head, + kv_spatial_dim=kv_axis, + query_spatial_dim=query_spatial, + is_causal=False, ) + if added_dummy_spat_dim_to_query: + att = rf.squeeze(att, axis=query_spatial) + output, _ = rf.merge_dims(att, dims=(self.num_heads, self.value_dim_per_head), out_dim=self.value_dim_total) if self.proj: output = self.proj(output) @@ -219,6 +326,34 @@ def default_initial_state(self, *, batch_dims: Sequence[Dim]) -> CausalSelfAtten accum_axis=expand_dim, ) + def attention(self, q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) -> Tensor: + """apply attention""" + query, v_feat_dim, query_spatial, added_dummy_spat_dim_to_query = _infer_att_dims( + q, k, v, qk_feat_dim=self.key_dim_per_head, kv_spatial_dim=kv_axis + ) + + is_causal = query_spatial is kv_axis # i.e. we are not in single-step mode + + att = scaled_dot_product_attention( + query, + k, + v, + att_dropout=self.att_dropout, + att_dropout_broadcast=self.att_dropout_broadcast, + v_feat_dim=v_feat_dim, + qk_feat_dim=self.key_dim_per_head, + kv_spatial_dim=kv_axis, + query_spatial_dim=query_spatial, + is_causal=is_causal, + ) + if added_dummy_spat_dim_to_query: + att = rf.squeeze(att, axis=query_spatial) + + output, _ = rf.merge_dims(att, dims=(self.num_heads, self.value_dim_per_head), out_dim=self.value_dim_total) + if self.proj: + output = self.proj(output) + return output + def _causal_self_att_step( k: Tensor, @@ -227,6 +362,7 @@ def _causal_self_att_step( axis: Dim, state: Optional[CausalSelfAttentionState], self: rf.Module, + with_causal_masking: bool = False, ) -> Tuple[Tensor, Tensor, Dim, CausalSelfAttentionState]: new_state = CausalSelfAttentionState() if axis == single_step_dim: @@ -247,9 +383,14 @@ def _causal_self_att_step( new_state.v_accum = v new_state.accum_axis = axis # See CumConcatLayer and https://github.com/rwth-i6/returnn/issues/391 for the idea. - hist_dim = Dim(rf.range_over_dim(axis, device="cpu") + 1, name=f"{axis.description}:kv") - k, _ = rf.replace_dim(k, in_dim=axis, out_dim=hist_dim) - v, _ = rf.replace_dim(v, in_dim=axis, out_dim=hist_dim) + if with_causal_masking: + # no longer needed if is_causal=True in scaled_dot_product_attention + hist_dim = Dim(rf.range_over_dim(axis, device="cpu") + 1, name=f"{axis.description}:kv") + k, _ = rf.replace_dim(k, in_dim=axis, out_dim=hist_dim) + v, _ = rf.replace_dim(v, in_dim=axis, out_dim=hist_dim) + else: + hist_dim = axis + return k, v, hist_dim, new_state @@ -605,7 +746,9 @@ def __call__( ) -> Tuple[Tensor, CausalSelfAttentionState]: """forward""" q, k, v = self.forward_qkv(source) - k, v, hist_dim, new_state = _causal_self_att_step(k, v, axis=axis, state=state, self=self) + k, v, hist_dim, new_state = _causal_self_att_step( + k, v, axis=axis, state=state, self=self, with_causal_masking=True + ) if self.learned_pos_emb is not None: pos_emb, pos_emb_spatial_dim = self.learned_pos_emb(query_spatial_dim=axis, key_value_spatial_dim=hist_dim) @@ -656,7 +799,7 @@ class CrossAttention(rf.Module): """ Cross attention - It uses :func:`dot_attention` for multi-headed dot-attention. + It uses :func:`scaled_dot_product_attention` for multi-headed dot-attention. """ def __init__( @@ -765,15 +908,25 @@ def __call__(self, q: Tensor, encoder: rf.State) -> Tensor: def attention(self, q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) -> Tensor: """apply attention""" - att = dot_attention( - q, + query, v_feat_dim, query_spatial, added_dummy_spat_dim_to_query = _infer_att_dims( + q, k, v, qk_feat_dim=self.key_dim_per_head, kv_spatial_dim=kv_axis + ) + + att = scaled_dot_product_attention( + query, k, v, - key_dim=self.key_dim_per_head, - axis=kv_axis, att_dropout=self.att_dropout, att_dropout_broadcast=self.att_dropout_broadcast, + v_feat_dim=v_feat_dim, + qk_feat_dim=self.key_dim_per_head, + kv_spatial_dim=kv_axis, + query_spatial_dim=query_spatial, + is_causal=False, ) + if added_dummy_spat_dim_to_query: + att = rf.squeeze(att, axis=query_spatial) + output, _ = rf.merge_dims(att, dims=(self.num_heads, self.value_dim_per_head), out_dim=self.value_dim_total) if self.proj: output = self.proj(output) diff --git a/returnn/torch/frontend/_backend.py b/returnn/torch/frontend/_backend.py index 97bd7f717..93fb97306 100644 --- a/returnn/torch/frontend/_backend.py +++ b/returnn/torch/frontend/_backend.py @@ -2564,6 +2564,147 @@ def lstm( return out, (new_state_h, new_state_c) + ForceFallbackSDPA = False + + @classmethod + def scaled_dot_product_attention( + cls, + query: _TT, + key: _TT, + value: _TT, + *, + attention_mask: Optional[_TT] = None, + att_dropout: float = 0.0, + att_dropout_broadcast: bool, + v_feat_dim: Dim, + qk_feat_dim: Dim, + kv_spatial_dim: Dim, + query_spatial_dim: Dim, + is_causal: bool = False, + scale: Optional[float] = None, + ): + """ + Scaled dot-product attention. + :return: attention output + """ + if ( + TorchBackend.ForceFallbackSDPA + or ( # kv_spatial_dim and query dims are co-dependent, legacy causalselfattention code relies on this... + kv_spatial_dim.dyn_size_ext is not None + and any([d not in key.dims_set for d in kv_spatial_dim.dyn_size_ext.dims]) + ) + or att_dropout_broadcast + ): + # the legacy CausalSelfAttention implementation has a Dimension in the key which depends + # on another Dimension that only exists in the query matrix. + # Therefore the key/value matrices are only well-defined once they are multiplied with the query matrix... + # In this case, we just fall back to the old implementation to not break old setups. + return super().scaled_dot_product_attention( + query=query, + key=key, + value=value, + attention_mask=attention_mask, + att_dropout=att_dropout, + att_dropout_broadcast=att_dropout_broadcast, + v_feat_dim=v_feat_dim, + qk_feat_dim=qk_feat_dim, + kv_spatial_dim=kv_spatial_dim, + query_spatial_dim=query_spatial_dim, + is_causal=is_causal, + scale=scale, + ) + + query_raw = query.raw_tensor + key_raw = key.raw_tensor + value_raw = value.raw_tensor + + if value.feature_dim is not None: + assert value.feature_dim == v_feat_dim # maybe unnecessary check? + batch_dims = query.remaining_dims([qk_feat_dim, query_spatial_dim]) + assert set(batch_dims) == set(key.remaining_dims([qk_feat_dim, kv_spatial_dim])) + assert set(batch_dims) == set(value.remaining_dims([v_feat_dim, kv_spatial_dim])) + query_raw = torch.permute( + query_raw, + [query.get_axis_from_description(d) for d in batch_dims + [query_spatial_dim, qk_feat_dim]], + ).contiguous() # contiguous is a requirement for fused kernels + key_raw = torch.permute( + key_raw, + [key.get_axis_from_description(d) for d in batch_dims + [kv_spatial_dim, qk_feat_dim]], + ).contiguous() + value_raw = torch.permute( + value_raw, + [value.get_axis_from_description(d) for d in batch_dims + [kv_spatial_dim, v_feat_dim]], + ).contiguous() + + attention_mask_raw: Optional[torch.Tensor] = None + if attention_mask is not None: + if is_causal: + raise NotImplementedError("causal attention with attention_mask is not supported") + attention_mask_raw = attention_mask.raw_tensor + # assumes that query and kv spatial dim are present in the attention mask, + # this requirement could be relaxed... + att_mask_batch_dims = attention_mask.remaining_dims([query_spatial_dim, kv_spatial_dim]) + if not set(att_mask_batch_dims).issubset(set(batch_dims)): + raise ValueError( + f"attention_mask has unexpected batch dims {att_mask_batch_dims}, expected subset of {batch_dims}" + ) + # order + att_mask_batch_dims = [d for d in batch_dims if d in att_mask_batch_dims] + + attention_mask_raw = torch.permute( + attention_mask_raw, + [ + attention_mask.get_axis_from_description(d) + for d in att_mask_batch_dims + [query_spatial_dim, kv_spatial_dim] + ], + ) + # we totally ignore kv_spatial_dim dyn_sizes here... TODO + if kv_spatial_dim.is_dynamic(): + raise NotImplementedError("attention_mask with dynamic kv_spatial_dim is not implemented") + elif kv_spatial_dim.is_dynamic() and kv_spatial_dim.dyn_size_ext is not None: + if is_causal: + if not kv_spatial_dim == query_spatial_dim: + raise ValueError("causal attention only supported for kv_spatial_dim == query_spatial_dim") + else: + # no attention mask, but we know that some keys/values are padding + # create mask based on that + kv_spat_dyn_dims = kv_spatial_dim.dyn_size_ext.dims_set + assert kv_spat_dyn_dims.issubset(set(batch_dims)) + attention_mask_raw = kv_spatial_dim.get_mask( + dim_order=[*[d for d in batch_dims if d in kv_spat_dyn_dims], kv_spatial_dim] + ).raw_tensor + # insert 1 dim at -2 (for query_spatial_dim) + attention_mask_raw = attention_mask_raw.unsqueeze(-2) + # now add all other batch dims + for i, b_dim in enumerate(batch_dims): + if b_dim not in kv_spat_dyn_dims: + attention_mask_raw = attention_mask_raw.unsqueeze(i) + + if attention_mask_raw is not None: + attention_mask_raw = attention_mask_raw.contiguous() # necessary for fused kernels + # dont need to check query spatial dim for is_dynamic, as we assign that dim to the result tensor so + # downstream code automatically handles it + att_raw = torch.nn.functional.scaled_dot_product_attention( + query_raw, + key_raw, + value_raw, + attn_mask=attention_mask_raw, + dropout_p=att_dropout if rf.get_run_ctx().is_train_flag_enabled(func=rf.dropout) else 0.0, + is_causal=is_causal, + # scale is only available in PyTorch 2.1+, some tests still run 2.0 + **({"scale": scale} if scale is not None else {}), + ) + + att = rf.convert_to_tensor( + att_raw, + dims=batch_dims + [query_spatial_dim, v_feat_dim], + name="scaled_dot_product_attention", + ) + + if value.feature_dim in att.dims: + att.feature_dim = value.feature_dim + return att + TensorArrayType = List[Tensor] @staticmethod diff --git a/tests/test_rf_attention.py b/tests/test_rf_attention.py index 3e58deb03..2aa989ace 100644 --- a/tests/test_rf_attention.py +++ b/tests/test_rf_attention.py @@ -9,7 +9,9 @@ import _setup_test_env # noqa import returnn.frontend as rf from returnn.tensor import Tensor, Dim, TensorDict, batch_dim +from returnn.torch.frontend import TorchBackend from rf_utils import run_model, tf_scope +import pytest def _setup(): @@ -24,6 +26,11 @@ def _setup(): _setup() +@pytest.fixture(autouse=True) +def reset_sdpa(): + TorchBackend.ForceFallbackSDPA = False + + def test_dot_attention(): time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) key_dim = Dim(7, name="key") @@ -146,6 +153,9 @@ def test_self_attention_to_pure_torch(): ) torch_mhsa.eval() + TorchBackend.ForceFallbackSDPA = True + rf_output_fallback = rf_mhsa(rf_input, axis=spatial_dim) + TorchBackend.ForceFallbackSDPA = False rf_output = rf_mhsa(rf_input, axis=spatial_dim) torch_output, torch_attn_weights = torch_mhsa(torch_input, torch_input, torch_input, key_padding_mask=None) @@ -153,14 +163,20 @@ def test_self_attention_to_pure_torch(): print(rf_output.raw_tensor) print(rf_output.raw_tensor.shape) print("---------------------------") + print("RF output (fallback SDPA)") + print(rf_output_fallback.raw_tensor) + print(rf_output_fallback.raw_tensor.shape) + print("---------------------------") print("Torch output") print(torch_output) print(torch_output.shape) - torch.testing.assert_allclose(rf_output.raw_tensor, torch_output, atol=1e-3, rtol=1e-4) + torch.testing.assert_close(rf_output.raw_tensor, torch_output, atol=1e-3, rtol=1e-4) + torch.testing.assert_close(rf_output_fallback.raw_tensor, torch_output, atol=1e-3, rtol=1e-4) def test_causal_self_attention(): + import torch from returnn.tensor import single_step_dim time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) @@ -204,13 +220,18 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): out = model(extern_data["data"], axis=time_dim) out.mark_as_default_output(shape=(batch_dim, time_dim, model.out_dim)) - run_model( - extern_data, - lambda *, epoch, step: _Net(), - _forward_step, - # TF needs TensorArray unstack, not implemented yet - test_tensorflow=False, - ) + res = {} + for force_fallback_sdpa in [True, False]: + TorchBackend.ForceFallbackSDPA = force_fallback_sdpa + res[force_fallback_sdpa] = run_model( + extern_data, + lambda *, epoch, step: _Net(), + _forward_step, + # TF needs TensorArray unstack, not implemented yet + test_tensorflow=False, + ) + + torch.testing.assert_close(res[False].data["output"].raw_tensor, res[True].data["output"].raw_tensor) def test_rotary_embedding(): @@ -262,8 +283,8 @@ def test_rotary_embedding(): assert out_rf_sin.raw_tensor.shape == out_hf_sin.shape assert out_rf_cos.raw_tensor.shape == out_hf_cos.shape - torch.testing.assert_allclose(out_rf_sin.raw_tensor, out_hf_sin) - torch.testing.assert_allclose(out_rf_cos.raw_tensor, out_hf_cos) + torch.testing.assert_close(out_rf_sin.raw_tensor, out_hf_sin) + torch.testing.assert_close(out_rf_cos.raw_tensor, out_hf_cos) def test_rope_causal_self_att(): @@ -274,6 +295,7 @@ def test_rope_causal_self_att(): # noinspection PyProtectedMember from returnn.frontend.attention import _apply_rope as rf_apply_rope from returnn.frontend.conversions.hf_llama import import_params_hf_llama_att_to_rf_rotary_att + from returnn.frontend._backend import Backend from transformers.models.llama.modeling_llama import ( LlamaAttention, @@ -313,8 +335,34 @@ def test_rope_causal_self_att(): in_ = rf.random_uniform([batch_dim, seq_dim, model_dim]) in_.name = "input" + TorchBackend.ForceFallbackSDPA = True with PyTracer( - [rf.RotaryPosCausalSelfAttention.__call__, rf.sinusoidal_encoding, rf.dot_attention, rf_apply_rope], + [ + rf.RotaryPosCausalSelfAttention.__call__, + rf.sinusoidal_encoding, + rf.scaled_dot_product_attention, + Backend.scaled_dot_product_attention, + rf_apply_rope, + ], + (Tensor, Dim), + ) as trace_rf_fallback: + out_rf_fallback, _ = model_rf( + in_, + axis=seq_dim, + state=model_rf.default_initial_state(batch_dims=[batch_dim]), + ) + out_rf_fallback = out_rf_fallback.copy_transpose((batch_dim, seq_dim, model_dim)) + pprint(trace_rf_fallback.captured_locals) + TorchBackend.ForceFallbackSDPA = False + + with PyTracer( + [ + rf.RotaryPosCausalSelfAttention.__call__, + rf.sinusoidal_encoding, + rf.scaled_dot_product_attention, + Backend.scaled_dot_product_attention, + rf_apply_rope, + ], (Tensor, Dim), ) as trace_rf: out_rf, _ = model_rf(in_, axis=seq_dim, state=model_rf.default_initial_state(batch_dims=[batch_dim])) @@ -348,103 +396,134 @@ def test_rope_causal_self_att(): print("First HF att weight tensor:") print(trace_hf.captured_locals[LlamaAttention.forward][0]["attn_weights"][-1][0, 0, 0].detach().numpy()) - check_py_traces_rf_to_pt_equal( - trace_rf.captured_locals, - trace_hf.captured_locals, - [ - ( - (rf.RotaryPosCausalSelfAttention.__call__, 0, "q", 0), - # input: batch_dim, seq_dim, model_dim - # input_shape: batch_dim, seq_dim - # HF query_states': (batch_dim, seq_dim, num_heads, self.head_dim), - # then transposed to (batch_dim, num_heads, seq_dim, self.head_dim) - (LlamaAttention.forward, 0, "query_states", 0), - lambda x, *, name, **_: rf.convert_to_tensor( - # reorder complex numbers - x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).flatten(-2), - dims=(batch_dim, model_rf.num_heads, seq_dim, model_rf.key_dim_per_head), - name=name, + trace_all = [ + ( + (rf.RotaryPosCausalSelfAttention.__call__, 0, "q", 0), + # input: batch_dim, seq_dim, model_dim + # input_shape: batch_dim, seq_dim + # HF query_states': (batch_dim, seq_dim, num_heads, self.head_dim), + # then transposed to (batch_dim, num_heads, seq_dim, self.head_dim) + (LlamaAttention.forward, 0, "query_states", 0), + lambda x, *, name, **_: rf.convert_to_tensor( + # reorder complex numbers + x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).flatten(-2), + dims=( + batch_dim, + model_rf.num_heads, + seq_dim, + model_rf.key_dim_per_head, ), + name=name, ), - ( - (rf.RotaryPosCausalSelfAttention.__call__, 0, "k", 0), - (LlamaAttention.forward, 0, "key_states", 0), - lambda x, *, name, **_: rf.convert_to_tensor( - # reorder complex numbers - x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).flatten(-2), - dims=(batch_dim, model_rf.num_heads, seq_dim, model_rf.key_dim_per_head), - name=name, + ), + ( + (rf.RotaryPosCausalSelfAttention.__call__, 0, "k", 0), + (LlamaAttention.forward, 0, "key_states", 0), + lambda x, *, name, **_: rf.convert_to_tensor( + # reorder complex numbers + x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).flatten(-2), + dims=( + batch_dim, + model_rf.num_heads, + seq_dim, + model_rf.key_dim_per_head, ), + name=name, ), - ( - (rf.sinusoidal_encoding, 0, "div_term", 0), - (LlamaRotaryEmbedding.forward, 0, "inv_freq_expanded", 0), - lambda x, *, name, **_: rf.convert_to_tensor( - x[0, :, 0], dims=[model_rf.key_dim_per_head.div_left(2)], name=name - ), + ), + ( + (rf_apply_rope, 0, "pe_imag", 0), + (apply_rotary_pos_emb, 0, "sin", 0), + lambda x, *, name, **_: rf.convert_to_tensor( + x[0, :, : x.shape[2] // 2], + dims=(seq_dim, model_rf.key_dim_per_head.div_left(2)), + name=name, ), - ( - (rf.sinusoidal_encoding, 0, "arg_sin", 0), - (LlamaRotaryEmbedding.forward, 0, "freqs", 0), - lambda x, *, name, resolve_dim, **_: rf.convert_to_tensor( - x[0], - dims=(resolve_dim("arg_sin.dims[0]"), model_rf.key_dim_per_head.div_left(2)), - name=name, - ), + ), + ( + (rf_apply_rope, 0, "pe_imag", 0), + (apply_rotary_pos_emb, 0, "sin", 0), + lambda x, *, name, **_: rf.convert_to_tensor( + x[-1, :, x.shape[2] // 2 :], + dims=(seq_dim, model_rf.key_dim_per_head.div_left(2)), + name=name, ), - ( - (rf_apply_rope, 0, "pe_imag", 0), - (apply_rotary_pos_emb, 0, "sin", 0), - lambda x, *, name, **_: rf.convert_to_tensor( - x[0, :, : x.shape[2] // 2], dims=(seq_dim, model_rf.key_dim_per_head.div_left(2)), name=name - ), + ), + ( + (rf_apply_rope, 0, "pe_real", 0), + (apply_rotary_pos_emb, 0, "cos", 0), + lambda x, *, name, **_: rf.convert_to_tensor( + x[0, :, : x.shape[2] // 2], + dims=(seq_dim, model_rf.key_dim_per_head.div_left(2)), + name=name, ), - ( - (rf_apply_rope, 0, "pe_imag", 0), - (apply_rotary_pos_emb, 0, "sin", 0), - lambda x, *, name, **_: rf.convert_to_tensor( - x[-1, :, x.shape[2] // 2 :], dims=(seq_dim, model_rf.key_dim_per_head.div_left(2)), name=name + ), + ( + (rf.RotaryPosCausalSelfAttention.__call__, 0, "q", -1), + (LlamaAttention.forward, 0, "query_states", -1), + lambda x, *, name, **_: rf.convert_to_tensor( + x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).flatten(-2), + dims=( + batch_dim, + model_rf.num_heads, + seq_dim, + model_rf.key_dim_per_head, ), + name=name, ), - ( - (rf_apply_rope, 0, "pe_real", 0), - (apply_rotary_pos_emb, 0, "cos", 0), - lambda x, *, name, **_: rf.convert_to_tensor( - x[0, :, : x.shape[2] // 2], - dims=(seq_dim, model_rf.key_dim_per_head.div_left(2)), - name=name, - ), + ), + ( + (rf.scaled_dot_product_attention, 0, "att", 0), + (LlamaAttention.forward, 0, "attn_output", 0), + (batch_dim, seq_dim, model_rf.num_heads, model_rf.value_dim_per_head), + ), + ] + + trace_only_fallback = [ + *trace_all, + # fallback is executed first, and sinusoidal is cached. Therefore we only test in fallback + ( + (rf.sinusoidal_encoding, 0, "div_term", 0), + (LlamaRotaryEmbedding.forward, 0, "inv_freq_expanded", 0), + lambda x, *, name, **_: rf.convert_to_tensor( + x[0, :, 0], dims=[model_rf.key_dim_per_head.div_left(2)], name=name ), - ( - (rf.RotaryPosCausalSelfAttention.__call__, 0, "q", -1), - (LlamaAttention.forward, 0, "query_states", -1), - lambda x, *, name, **_: rf.convert_to_tensor( - x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).flatten(-2), - dims=(batch_dim, model_rf.num_heads, seq_dim, model_rf.key_dim_per_head), - name=name, + ), + ( + (rf.sinusoidal_encoding, 0, "arg_sin", 0), + (LlamaRotaryEmbedding.forward, 0, "freqs", 0), + lambda x, *, name, resolve_dim, **_: rf.convert_to_tensor( + x[0], + dims=( + resolve_dim("arg_sin.dims[0]"), + model_rf.key_dim_per_head.div_left(2), ), + name=name, ), - ( - (rf.dot_attention, 0, "energy", 0), - (eager_attention_forward, 0, "attn_weights", 0), - (batch_dim, model_rf.num_heads, seq_dim, "axis"), - ), - ( - (rf.dot_attention, 0, "att_weights", 0), - (LlamaAttention.forward, 0, "attn_weights", -1), - (batch_dim, model_rf.num_heads, seq_dim, "axis"), - ), - ( - (rf.dot_attention, 0, "att", 0), - (LlamaAttention.forward, 0, "attn_output", 0), - (batch_dim, seq_dim, model_rf.num_heads, model_rf.value_dim_per_head), - ), - ], - ) + ), + ( + (Backend.scaled_dot_product_attention, 0, "energy", 0), + (eager_attention_forward, 0, "attn_weights", 0), + (batch_dim, model_rf.num_heads, seq_dim, "kv_spatial_dim"), + ), + ( + (Backend.scaled_dot_product_attention, 0, "att_weights", 0), + (LlamaAttention.forward, 0, "attn_weights", -1), + (batch_dim, model_rf.num_heads, seq_dim, "kv_spatial_dim"), + ), + ] + + assert Backend.scaled_dot_product_attention not in trace_rf.captured_locals + assert Backend.scaled_dot_product_attention in trace_rf_fallback.captured_locals + + check_py_traces_rf_to_pt_equal(trace_rf.captured_locals, trace_hf.captured_locals, trace_all) + check_py_traces_rf_to_pt_equal(trace_rf_fallback.captured_locals, trace_hf.captured_locals, trace_only_fallback) print("Final check...") assert out_rf.raw_tensor.shape == out_hf.shape torch.testing.assert_close(out_rf.raw_tensor, out_hf) + assert out_rf_fallback.raw_tensor.shape == out_hf.shape + torch.testing.assert_close(out_rf_fallback.raw_tensor, out_hf) print(" all matched!") @@ -506,24 +585,42 @@ def _make_rel_pos_causal_self_att(**_kwargs): models = [_make_causal_self_att, _make_rope_causal_self_att, _make_rel_pos_causal_self_att] - for get_model in models: - print("> Testing model:", get_model.__name__) - res = run_model( - extern_data, - get_model, - _forward_step, - # TF needs TensorArray unstack, not implemented yet - test_tensorflow=False, - ) - - # Check that the single-step and the seq-level output are the same. - res_seq_level = res.data["out_seq_level"].raw_tensor - for key in ["out_seq_level_explicit_initial_state", "out_single_steps"]: - res_other = res.data[key].raw_tensor - assert res_seq_level.shape == res_other.shape - numpy.testing.assert_allclose( - res_other, res_seq_level, atol=1e-5, rtol=1e-5, err_msg=f"output {key} differs" + resdict = {} + for use_fallback in [True, False]: + TorchBackend.ForceFallbackSDPA = use_fallback + print("=== ForceFallbackSDPA =", use_fallback, "===") + for get_model in models: + print("> Testing model:", get_model.__name__) + res = run_model( + extern_data, + get_model, + _forward_step, + # TF needs TensorArray unstack, not implemented yet + test_tensorflow=False, ) + resdict.setdefault(get_model.__name__, {})[use_fallback] = res + + # Check that the single-step and the seq-level output are the same. + res_seq_level = res.data["out_seq_level"].raw_tensor + for key in ["out_seq_level_explicit_initial_state", "out_single_steps"]: + res_other = res.data[key].raw_tensor + assert res_seq_level.shape == res_other.shape + numpy.testing.assert_allclose( + res_other, + res_seq_level, + atol=1e-5, + rtol=1e-5, + err_msg=f"output {key} differs", + ) + # test fallback & no-fallback + for name, v in resdict.items(): + numpy.testing.assert_allclose( + v[True].data["out_seq_level"].raw_tensor, + v[False].data["out_seq_level"].raw_tensor, + atol=1e-5, + rtol=1e-5, + err_msg=f"output {name} differs", + ) def test_relative_positional_encoding(): @@ -572,6 +669,8 @@ def _forward_step(**_kwargs): def test_rel_pos_self_attention(): + import torch + time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) in_dim = Dim(8, name="in") extern_data = TensorDict( @@ -630,7 +729,15 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step) check_batching = True - run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step, test_tensorflow=False) + + res = {} + for force_fallback_sdpa in [True, False]: + TorchBackend.ForceFallbackSDPA = force_fallback_sdpa + res[force_fallback_sdpa] = run_model( + extern_data, lambda *, epoch, step: _Net(), _forward_step, test_tensorflow=False + ) + + torch.testing.assert_close(res[False].data["output"].raw_tensor, res[True].data["output"].raw_tensor) def test_sinusoidal_positional_encoding(): @@ -660,6 +767,8 @@ def _forward_step(**_kwargs): def test_CausalSelfAttention(): + import torch + time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) feat_dim = Dim(8, name="feat") key_dim = Dim(6, name="key") @@ -678,21 +787,26 @@ def _forward_step(*, model: rf.CausalSelfAttention, extern_data: TensorDict): out.mark_as_default_output(shape=(batch_dim, time_dim, value_dim)) model.qkv.weight.mark_as_output("qkv_weight", shape=[feat_dim, 2 * key_dim + value_dim]) - res = run_model( - extern_data, - lambda *, epoch, step: rf.CausalSelfAttention( - in_dim=feat_dim, - proj_dim=None, - key_dim_total=key_dim, - value_dim_total=value_dim, - num_heads=2, - with_bias=False, - ), - _forward_step, - # Some problem with dimension tags currently in the TF-layers-dict backend... - # Anyway, we compare to the TF SelfAttentionLayer with attention_left_only=True below. - test_tensorflow=False, - ) + res = {} + for force_fallback_sdpa in [True, False]: + TorchBackend.ForceFallbackSDPA = force_fallback_sdpa + res[force_fallback_sdpa] = run_model( + extern_data, + lambda *, epoch, step: rf.CausalSelfAttention( + in_dim=feat_dim, + proj_dim=None, + key_dim_total=key_dim, + value_dim_total=value_dim, + num_heads=2, + with_bias=False, + ), + _forward_step, + # Some problem with dimension tags currently in the TF-layers-dict backend... + # Anyway, we compare to the TF SelfAttentionLayer with attention_left_only=True below. + test_tensorflow=False, + ) + + torch.testing.assert_close(res[False].data["output"].raw_tensor, res[True].data["output"].raw_tensor) extern_data.reset_content() @@ -725,14 +839,86 @@ def _forward_step(*, model: rf.CausalSelfAttention, extern_data: TensorDict): ) net.construct_from_dict(net_dict) layer = net.get_default_output_layer() - layer.params["QKV"].load(res.data["qkv_weight"].raw_tensor, session=session) + layer.params["QKV"].load(res[False].data["qkv_weight"].raw_tensor, session=session) out = layer.output.copy_transpose([batch_dim, time_dim, value_dim]).copy_masked(0.0) out_tf_v = session.run( out.raw_tensor, feed_dict={ - net.extern_data.data["data"].placeholder: res.data["data"].raw_tensor, - net.extern_data.data["data"].dims[1].dyn_size_ext.raw_tensor: res.data["seq_len"].raw_tensor, + net.extern_data.data["data"].placeholder: res[False].data["data"].raw_tensor, + net.extern_data.data["data"].dims[1].dyn_size_ext.raw_tensor: res[False].data["seq_len"].raw_tensor, }, ) - numpy.testing.assert_almost_equal(res.data["output"].raw_tensor, out_tf_v, decimal=5) + numpy.testing.assert_almost_equal(res[False].data["output"].raw_tensor, out_tf_v, decimal=5) + + +# Check if pytest-benchmark is installed +try: + import pytest_benchmark +except ImportError: + pytest_benchmark = None + + +def benchmark_att_layer(benchmark, att_type, sdpa, seq_len): + import torch + + TorchBackend.ForceFallbackSDPA = sdpa == "returnn_fallback" + rf.select_backend_torch() + rf.init_forward_step_run_ctx() + + batch_dim = Dim(128 if seq_len <= 32 else 8, name="batch") + seq_dim = Dim(seq_len, name="time") + + num_heads = 8 + head_dim = 64 + model_dim_val = num_heads * head_dim + model_dim = Dim(model_dim_val, name="model") + with torch.amp.autocast("cuda", dtype=torch.bfloat16), rf.set_default_float_dtype_ctx("bfloat16"): + input_tensor = rf.random_normal(dims=[batch_dim, seq_dim, model_dim]) + + _cls = rf.SelfAttention if att_type == "self_attention" else rf.CausalSelfAttention + att_layer = _cls( + in_dim=model_dim, + proj_dim=model_dim, + key_dim_total=model_dim, + value_dim_total=model_dim, + num_heads=num_heads, + ) + + def _run_forward(): + with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION]): + res = att_layer(input_tensor, axis=seq_dim) + torch.cuda.synchronize() if torch.cuda.is_available() else None + return res + + benchmark(_run_forward) + + +@pytest.mark.skipif(pytest_benchmark is None, reason="pytest-benchmark not installed") +@pytest.mark.parametrize("sdpa", ["torch_sdpa", "returnn_fallback"], ids=lambda sdpa: f"{sdpa=}") +@pytest.mark.parametrize("seq_len", [32, 1024], ids=lambda seq_len: f"{seq_len=}") +@pytest.mark.benchmark(warmup=True, disable_gc=True) +def test_benchmark_flashatt_self_attention(benchmark, sdpa, seq_len): + import torch + + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + benchmark.group = f"flashatt_self_attention({device}, {seq_len=})" + with rf.set_default_device_ctx(device): + benchmark_att_layer(benchmark, "self_attention", sdpa, seq_len) + + +@pytest.mark.skipif(pytest_benchmark is None, reason="pytest-benchmark not installed") +@pytest.mark.parametrize("sdpa", ["torch_sdpa", "returnn_fallback"], ids=lambda sdpa: f"{sdpa=}") +@pytest.mark.parametrize("seq_len", [32, 1024], ids=lambda seq_len: f"{seq_len=}") +@pytest.mark.benchmark(warmup=True, disable_gc=True) +def test_benchmark_flashatt_causal_self_attention(benchmark, sdpa, seq_len): + import torch + + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + benchmark.group = f"flashatt_causal_self_attention({device}, {seq_len=})" + with rf.set_default_device_ctx(device): + benchmark_att_layer(benchmark, "causal_self_attention", sdpa, seq_len)