Skip to content

Commit 20b20c7

Browse files
build(deps-dev): bump jax from 0.5.0 to 0.6.2 (#5102)
Bumps [jax](https://github.com/jax-ml/jax) from 0.5.0 to 0.6.2. <details> <summary>Release notes</summary> <p><em>Sourced from <a href="https://github.com/jax-ml/jax/releases">jax's releases</a>.</em></p> <blockquote> <h2>JAX v0.6.2</h2> <ul> <li> <p>New features:</p> <ul> <li>Added <code>jax.tree.broadcast</code> which implements a pytree prefix broadcasting helper.</li> </ul> </li> <li> <p>Changes</p> <ul> <li>The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12.</li> </ul> </li> </ul> <h2>JAX v0.6.1</h2> <ul> <li> <p>New features:</p> <ul> <li>Added <code>jax.lax.axis_size</code> which returns the size of the mapped axis given its name.</li> </ul> </li> <li> <p>Changes</p> <ul> <li>Additional checking for the versions of CUDA package dependencies was reenabled, having been accidentally disabled in a previous release.</li> <li>JAX nightly packages are now published to artifact registry. To install these packages, see the <a href="https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation">JAX installation guide</a>.</li> <li><code>jax.sharding.PartitionSpec</code> no longer inherits from a tuple.</li> <li><code>jax.ShapeDtypeStruct</code> is immutable now. Please use <code>.update</code> method to update your <code>ShapeDtypeStruct</code> instead of doing in-place updates.</li> </ul> </li> <li> <p>Deprecations</p> <ul> <li><code>jax.custom_derivatives.custom_jvp_call_jaxpr_p</code> is deprecated, and will be removed in JAX v0.7.0.</li> </ul> </li> </ul> <h2>JAX v0.6.0</h2> <ul> <li> <p>Breaking changes</p> <ul> <li><code>jax.numpy.array</code> no longer accepts <code>None</code>. This behavior was deprecated since November 2023 and is now removed.</li> <li>Removed the <code>config.jax_data_dependent_tracing_fallback</code> config option, which was added temporarily in v0.4.36 to allow users to opt out of the new &quot;stackless&quot; tracing machinery.</li> <li>Removed the <code>config.jax_eager_pmap</code> config option.</li> <li>Disallow the calling of <code>lower</code> and <code>trace</code> AOT APIs on the result of <code>jax.jit</code> if there have been subsequent wrappers applied. Previously this worked, but silently ignored the wrappers. The workaround is to apply <code>jax.jit</code> last among the wrappers, and similarly for <code>jax.pmap</code>. See <code>[#27873](https://github.com/jax-ml/jax/issues/27873)</code>.</li> <li>The <code>cuda12_pip</code> extra for <code>jax</code> has been removed; use <code>pip install jax[cuda12]</code> instead.</li> </ul> </li> <li> <p>Changes</p> <ul> <li>The minimum CuDNN version is v9.8.</li> <li>JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain supported.</li> <li>JAX package extras are now updated to use dash instead of underscore to align with PEP 685. For instance, if you were previously using <code>pip install jax[cuda12_local]</code></li> </ul> </li> </ul> <!-- raw HTML omitted --> </blockquote> <p>... (truncated)</p> </details> <details> <summary>Changelog</summary> <p><em>Sourced from <a href="https://github.com/jax-ml/jax/blob/main/CHANGELOG.md">jax's changelog</a>.</em></p> <blockquote> <h2>JAX 0.6.2 (June 17, 2025)</h2> <ul> <li> <p>New features:</p> <ul> <li>Added {func}<code>jax.tree.broadcast</code> which implements a pytree prefix broadcasting helper.</li> </ul> </li> <li> <p>Changes</p> <ul> <li>The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12.</li> </ul> </li> </ul> <h2>JAX 0.6.1 (May 21, 2025)</h2> <ul> <li> <p>New features:</p> <ul> <li>Added {func}<code>jax.lax.axis_size</code> which returns the size of the mapped axis given its name.</li> </ul> </li> <li> <p>Changes</p> <ul> <li>Additional checking for the versions of CUDA package dependencies was re-enabled, having been accidentally disabled in a previous release.</li> <li>JAX nightly packages are now published to artifact registry. To install these packages, see the <a href="https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation">JAX installation guide</a>.</li> <li><code>jax.sharding.PartitionSpec</code> no longer inherits from a tuple.</li> <li><code>jax.ShapeDtypeStruct</code> is immutable now. Please use <code>.update</code> method to update your <code>ShapeDtypeStruct</code> instead of doing in-place updates.</li> </ul> </li> <li> <p>Deprecations</p> <ul> <li><code>jax.custom_derivatives.custom_jvp_call_jaxpr_p</code> is deprecated, and will be removed in JAX v0.7.0.</li> </ul> </li> </ul> <h2>JAX 0.6.0 (April 16, 2025)</h2> <ul> <li> <p>Breaking changes</p> <ul> <li>{func}<code>jax.numpy.array</code> no longer accepts <code>None</code>. This behavior was deprecated since November 2023 and is now removed.</li> <li>Removed the <code>config.jax_data_dependent_tracing_fallback</code> config option, which was added temporarily in v0.4.36 to allow users to opt out of the new &quot;stackless&quot; tracing machinery.</li> <li>Removed the <code>config.jax_eager_pmap</code> config option.</li> <li>Disallow the calling of <code>lower</code> and <code>trace</code> AOT APIs on the result of <code>jax.jit</code> if there have been subsequent wrappers applied. Previously this worked, but silently ignored the wrappers. The workaround is to apply <code>jax.jit</code> last among the wrappers, and similarly for <code>jax.pmap</code>. See {jax-issue}<code>[#27873](https://github.com/jax-ml/jax/issues/27873)</code>.</li> <li>The <code>cuda12_pip</code> extra for <code>jax</code> has been removed; use <code>pip install jax[cuda12]</code> instead.</li> </ul> </li> <li> <p>Changes</p> <ul> <li>The minimum CuDNN version is v9.8.</li> <li>JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain supported.</li> </ul> </li> </ul> <!-- raw HTML omitted --> </blockquote> <p>... (truncated)</p> </details> <details> <summary>Commits</summary> <ul> <li><a href="https://github.com/jax-ml/jax/commit/1ad05bb26105f23ee7728b36cca12901fe70e187"><code>1ad05bb</code></a> Add LLVM patch to fix AVX512 codegeneration problem</li> <li><a href="https://github.com/jax-ml/jax/commit/8f81490ad4ed60cf923e6f0ef7a3bc1d708c7636"><code>8f81490</code></a> Prepare for JAX release 0.6.2</li> <li><a href="https://github.com/jax-ml/jax/commit/e4de90e6d0eb630db1f6b530b5c55d9fa6123317"><code>e4de90e</code></a> Update XLA dependency to use revision</li> <li><a href="https://github.com/jax-ml/jax/commit/02688e18fc48e1cb3f5572b2257b540dbbce28dd"><code>02688e1</code></a> [Pallas][Mosaic GPU] Enable collective MMA from TMEM.</li> <li><a href="https://github.com/jax-ml/jax/commit/dc9ef6145bba53afbabdc7a5748c4afa1cd16025"><code>dc9ef61</code></a> Merge pull request <a href="https://redirect.github.com/jax-ml/jax/issues/29410">#29410</a> from DanisNone:nn-type</li> <li><a href="https://github.com/jax-ml/jax/commit/353e7fac82a76f558ecf663cf5e32684e9ca175f"><code>353e7fa</code></a> Merge pull request <a href="https://redirect.github.com/jax-ml/jax/issues/29516">#29516</a> from jakevdp:enable-x64-warning</li> <li><a href="https://github.com/jax-ml/jax/commit/3d37b0d727f9353c04670fb5f4a8799e79ce5bf4"><code>3d37b0d</code></a> Merge pull request <a href="https://redirect.github.com/jax-ml/jax/issues/29504">#29504</a> from vfdev-5:tsan-ft-removed-fixed-suppression</li> <li><a href="https://github.com/jax-ml/jax/commit/c2cc9f9cc9fd559f85195bace3c511f435070709"><code>c2cc9f9</code></a> [pallas] <code>AbstractMemoryRef</code> now implements all functional update methods via...</li> <li><a href="https://github.com/jax-ml/jax/commit/0fd082136ea62123f0996b3dd5d286e3d0b02680"><code>0fd0821</code></a> [Pallas TPU] Add flag to enable using registers to keep track of slot info</li> <li><a href="https://github.com/jax-ml/jax/commit/f22896ac23928b05488bf4524818fb241fc3817b"><code>f22896a</code></a> jax.experimental.enable_x64: add warning to docstring</li> <li>Additional commits viewable in <a href="https://github.com/jax-ml/jax/compare/jax-v0.5.0...jax-v0.6.2">compare view</a></li> </ul> </details> <br /> [![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=jax&package-manager=pip&previous-version=0.5.0&new-version=0.6.2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) --- <details> <summary>Dependabot commands and options</summary> <br /> You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show <dependency name> ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) </details> --------- Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent 7e50200 commit 20b20c7

2 files changed

Lines changed: 1 addition & 7 deletions

File tree

.github/workflows/test_cuda.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@ jobs:
3636
with:
3737
useLocalCache: true
3838
useCloudCache: false
39-
- run: |
40-
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb \
41-
&& sudo dpkg -i cuda-keyring_1.0-1_all.deb \
42-
&& sudo apt-get update \
43-
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
44-
if: false # skip as we use nvidia image
4539
- run: python -m pip install -U uv
4640
- run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax "jax[cuda12]"
4741
- run: |

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ pin_pytorch_gpu = [
169169
"torch>=2.7,<2.10",
170170
]
171171
pin_jax = [
172-
"jax==0.5.0;python_version>='3.10'",
172+
"jax==0.6.2;python_version>='3.10'",
173173
]
174174

175175
[tool.setuptools_scm]

0 commit comments

Comments
 (0)