Fix mesh creation to use local devices for single-host benchmarks#125
Open
simrankaurb wants to merge 1 commit into
Open
Fix mesh creation to use local devices for single-host benchmarks#125simrankaurb wants to merge 1 commit into
simrankaurb wants to merge 1 commit into
Conversation
Collaborator
|
Thanks @simrankaurb , Have we verified the functionality and metrics correctness on this change? Also, should we also cover the HBM and H2D/D2H? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Enable independent single-host execution for GEMM and Collectives benchmarks
This PR enables single-host TPU benchmarks (GEMM and single-node collectives) to execute independently on individual hosts within a multi-host slice, without requiring coordination from other hosts.
Previously, mesh creation for these benchmarks defaulted to global devices (
jax.devices()). When running diagnostics on a single host while other hosts in the slice were idle, JAX would attempt to coordinate execution across the entire global mesh, resulting in execution blocking.To allow independent single-host execution:
benchmark_utils.pyto scope the device mesh tolocal_devices()and uselocal_device_count()for single-host sharding strategies.benchmark_collectives.pyto check the requestedici_size. If the required devices fit within a single node (e.g.,ici_size <= 8), the mesh is strictly scoped tolocal_devices(), only falling back to globaljax.devices()for multi-host workloads (e.g.,ici_size: 16).This allows diagnostic workloads to run independently on active nodes, while preserving standard global mesh behavior for multi-host benchmarks.