[JAX] Collective Gemm test fixes#3115
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile SummaryThis PR fixes two categories of issues in the JAX Collective GEMM test suite: in the shell harness, it adds robust fallback paths for
Confidence Score: 5/5All changes are targeted test fixes with no impact on production library code; the logic is correct and the test harness improvements make it more robust. The Python changes correctly drop the process_id == 0 guard because gathered tensors are fully replicated via PartitionSpec(None) before assertions run, so every rank holds identical data. The shell script changes are purely defensive: fallback directories, per-test NCCL ID isolation, and proper signal handling. No production code paths are touched. No files require special attention; all changes are confined to the collective GEMM test examples directory. Important Files Changed
Sequence DiagramsequenceDiagram
participant SH as run_test_cgemm.sh
participant P0 as Process 0 (rank 0)
participant PN as Process N (rank N)
SH->>SH: mktemp CGEMM_NCCL_FILE_BASE_DIR
loop For each TEST_CASE
SH->>SH: mkdir CGEMM_NCCL_FILE_BASE_DIR/TEST_NAME
SH->>SH: export NVTE_JAX_NCCL_FILE_PATH
SH->>P0: "pytest --process-id=0"
SH->>PN: "pytest --process-id=N"
P0->>P0: compute + gather via PartitionSpec(None)
PN->>PN: compute + gather via PartitionSpec(None)
Note over P0,PN: All ranks now run assert_allclose
P0->>P0: assert_allclose(ref, output)
PN->>PN: assert_allclose(ref, output)
SH->>SH: wait and check log for result
end
SH->>SH: cleanup and rm CGEMM_NCCL_FILE_BASE_DIR
Reviews (2): Last reviewed commit: "Fix other cgemm tests" | Re-trigger Greptile |
| CGEMM_NCCL_FILE_DIR=$(mktemp -d /tmp/te_cgemm_nccl_ids.XXXXXX) | ||
| export NVTE_JAX_NCCL_FILE_PATH="${NVTE_JAX_NCCL_FILE_PATH:-$CGEMM_NCCL_FILE_DIR}" |
There was a problem hiding this comment.
mktemp -d is not guarded against failure. If /tmp is full or unavailable, CGEMM_NCCL_FILE_DIR would be empty/unset and NVTE_JAX_NCCL_FILE_PATH would export a garbage value, causing every subsequent test case to fail with an unintelligible NCCL ID file error.
| CGEMM_NCCL_FILE_DIR=$(mktemp -d /tmp/te_cgemm_nccl_ids.XXXXXX) | |
| export NVTE_JAX_NCCL_FILE_PATH="${NVTE_JAX_NCCL_FILE_PATH:-$CGEMM_NCCL_FILE_DIR}" | |
| CGEMM_NCCL_FILE_DIR=$(mktemp -d /tmp/te_cgemm_nccl_ids.XXXXXX) || { echo "ERROR: Failed to create temp dir for NCCL ID files"; exit 1; } | |
| export NVTE_JAX_NCCL_FILE_PATH="${NVTE_JAX_NCCL_FILE_PATH:-$CGEMM_NCCL_FILE_DIR}" |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: