|
2 | 2 | # |
3 | 3 | # See LICENSE for license information. |
4 | 4 |
|
5 | | -set -e |
| 5 | +function error_exit() { |
| 6 | + echo "Error: $1" |
| 7 | + exit 1 |
| 8 | +} |
| 9 | + |
| 10 | +function test_fail() { |
| 11 | + RET=1 |
| 12 | + FAILED_CASES="$FAILED_CASES $1" |
| 13 | + echo "Error: sub-test failed: $1" |
| 14 | +} |
| 15 | + |
| 16 | +RET=0 |
| 17 | +FAILED_CASES="" |
6 | 18 |
|
7 | 19 | : ${TE_PATH:=/opt/transformerengine} |
8 | 20 | : ${XML_LOG_DIR:=/logs} |
9 | 21 | mkdir -p "$XML_LOG_DIR" |
10 | 22 |
|
11 | | -pip3 install pytest==8.2.1 |
| 23 | +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" |
12 | 24 |
|
13 | 25 | # Limit parallel build jobs to avoid overwhelming system resources |
14 | 26 | export MAX_JOBS=32 |
|
41 | 53 | fi |
42 | 54 |
|
43 | 55 | # Run tests |
44 | | - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py |
| 56 | + NUM_GPUS=$(nvidia-smi -L | wc -l) |
| 57 | + echo "Detected $NUM_GPUS GPU(s)" |
| 58 | + if [ "$NUM_GPUS" -ge 5 ]; then |
| 59 | + CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 )) |
| 60 | + CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS) |
| 61 | + echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)" |
| 62 | + |
| 63 | + CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ |
| 64 | + --junitxml=$XML_LOG_DIR/pytest.xml \ |
| 65 | + $TE_PATH/tests/pytorch/attention/test_attention.py & |
| 66 | + PID_ATTN=$! |
| 67 | + |
| 68 | + CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ |
| 69 | + --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml \ |
| 70 | + $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & |
| 71 | + PID_CP=$! |
45 | 72 |
|
| 73 | + wait $PID_ATTN || test_fail "test_attention.py" |
| 74 | + wait $PID_CP || test_fail "test_attention_with_cp.py" |
| 75 | + else |
| 76 | + echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)" |
| 77 | + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" |
| 78 | + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" |
| 79 | + fi |
46 | 80 | done |
| 81 | + |
| 82 | +if [ "$RET" -ne 0 ]; then |
| 83 | + echo "Error in the following test cases:$FAILED_CASES" |
| 84 | + exit 1 |
| 85 | +fi |
| 86 | +echo "All tests passed" |
| 87 | +exit 0 |
0 commit comments