Skip to content

Commit 27fc168

Browse files
cyanguwapre-commit-ci[bot]greptile-apps[bot]
authored
[Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell (NVIDIA#2584)
* update FE to 1.17 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add determinism flag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add determinism to test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add determinism to qa/ Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move bias/dbias/versioning/dropout logic to C API Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update qa/L0_pytorch_unittest/test.sh make .xml file specific to deterministic tests in qa/ Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add determinism to Jax extension Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add determinism to Jax tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests/jax/test_fused_attn.py fix typo Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/common/fused_attn/fused_attn.cpp fix indentation Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the AI fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Jax extension call Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes based on comments Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix selection logic and fwd arg Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix version check in Jax test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix pytorch CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix Jax CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix non-/determinism logic and CI Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix formatting Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/fused_attn/fused_attn.cpp fix and/or logic Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to 9.18.1 for requirement Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reduce Jax CI tests for determinism Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent dfdd382 commit 27fc168

13 files changed

Lines changed: 299 additions & 44 deletions

File tree

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 102 files

qa/L0_jax_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
2929
mkdir -p "$XML_LOG_DIR"
3030

3131
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
32+
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py"
3233

3334
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
3435
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"

qa/L0_pytorch_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e
4545
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
4646
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
4747
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
48+
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
4849
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
4950
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
5051
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"

tests/jax/test_fused_attn.py

Lines changed: 203 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# See LICENSE for license information.
44
"""Tests for fused attention"""
5+
import os
56
from enum import Enum, auto
67
from dataclasses import dataclass, field
78
from functools import partial
@@ -49,6 +50,9 @@
4950
from distributed_test_base import assert_equal_collectives
5051
from utils import assert_allclose, print_debug_tensor_stats
5152

53+
# Get determinism
54+
_deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
55+
5256

5357
@pytest.fixture(autouse=True, scope="module")
5458
def init():
@@ -413,15 +417,25 @@ def _check_configs(self):
413417
pytest.skip(
414418
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
415419
)
416-
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
417-
if (
418-
get_device_compute_capability(0) >= 100
419-
and self.dropout_prob == 0.1
420-
and self.attn_bias_type is not AttnBiasType.NO_BIAS
421-
):
422-
pytest.skip(
423-
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
424-
)
420+
421+
if get_device_compute_capability(0) >= 100 and self.is_training:
422+
if FusedAttnHelper.is_non_deterministic_allowed() and (
423+
(self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
424+
or get_cudnn_version() < 90700
425+
):
426+
pytest.skip(
427+
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with"
428+
" dropout"
429+
)
430+
if not FusedAttnHelper.is_non_deterministic_allowed() and (
431+
self.dropout_prob != 0.0
432+
or self.attn_bias_type != AttnBiasType.NO_BIAS
433+
or get_cudnn_version() < 91801
434+
):
435+
pytest.skip(
436+
"For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or"
437+
" dropout"
438+
)
425439
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
426440
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
427441
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
@@ -1269,6 +1283,7 @@ def check_dqkv(primitive, reference, pad, idx):
12691283
pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
12701284
],
12711285
)
1286+
@pytest.mark.skipif(_deterministic, reason="Test non-determinism only")
12721287
class TestFusedAttn:
12731288
"""
12741289
Fused attention tester
@@ -1392,3 +1407,182 @@ def test_backward(
13921407
seq_desc_format,
13931408
)
13941409
runner.test_backward()
1410+
1411+
1412+
@pytest.mark.parametrize(
1413+
"attn_mask_type",
1414+
[
1415+
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
1416+
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
1417+
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
1418+
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
1419+
pytest.param(
1420+
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
1421+
),
1422+
],
1423+
)
1424+
@pytest.mark.parametrize(
1425+
"softmax_type",
1426+
[
1427+
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
1428+
],
1429+
)
1430+
@pytest.mark.parametrize(
1431+
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
1432+
[
1433+
# large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
1434+
pytest.param(
1435+
2,
1436+
1024,
1437+
2048,
1438+
12,
1439+
6,
1440+
128,
1441+
64,
1442+
jnp.bfloat16,
1443+
QKVLayout.BSHD_BSHD_BSHD,
1444+
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-SEPARATE",
1445+
),
1446+
pytest.param(
1447+
2,
1448+
1024,
1449+
2048,
1450+
12,
1451+
6,
1452+
128,
1453+
64,
1454+
jnp.bfloat16,
1455+
QKVLayout.THD_THD_THD,
1456+
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
1457+
),
1458+
],
1459+
)
1460+
@pytest.mark.parametrize(
1461+
"dropout_prob",
1462+
[
1463+
pytest.param(0.0, id="DROP_0.0"),
1464+
],
1465+
)
1466+
@pytest.mark.parametrize(
1467+
"swa",
1468+
[
1469+
pytest.param(False, id="NO_SWA"),
1470+
],
1471+
)
1472+
@pytest.mark.parametrize(
1473+
"seq_desc_format",
1474+
[
1475+
pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
1476+
],
1477+
)
1478+
@pytest.mark.skipif(not _deterministic, reason="Test determinism only")
1479+
class TestFusedAttnWithDeterminism:
1480+
"""
1481+
Fused attention tester with determinism
1482+
"""
1483+
1484+
@staticmethod
1485+
@pytest.mark.parametrize(
1486+
"is_training",
1487+
[
1488+
pytest.param(True, id="TRAINING"),
1489+
],
1490+
)
1491+
@pytest.mark.parametrize(
1492+
"attn_bias_type, bias_shape",
1493+
[
1494+
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
1495+
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
1496+
],
1497+
)
1498+
def _test_forward(
1499+
b,
1500+
s_q,
1501+
s_kv,
1502+
h_q,
1503+
h_kv,
1504+
d_qk,
1505+
d_v,
1506+
attn_bias_type,
1507+
attn_mask_type,
1508+
softmax_type,
1509+
dropout_prob,
1510+
dtype,
1511+
is_training,
1512+
qkv_layout,
1513+
bias_shape,
1514+
swa,
1515+
seq_desc_format,
1516+
):
1517+
"""
1518+
Test forward with parameterized configs
1519+
This test is not intended to run automatically during CI as it is time-consuming
1520+
It is kept for development and debugging
1521+
"""
1522+
TestFusedAttn._test_forward(
1523+
b,
1524+
s_q,
1525+
s_kv,
1526+
h_q,
1527+
h_kv,
1528+
d_qk,
1529+
d_v,
1530+
attn_bias_type,
1531+
attn_mask_type,
1532+
softmax_type,
1533+
dropout_prob,
1534+
dtype,
1535+
is_training,
1536+
qkv_layout,
1537+
bias_shape,
1538+
swa,
1539+
seq_desc_format,
1540+
)
1541+
1542+
@staticmethod
1543+
@pytest.mark.parametrize(
1544+
"attn_bias_type, bias_shape",
1545+
[
1546+
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
1547+
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
1548+
],
1549+
)
1550+
def test_backward(
1551+
b,
1552+
s_q,
1553+
s_kv,
1554+
h_q,
1555+
h_kv,
1556+
d_qk,
1557+
d_v,
1558+
attn_bias_type,
1559+
attn_mask_type,
1560+
softmax_type,
1561+
dropout_prob,
1562+
dtype,
1563+
qkv_layout,
1564+
bias_shape,
1565+
swa,
1566+
seq_desc_format,
1567+
):
1568+
"""
1569+
Test backward with parameterized configs
1570+
"""
1571+
TestFusedAttn.test_backward(
1572+
b,
1573+
s_q,
1574+
s_kv,
1575+
h_q,
1576+
h_kv,
1577+
d_qk,
1578+
d_v,
1579+
attn_bias_type,
1580+
attn_mask_type,
1581+
softmax_type,
1582+
dropout_prob,
1583+
dtype,
1584+
qkv_layout,
1585+
bias_shape,
1586+
swa,
1587+
seq_desc_format,
1588+
)

0 commit comments

Comments
 (0)