Skip to content

feat: Add GPU support to experimental JAX inference framework#272

Closed
tohaowu wants to merge 1 commit into
AI-Hypercomputer:mainfrom
tohaowu:feat/gpu-support-jax-experimental
Closed

feat: Add GPU support to experimental JAX inference framework#272
tohaowu wants to merge 1 commit into
AI-Hypercomputer:mainfrom
tohaowu:feat/gpu-support-jax-experimental

Conversation

@tohaowu

@tohaowu tohaowu commented May 21, 2025

Copy link
Copy Markdown

This commit introduces GPU support for the JAX-based experimental inference framework located in experimental/jax.

Key changes include:

  • Modified experimental/jax/requirements.txt to use jax[cuda-pip] allowing JAX to utilize NVIDIA GPUs.
  • Refined experimental/jax/inference/parallel/mesh.py to correctly handle GPU devices during mesh creation, ensuring robust platform detection alongside existing TPU support.
  • Verified that experimental/jax/inference/runtime/offline_inference.py correctly uses jax.devices() and is compatible with the new GPU handling in the mesh creation logic.
  • Updated experimental/jax/README.md to include instructions for setting up JAX with GPU support and to reflect that NVIDIA GPUs are now a supported backend.
  • Added a new test script experimental/jax/inference/entrypoint/run_gpu_test.py and instructions for you to verify GPU functionality with a small number of prompts.

These changes allow you, if you have compatible NVIDIA GPUs and CUDA setups, to run the experimental JAX inference framework, expanding its usability beyond TPUs.

This commit introduces GPU support for the JAX-based experimental inference framework located in `experimental/jax`.

Key changes include:

- Modified `experimental/jax/requirements.txt` to use `jax[cuda-pip]` allowing JAX to utilize NVIDIA GPUs.
- Refined `experimental/jax/inference/parallel/mesh.py` to correctly handle GPU devices during mesh creation, ensuring robust platform detection alongside existing TPU support.
- Verified that `experimental/jax/inference/runtime/offline_inference.py` correctly uses `jax.devices()` and is compatible with the new GPU handling in the mesh creation logic.
- Updated `experimental/jax/README.md` to include instructions for setting up JAX with GPU support and to reflect that NVIDIA GPUs are now a supported backend.
- Added a new test script `experimental/jax/inference/entrypoint/run_gpu_test.py` and instructions for you to verify GPU functionality with a small number of prompts.

These changes allow you, if you have compatible NVIDIA GPUs and CUDA setups, to run the experimental JAX inference framework, expanding its usability beyond TPUs.
@tohaowu tohaowu requested a review from vipannalla as a code owner May 21, 2025 21:12
@tohaowu tohaowu closed this May 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant