Skip to content

Commit 71bc205

Browse files
committed
test(attention): Removed subprocess isolation for fa4 tests.
Not necessary on H100 system. Better to skip these tests on A100 systems.
1 parent 1e84358 commit 71bc205

1 file changed

Lines changed: 67 additions & 131 deletions

File tree

tests/models/test_causal_self_attention.py

Lines changed: 67 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,7 @@
33
To do so, turn on verbose and run 'pytest tests/models/test_causal_self_attention.py -s'
44
"""
55

6-
import os
7-
import subprocess
8-
import sys
9-
import textwrap
106
from copy import deepcopy
11-
from pathlib import Path
127

138
import pytest
149
import torch
@@ -25,8 +20,6 @@
2520
torch.manual_seed(0)
2621

2722
FLASH_ATTN_V4_AVAILABLE = is_flash_attn_v4_available()
28-
REPO_ROOT = Path(__file__).resolve().parents[2]
29-
SRC_ROOT = REPO_ROOT / "src"
3023

3124

3225
def _get_random_input_seq(embedding_shape):
@@ -287,142 +280,85 @@ def test_qk_norm(n_head_q, n_head_kv, n_embd, attention_impl):
287280

288281
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.")
289282
@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed")
290-
def test_dao_flash_v4_forward_mha_subprocess():
291-
result = _run_fa4_subprocess(
292-
"""
293-
import torch
294-
from modalities.models.gpt2.gpt2_model import CausalSelfAttention
295-
296-
q = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda')
297-
k = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda')
298-
v = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda')
299-
out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl='dao_flash_v4')
300-
torch.cuda.synchronize()
301-
assert tuple(out.shape) == (2, 12, 4, 32)
302-
print('ok')
303-
"""
304-
)
305-
assert result.stdout.strip().endswith("ok")
283+
def test_dao_flash_v4_forward_mha():
284+
q = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device="cuda")
285+
k = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device="cuda")
286+
v = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device="cuda")
287+
288+
out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl="dao_flash_v4")
289+
290+
torch.cuda.synchronize()
291+
assert tuple(out.shape) == (2, 12, 4, 32)
306292

307293

308294
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.")
309295
@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed")
310-
def test_dao_flash_v4_forward_gqa_subprocess():
311-
result = _run_fa4_subprocess(
312-
"""
313-
import torch
314-
from modalities.models.gpt2.gpt2_model import CausalSelfAttention
315-
316-
q = torch.rand(2, 8, 12, 32, dtype=torch.bfloat16, device='cuda')
317-
k = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device='cuda')
318-
v = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device='cuda')
319-
out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl='dao_flash_v4')
320-
torch.cuda.synchronize()
321-
assert tuple(out.shape) == (2, 12, 8, 32)
322-
print('ok')
323-
"""
324-
)
325-
assert result.stdout.strip().endswith("ok")
296+
def test_dao_flash_v4_forward_gqa():
297+
q = torch.rand(2, 8, 12, 32, dtype=torch.bfloat16, device="cuda")
298+
k = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device="cuda")
299+
v = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device="cuda")
300+
301+
out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl="dao_flash_v4")
302+
303+
torch.cuda.synchronize()
304+
assert tuple(out.shape) == (2, 12, 8, 32)
326305

327306

328307
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.")
329308
@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed")
330-
def test_dao_flash_v4_qk_norm_subprocess():
331-
result = _run_fa4_subprocess(
332-
"""
333-
import torch
334-
from modalities.models.gpt2.gpt2_model import (
335-
AttentionConfig,
336-
CausalSelfAttention,
337-
LayerNorms,
338-
LayerNormWrapperConfig,
339-
PytorchRMSLayerNormConfig,
340-
)
341-
342-
torch.manual_seed(0)
343-
attention_config_no_norm = AttentionConfig(qkv_transforms=[])
344-
attention_config_with_norm = AttentionConfig(
345-
qkv_transforms=[],
346-
qk_norm_config=LayerNormWrapperConfig(
347-
norm_type=LayerNorms.pytorch_rms_norm,
348-
config=PytorchRMSLayerNormConfig(normalized_shape=8),
349-
),
350-
)
351-
352-
torch.manual_seed(0)
353-
layer_no_norm = CausalSelfAttention(
354-
4, 4, 32, attention_config_no_norm, 'dao_flash_v4', False, 0.0
355-
).cuda().bfloat16()
356-
torch.manual_seed(0)
357-
layer_with_norm = CausalSelfAttention(
358-
4, 4, 32, attention_config_with_norm, 'dao_flash_v4', False, 0.0
359-
).cuda().bfloat16()
360-
x = torch.rand((2, 9, 32), dtype=torch.bfloat16, device='cuda')
361-
out_no_norm = layer_no_norm(x)
362-
out_with_norm = layer_with_norm(x)
363-
torch.cuda.synchronize()
364-
assert out_no_norm.shape == out_with_norm.shape == (2, 9, 32)
365-
assert not torch.allclose(out_no_norm, out_with_norm, atol=1e-6)
366-
print('ok')
367-
"""
309+
def test_dao_flash_v4_qk_norm():
310+
torch.manual_seed(0)
311+
attention_config_no_norm = AttentionConfig(qkv_transforms=[])
312+
attention_config_with_norm = AttentionConfig(
313+
qkv_transforms=[],
314+
qk_norm_config=LayerNormWrapperConfig(
315+
norm_type=LayerNorms.pytorch_rms_norm,
316+
config=PytorchRMSLayerNormConfig(normalized_shape=8),
317+
),
368318
)
369-
assert result.stdout.strip().endswith("ok")
319+
320+
torch.manual_seed(0)
321+
layer_no_norm = (
322+
CausalSelfAttention(4, 4, 32, attention_config_no_norm, "dao_flash_v4", False, 0.0).cuda().bfloat16()
323+
)
324+
torch.manual_seed(0)
325+
layer_with_norm = (
326+
CausalSelfAttention(4, 4, 32, attention_config_with_norm, "dao_flash_v4", False, 0.0).cuda().bfloat16()
327+
)
328+
x = torch.rand((2, 9, 32), dtype=torch.bfloat16, device="cuda")
329+
330+
out_no_norm = layer_no_norm(x)
331+
out_with_norm = layer_with_norm(x)
332+
333+
torch.cuda.synchronize()
334+
assert out_no_norm.shape == out_with_norm.shape == (2, 9, 32)
335+
assert not torch.allclose(out_no_norm, out_with_norm, atol=1e-6)
370336

371337

372338
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.")
373339
@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed")
374-
def test_dao_flash_v4_backward_approximate_equality_subprocess():
375-
result = _run_fa4_subprocess(
376-
"""
377-
import torch
378-
from modalities.models.gpt2.gpt2_model import CausalSelfAttention
379-
380-
query_ref = torch.rand((2, 8, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True)
381-
key_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True)
382-
value_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True)
383-
384-
query_fa4 = query_ref.detach().clone().requires_grad_(True)
385-
key_fa4 = key_ref.detach().clone().requires_grad_(True)
386-
value_fa4 = value_ref.detach().clone().requires_grad_(True)
387-
388-
output_ref = CausalSelfAttention.execute_attention(
389-
query_ref, key_ref, value_ref, dropout=0.0, attention_impl='pytorch_flash'
390-
)
391-
output_fa4 = CausalSelfAttention.execute_attention(
392-
query_fa4, key_fa4, value_fa4, dropout=0.0, attention_impl='dao_flash_v4'
393-
)
394-
torch.testing.assert_close(output_ref, output_fa4, atol=2.5e-3, rtol=0.016)
395-
396-
output_ref.float().sum().backward()
397-
output_fa4.float().sum().backward()
398-
torch.cuda.synchronize()
399-
400-
torch.testing.assert_close(query_ref.grad, query_fa4.grad, atol=5e-3, rtol=0.02)
401-
torch.testing.assert_close(key_ref.grad, key_fa4.grad, atol=5e-3, rtol=0.02)
402-
torch.testing.assert_close(value_ref.grad, value_fa4.grad, atol=5e-3, rtol=0.02)
403-
print('ok')
404-
"""
340+
def test_dao_flash_v4_backward_approximate_equality():
341+
query_ref = torch.rand((2, 8, 12, 64), dtype=torch.bfloat16, device="cuda", requires_grad=True)
342+
key_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device="cuda", requires_grad=True)
343+
value_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device="cuda", requires_grad=True)
344+
345+
query_fa4 = query_ref.detach().clone().requires_grad_(True)
346+
key_fa4 = key_ref.detach().clone().requires_grad_(True)
347+
value_fa4 = value_ref.detach().clone().requires_grad_(True)
348+
349+
output_ref = CausalSelfAttention.execute_attention(
350+
query_ref, key_ref, value_ref, dropout=0.0, attention_impl="pytorch_flash"
405351
)
406-
assert result.stdout.strip().endswith("ok")
407-
408-
409-
def _run_fa4_subprocess(code: str) -> subprocess.CompletedProcess[str]:
410-
"""Run flash attention 4 related code in a subprocess to isolate FA4's CUDA context
411-
and avoid conflicts with other tests.
412-
The code should print 'ok' if it runs successfully.
413-
The function returns the CompletedProcess object,
414-
which contains stdout and stderr for further inspection if needed.
415-
TODO: This might be an A100 / SM80-specific issue, so we can consider removing this subprocess isolation
416-
if we confirm that FA4 works well on newer architectures without it.
417-
"""
418-
env = os.environ.copy()
419-
existing_pythonpath = env.get("PYTHONPATH")
420-
env["PYTHONPATH"] = f"{SRC_ROOT}:{existing_pythonpath}" if existing_pythonpath else str(SRC_ROOT)
421-
return subprocess.run(
422-
[sys.executable, "-c", textwrap.dedent(code)],
423-
cwd=REPO_ROOT,
424-
env=env,
425-
check=True,
426-
capture_output=True,
427-
text=True,
352+
output_fa4 = CausalSelfAttention.execute_attention(
353+
query_fa4, key_fa4, value_fa4, dropout=0.0, attention_impl="dao_flash_v4"
428354
)
355+
356+
torch.testing.assert_close(output_ref, output_fa4, atol=2.5e-3, rtol=0.016)
357+
358+
output_ref.float().sum().backward()
359+
output_fa4.float().sum().backward()
360+
torch.cuda.synchronize()
361+
362+
torch.testing.assert_close(query_ref.grad, query_fa4.grad, atol=5e-3, rtol=0.02)
363+
torch.testing.assert_close(key_ref.grad, key_fa4.grad, atol=5e-3, rtol=0.02)
364+
torch.testing.assert_close(value_ref.grad, value_fa4.grad, atol=5e-3, rtol=0.02)

0 commit comments

Comments
 (0)