Skip to content

Commit 003fa34

Browse files
committed
refactor test
1 parent 2d12f46 commit 003fa34

2 files changed

Lines changed: 98 additions & 164 deletions

File tree

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import pytest
16+
import torch
17+
import torch.nn.functional as F
18+
19+
from diffusers.models.attention_dispatch import (
20+
_CAN_USE_FLASH_ATTN,
21+
AttentionBackendName,
22+
dispatch_attention_fn,
23+
)
24+
25+
26+
# A mask with non-contiguous valid tokens.
27+
_NON_PREFIX_MASK = torch.tensor(
28+
[
29+
[True, True, True, False, False, True, True, True, True, True],
30+
[True, False, False, False, True, True, True, True, True, True],
31+
],
32+
dtype=torch.bool,
33+
)
34+
35+
36+
def _make_qkv(batch_size, seq_len, num_heads, head_dim, dtype=torch.float32):
37+
g = torch.Generator().manual_seed(42)
38+
q = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype)
39+
k = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype)
40+
v = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype)
41+
return q, k, v
42+
43+
44+
def _sdpa_ref(q, k, v, bool_mask_2d=None):
45+
if bool_mask_2d is not None:
46+
additive_mask = torch.zeros_like(bool_mask_2d, dtype=q.dtype)
47+
additive_mask = additive_mask.masked_fill(~bool_mask_2d, float("-inf"))
48+
additive_mask = additive_mask[:, None, None, :] # (batch_size, 1, 1, seq_len_kv)
49+
else:
50+
additive_mask = None
51+
q, k, v = (t.permute(0, 2, 1, 3) for t in (q, k, v))
52+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=additive_mask)
53+
return out.permute(0, 2, 1, 3)
54+
55+
56+
@pytest.mark.skipif(not _CAN_USE_FLASH_ATTN, reason="flash-attn is required for these tests")
57+
class TestFlashAttention:
58+
"""Flash attention backend must produce results consistent with the SDPA reference when attn_mask is given."""
59+
60+
def test_no_mask_matches_sdpa_reference(self):
61+
"""FLASH backend output must match SDPA reference without any masking."""
62+
batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32
63+
device = torch.device("cuda")
64+
q, k, v = (
65+
t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim)
66+
)
67+
ref = _sdpa_ref(q, k, v)
68+
out = dispatch_attention_fn(q, k, v, attn_mask=None, backend=AttentionBackendName.FLASH)
69+
70+
assert torch.allclose(ref, out, atol=1e-2), f"Max diff: {(ref - out).abs().max():.2e}"
71+
72+
def test_mask_matches_sdpa_reference(self):
73+
"""FLASH backend output must match SDPA reference with attention mask."""
74+
batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32
75+
device = torch.device("cuda")
76+
q, k, v = (
77+
t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim)
78+
)
79+
mask = _NON_PREFIX_MASK.to(device)
80+
81+
ref = _sdpa_ref(q, k, v, mask)
82+
out = dispatch_attention_fn(q, k, v, attn_mask=mask, backend=AttentionBackendName.FLASH)
83+
84+
assert torch.allclose(ref, out, atol=1e-2), f"Max diff: {(ref - out).abs().max():.2e}"
85+
86+
def test_4d_bool_mask_equivalent_to_2d(self):
87+
"""4D bool mask (batch_size, 1, 1, seq_len) must normalize to the same result as the 2D mask."""
88+
batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32
89+
device = torch.device("cuda")
90+
q, k, v = (
91+
t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim)
92+
)
93+
mask = _NON_PREFIX_MASK.to(device)
94+
95+
out_2d = dispatch_attention_fn(q, k, v, attn_mask=mask, backend=AttentionBackendName.FLASH)
96+
out_4d = dispatch_attention_fn(q, k, v, attn_mask=mask[:, None, None, :], backend=AttentionBackendName.FLASH)
97+
98+
assert torch.allclose(out_2d, out_4d, atol=1e-3)

tests/others/test_flash_attention_mask.py

Lines changed: 0 additions & 164 deletions
This file was deleted.

0 commit comments

Comments
 (0)