Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions examples/jax/collective_gemm/run_test_cgemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,24 @@

NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}

SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
REPO_ROOT=$(cd "$SCRIPT_DIR/../../.." && pwd)

: ${TE_PATH:=/opt/transformerengine}
if [ ! -d "$TE_PATH" ] && [ -f "$REPO_ROOT/tests/jax/pytest.ini" ]; then
TE_PATH="$REPO_ROOT"
fi

: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
if ! mkdir -p "$XML_LOG_DIR" 2>/dev/null; then
XML_LOG_DIR=$(mktemp -d /tmp/te_cgemm_xml.XXXXXX)
echo "XML_LOG_DIR is not writable; using $XML_LOG_DIR"
fi

CGEMM_NCCL_FILE_BASE_DIR=
if [ -z "${NVTE_JAX_NCCL_FILE_PATH:-}" ]; then
CGEMM_NCCL_FILE_BASE_DIR=$(mktemp -d /tmp/te_cgemm_nccl_ids.XXXXXX)
fi

# Check if NVLINK is supported before running tests
echo "*** Checking NVLINK support***"
Expand Down Expand Up @@ -77,10 +92,14 @@ cleanup() {
kill -KILL "$pid" 2>/dev/null || true
fi
done
if [ -n "$CGEMM_NCCL_FILE_BASE_DIR" ] && [ -d "$CGEMM_NCCL_FILE_BASE_DIR" ]; then
rm -rf "$CGEMM_NCCL_FILE_BASE_DIR"
fi
}

# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM
trap cleanup EXIT
trap 'cleanup; exit 130' INT TERM

# Run each test case across all GPUs
for TEST_CASE in "${TEST_CASES[@]}"; do
Expand All @@ -89,6 +108,10 @@ for TEST_CASE in "${TEST_CASES[@]}"; do

# Extract just the test method name for log/xml file naming
TEST_NAME=$(echo "$TEST_CASE" | awk -F'::' '{print $NF}')
if [ -n "$CGEMM_NCCL_FILE_BASE_DIR" ]; then
export NVTE_JAX_NCCL_FILE_PATH="$CGEMM_NCCL_FILE_BASE_DIR/$TEST_NAME"
mkdir -p "$NVTE_JAX_NCCL_FILE_PATH"
fi

# Clear PIDs array for this test case
PIDS=()
Expand Down
2 changes: 1 addition & 1 deletion examples/jax/collective_gemm/test_dense_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def run_dense_grad_tests(args, mesh=None):
jax.block_until_ready(gathered_grads)
jax.block_until_ready(gathered_ref_grads)

if args.enable_result_check and args.process_id == 0:
if args.enable_result_check:
tol_dtype = get_tolerance_dtype(quantizer_set)
assert_allclose(ref_output, output, dtype=tol_dtype)
for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
Expand Down
2 changes: 1 addition & 1 deletion examples/jax/collective_gemm/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def run_gemm_tests(args, mesh=None):
jax.block_until_ready(gathered_ref_output)
jax.block_until_ready(gathered_output)

if args.enable_result_check and args.process_id == 0:
if args.enable_result_check:
# CGEMM + RS + BF16 uses TE's reduce_bf16 kernel (sequential left-to-right in FP32).
# With catastrophic cancellation the output is near zero while the absolute diff can
# reach 1 ULP of the partial GEMM magnitude (~0.0625 for typical transformer
Expand Down
2 changes: 1 addition & 1 deletion examples/jax/collective_gemm/test_layernorm_mlp_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None):
jax.block_until_ready(gathered_grads)
jax.block_until_ready(gathered_ref_grads)

if args.enable_result_check and args.process_id == 0:
if args.enable_result_check:
tol_dtype = get_tolerance_dtype(quantizer_sets[0])
assert_allclose(ref_output, output, dtype=tol_dtype)
for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
Expand Down