|
3 | 3 | To do so, turn on verbose and run 'pytest tests/models/test_causal_self_attention.py -s' |
4 | 4 | """ |
5 | 5 |
|
6 | | -import os |
7 | | -import subprocess |
8 | | -import sys |
9 | | -import textwrap |
10 | 6 | from copy import deepcopy |
11 | | -from pathlib import Path |
12 | 7 |
|
13 | 8 | import pytest |
14 | 9 | import torch |
|
25 | 20 | torch.manual_seed(0) |
26 | 21 |
|
27 | 22 | FLASH_ATTN_V4_AVAILABLE = is_flash_attn_v4_available() |
28 | | -REPO_ROOT = Path(__file__).resolve().parents[2] |
29 | | -SRC_ROOT = REPO_ROOT / "src" |
30 | 23 |
|
31 | 24 |
|
32 | 25 | def _get_random_input_seq(embedding_shape): |
@@ -287,142 +280,85 @@ def test_qk_norm(n_head_q, n_head_kv, n_embd, attention_impl): |
287 | 280 |
|
288 | 281 | @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") |
289 | 282 | @pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed") |
290 | | -def test_dao_flash_v4_forward_mha_subprocess(): |
291 | | - result = _run_fa4_subprocess( |
292 | | - """ |
293 | | - import torch |
294 | | - from modalities.models.gpt2.gpt2_model import CausalSelfAttention |
295 | | -
|
296 | | - q = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda') |
297 | | - k = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda') |
298 | | - v = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda') |
299 | | - out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl='dao_flash_v4') |
300 | | - torch.cuda.synchronize() |
301 | | - assert tuple(out.shape) == (2, 12, 4, 32) |
302 | | - print('ok') |
303 | | - """ |
304 | | - ) |
305 | | - assert result.stdout.strip().endswith("ok") |
| 283 | +def test_dao_flash_v4_forward_mha(): |
| 284 | + q = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device="cuda") |
| 285 | + k = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device="cuda") |
| 286 | + v = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device="cuda") |
| 287 | + |
| 288 | + out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl="dao_flash_v4") |
| 289 | + |
| 290 | + torch.cuda.synchronize() |
| 291 | + assert tuple(out.shape) == (2, 12, 4, 32) |
306 | 292 |
|
307 | 293 |
|
308 | 294 | @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") |
309 | 295 | @pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed") |
310 | | -def test_dao_flash_v4_forward_gqa_subprocess(): |
311 | | - result = _run_fa4_subprocess( |
312 | | - """ |
313 | | - import torch |
314 | | - from modalities.models.gpt2.gpt2_model import CausalSelfAttention |
315 | | -
|
316 | | - q = torch.rand(2, 8, 12, 32, dtype=torch.bfloat16, device='cuda') |
317 | | - k = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device='cuda') |
318 | | - v = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device='cuda') |
319 | | - out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl='dao_flash_v4') |
320 | | - torch.cuda.synchronize() |
321 | | - assert tuple(out.shape) == (2, 12, 8, 32) |
322 | | - print('ok') |
323 | | - """ |
324 | | - ) |
325 | | - assert result.stdout.strip().endswith("ok") |
| 296 | +def test_dao_flash_v4_forward_gqa(): |
| 297 | + q = torch.rand(2, 8, 12, 32, dtype=torch.bfloat16, device="cuda") |
| 298 | + k = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device="cuda") |
| 299 | + v = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device="cuda") |
| 300 | + |
| 301 | + out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl="dao_flash_v4") |
| 302 | + |
| 303 | + torch.cuda.synchronize() |
| 304 | + assert tuple(out.shape) == (2, 12, 8, 32) |
326 | 305 |
|
327 | 306 |
|
328 | 307 | @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") |
329 | 308 | @pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed") |
330 | | -def test_dao_flash_v4_qk_norm_subprocess(): |
331 | | - result = _run_fa4_subprocess( |
332 | | - """ |
333 | | - import torch |
334 | | - from modalities.models.gpt2.gpt2_model import ( |
335 | | - AttentionConfig, |
336 | | - CausalSelfAttention, |
337 | | - LayerNorms, |
338 | | - LayerNormWrapperConfig, |
339 | | - PytorchRMSLayerNormConfig, |
340 | | - ) |
341 | | -
|
342 | | - torch.manual_seed(0) |
343 | | - attention_config_no_norm = AttentionConfig(qkv_transforms=[]) |
344 | | - attention_config_with_norm = AttentionConfig( |
345 | | - qkv_transforms=[], |
346 | | - qk_norm_config=LayerNormWrapperConfig( |
347 | | - norm_type=LayerNorms.pytorch_rms_norm, |
348 | | - config=PytorchRMSLayerNormConfig(normalized_shape=8), |
349 | | - ), |
350 | | - ) |
351 | | -
|
352 | | - torch.manual_seed(0) |
353 | | - layer_no_norm = CausalSelfAttention( |
354 | | - 4, 4, 32, attention_config_no_norm, 'dao_flash_v4', False, 0.0 |
355 | | - ).cuda().bfloat16() |
356 | | - torch.manual_seed(0) |
357 | | - layer_with_norm = CausalSelfAttention( |
358 | | - 4, 4, 32, attention_config_with_norm, 'dao_flash_v4', False, 0.0 |
359 | | - ).cuda().bfloat16() |
360 | | - x = torch.rand((2, 9, 32), dtype=torch.bfloat16, device='cuda') |
361 | | - out_no_norm = layer_no_norm(x) |
362 | | - out_with_norm = layer_with_norm(x) |
363 | | - torch.cuda.synchronize() |
364 | | - assert out_no_norm.shape == out_with_norm.shape == (2, 9, 32) |
365 | | - assert not torch.allclose(out_no_norm, out_with_norm, atol=1e-6) |
366 | | - print('ok') |
367 | | - """ |
| 309 | +def test_dao_flash_v4_qk_norm(): |
| 310 | + torch.manual_seed(0) |
| 311 | + attention_config_no_norm = AttentionConfig(qkv_transforms=[]) |
| 312 | + attention_config_with_norm = AttentionConfig( |
| 313 | + qkv_transforms=[], |
| 314 | + qk_norm_config=LayerNormWrapperConfig( |
| 315 | + norm_type=LayerNorms.pytorch_rms_norm, |
| 316 | + config=PytorchRMSLayerNormConfig(normalized_shape=8), |
| 317 | + ), |
368 | 318 | ) |
369 | | - assert result.stdout.strip().endswith("ok") |
| 319 | + |
| 320 | + torch.manual_seed(0) |
| 321 | + layer_no_norm = ( |
| 322 | + CausalSelfAttention(4, 4, 32, attention_config_no_norm, "dao_flash_v4", False, 0.0).cuda().bfloat16() |
| 323 | + ) |
| 324 | + torch.manual_seed(0) |
| 325 | + layer_with_norm = ( |
| 326 | + CausalSelfAttention(4, 4, 32, attention_config_with_norm, "dao_flash_v4", False, 0.0).cuda().bfloat16() |
| 327 | + ) |
| 328 | + x = torch.rand((2, 9, 32), dtype=torch.bfloat16, device="cuda") |
| 329 | + |
| 330 | + out_no_norm = layer_no_norm(x) |
| 331 | + out_with_norm = layer_with_norm(x) |
| 332 | + |
| 333 | + torch.cuda.synchronize() |
| 334 | + assert out_no_norm.shape == out_with_norm.shape == (2, 9, 32) |
| 335 | + assert not torch.allclose(out_no_norm, out_with_norm, atol=1e-6) |
370 | 336 |
|
371 | 337 |
|
372 | 338 | @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") |
373 | 339 | @pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed") |
374 | | -def test_dao_flash_v4_backward_approximate_equality_subprocess(): |
375 | | - result = _run_fa4_subprocess( |
376 | | - """ |
377 | | - import torch |
378 | | - from modalities.models.gpt2.gpt2_model import CausalSelfAttention |
379 | | -
|
380 | | - query_ref = torch.rand((2, 8, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True) |
381 | | - key_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True) |
382 | | - value_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True) |
383 | | -
|
384 | | - query_fa4 = query_ref.detach().clone().requires_grad_(True) |
385 | | - key_fa4 = key_ref.detach().clone().requires_grad_(True) |
386 | | - value_fa4 = value_ref.detach().clone().requires_grad_(True) |
387 | | -
|
388 | | - output_ref = CausalSelfAttention.execute_attention( |
389 | | - query_ref, key_ref, value_ref, dropout=0.0, attention_impl='pytorch_flash' |
390 | | - ) |
391 | | - output_fa4 = CausalSelfAttention.execute_attention( |
392 | | - query_fa4, key_fa4, value_fa4, dropout=0.0, attention_impl='dao_flash_v4' |
393 | | - ) |
394 | | - torch.testing.assert_close(output_ref, output_fa4, atol=2.5e-3, rtol=0.016) |
395 | | -
|
396 | | - output_ref.float().sum().backward() |
397 | | - output_fa4.float().sum().backward() |
398 | | - torch.cuda.synchronize() |
399 | | -
|
400 | | - torch.testing.assert_close(query_ref.grad, query_fa4.grad, atol=5e-3, rtol=0.02) |
401 | | - torch.testing.assert_close(key_ref.grad, key_fa4.grad, atol=5e-3, rtol=0.02) |
402 | | - torch.testing.assert_close(value_ref.grad, value_fa4.grad, atol=5e-3, rtol=0.02) |
403 | | - print('ok') |
404 | | - """ |
| 340 | +def test_dao_flash_v4_backward_approximate_equality(): |
| 341 | + query_ref = torch.rand((2, 8, 12, 64), dtype=torch.bfloat16, device="cuda", requires_grad=True) |
| 342 | + key_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device="cuda", requires_grad=True) |
| 343 | + value_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device="cuda", requires_grad=True) |
| 344 | + |
| 345 | + query_fa4 = query_ref.detach().clone().requires_grad_(True) |
| 346 | + key_fa4 = key_ref.detach().clone().requires_grad_(True) |
| 347 | + value_fa4 = value_ref.detach().clone().requires_grad_(True) |
| 348 | + |
| 349 | + output_ref = CausalSelfAttention.execute_attention( |
| 350 | + query_ref, key_ref, value_ref, dropout=0.0, attention_impl="pytorch_flash" |
405 | 351 | ) |
406 | | - assert result.stdout.strip().endswith("ok") |
407 | | - |
408 | | - |
409 | | -def _run_fa4_subprocess(code: str) -> subprocess.CompletedProcess[str]: |
410 | | - """Run flash attention 4 related code in a subprocess to isolate FA4's CUDA context |
411 | | - and avoid conflicts with other tests. |
412 | | - The code should print 'ok' if it runs successfully. |
413 | | - The function returns the CompletedProcess object, |
414 | | - which contains stdout and stderr for further inspection if needed. |
415 | | - TODO: This might be an A100 / SM80-specific issue, so we can consider removing this subprocess isolation |
416 | | - if we confirm that FA4 works well on newer architectures without it. |
417 | | - """ |
418 | | - env = os.environ.copy() |
419 | | - existing_pythonpath = env.get("PYTHONPATH") |
420 | | - env["PYTHONPATH"] = f"{SRC_ROOT}:{existing_pythonpath}" if existing_pythonpath else str(SRC_ROOT) |
421 | | - return subprocess.run( |
422 | | - [sys.executable, "-c", textwrap.dedent(code)], |
423 | | - cwd=REPO_ROOT, |
424 | | - env=env, |
425 | | - check=True, |
426 | | - capture_output=True, |
427 | | - text=True, |
| 352 | + output_fa4 = CausalSelfAttention.execute_attention( |
| 353 | + query_fa4, key_fa4, value_fa4, dropout=0.0, attention_impl="dao_flash_v4" |
428 | 354 | ) |
| 355 | + |
| 356 | + torch.testing.assert_close(output_ref, output_fa4, atol=2.5e-3, rtol=0.016) |
| 357 | + |
| 358 | + output_ref.float().sum().backward() |
| 359 | + output_fa4.float().sum().backward() |
| 360 | + torch.cuda.synchronize() |
| 361 | + |
| 362 | + torch.testing.assert_close(query_ref.grad, query_fa4.grad, atol=5e-3, rtol=0.02) |
| 363 | + torch.testing.assert_close(key_ref.grad, key_fa4.grad, atol=5e-3, rtol=0.02) |
| 364 | + torch.testing.assert_close(value_ref.grad, value_fa4.grad, atol=5e-3, rtol=0.02) |
0 commit comments