Skip to content

[JAX] Collective Gemm test fixes#3115

Open
jberchtold-nvidia wants to merge 2 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/fix-cgemm-test-setup
Open

[JAX] Collective Gemm test fixes#3115
jberchtold-nvidia wants to merge 2 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/fix-cgemm-test-setup

Conversation

@jberchtold-nvidia

Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps

greptile-apps Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes two categories of issues in the JAX Collective GEMM test suite: in the shell harness, it adds robust fallback paths for TE_PATH and XML_LOG_DIR, introduces per-test NCCL ID file isolation with proper cleanup, and improves signal handling; in the Python test files, it removes the args.process_id == 0 guard so result assertions run on every worker process rather than only the rank-0 process.

  • Shell harness (run_test_cgemm.sh): adds SCRIPT_DIR/REPO_ROOT detection for flexible path resolution, falls back to a mktemp temp dir when XML_LOG_DIR is not writable, creates a per-test NCCL ID subdirectory under a shared CGEMM_NCCL_FILE_BASE_DIR, cleans it up on exit, and splits the trap so INT/TERM correctly returns exit code 130.
  • Python tests (test_gemm.py, test_dense_grad.py, test_layernorm_mlp_grad.py): drops the args.process_id == 0 condition from the enable_result_check block so that all worker processes run assert_allclose, providing broader validation coverage since each process holds a full replica of the gathered tensors via PartitionSpec(None).

Confidence Score: 5/5

All changes are targeted test fixes with no impact on production library code; the logic is correct and the test harness improvements make it more robust.

The Python changes correctly drop the process_id == 0 guard because gathered tensors are fully replicated via PartitionSpec(None) before assertions run, so every rank holds identical data. The shell script changes are purely defensive: fallback directories, per-test NCCL ID isolation, and proper signal handling. No production code paths are touched.

No files require special attention; all changes are confined to the collective GEMM test examples directory.

Important Files Changed

Filename Overview
examples/jax/collective_gemm/run_test_cgemm.sh Adds robust TE_PATH/XML_LOG_DIR fallbacks, per-test NCCL ID directory isolation with cleanup, and corrected INT/TERM signal handling with exit code 130.
examples/jax/collective_gemm/test_gemm.py Removes args.process_id == 0 guard so all processes run result assertions; correct since gathered tensors are fully replicated via PartitionSpec(None).
examples/jax/collective_gemm/test_dense_grad.py Same process_id guard removal as test_gemm.py; assertions now run on all worker ranks.
examples/jax/collective_gemm/test_layernorm_mlp_grad.py Same process_id guard removal; all ranks now validate layernorm MLP gradient correctness.

Sequence Diagram

sequenceDiagram
    participant SH as run_test_cgemm.sh
    participant P0 as Process 0 (rank 0)
    participant PN as Process N (rank N)

    SH->>SH: mktemp CGEMM_NCCL_FILE_BASE_DIR
    loop For each TEST_CASE
        SH->>SH: mkdir CGEMM_NCCL_FILE_BASE_DIR/TEST_NAME
        SH->>SH: export NVTE_JAX_NCCL_FILE_PATH
        SH->>P0: "pytest --process-id=0"
        SH->>PN: "pytest --process-id=N"
        P0->>P0: compute + gather via PartitionSpec(None)
        PN->>PN: compute + gather via PartitionSpec(None)
        Note over P0,PN: All ranks now run assert_allclose
        P0->>P0: assert_allclose(ref, output)
        PN->>PN: assert_allclose(ref, output)
        SH->>SH: wait and check log for result
    end
    SH->>SH: cleanup and rm CGEMM_NCCL_FILE_BASE_DIR
Loading

Reviews (2): Last reviewed commit: "Fix other cgemm tests" | Re-trigger Greptile

Comment on lines +21 to +22
CGEMM_NCCL_FILE_DIR=$(mktemp -d /tmp/te_cgemm_nccl_ids.XXXXXX)
export NVTE_JAX_NCCL_FILE_PATH="${NVTE_JAX_NCCL_FILE_PATH:-$CGEMM_NCCL_FILE_DIR}"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 mktemp -d is not guarded against failure. If /tmp is full or unavailable, CGEMM_NCCL_FILE_DIR would be empty/unset and NVTE_JAX_NCCL_FILE_PATH would export a garbage value, causing every subsequent test case to fail with an unintelligible NCCL ID file error.

Suggested change
CGEMM_NCCL_FILE_DIR=$(mktemp -d /tmp/te_cgemm_nccl_ids.XXXXXX)
export NVTE_JAX_NCCL_FILE_PATH="${NVTE_JAX_NCCL_FILE_PATH:-$CGEMM_NCCL_FILE_DIR}"
CGEMM_NCCL_FILE_DIR=$(mktemp -d /tmp/te_cgemm_nccl_ids.XXXXXX) || { echo "ERROR: Failed to create temp dir for NCCL ID files"; exit 1; }
export NVTE_JAX_NCCL_FILE_PATH="${NVTE_JAX_NCCL_FILE_PATH:-$CGEMM_NCCL_FILE_DIR}"

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant