Skip to content

Commit 25d2252

Browse files
authored
MAINT: bump JAX to 0.7.2 (#578)
1 parent 7333098 commit 25d2252

3 files changed

Lines changed: 11 additions & 18 deletions

File tree

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,13 @@ dask-core = ">=2025.12.0" # No distributed, tornado, etc.
153153
sparse = ">=0.17.0"
154154

155155
[tool.pixi.feature.backends.target.linux-64.dependencies]
156-
# On CPU use >=0.7.0
157-
# On GPU, use 0.6.0 (0.6.2 and 0.7.0 both segfault); see jaxlib pin below.
158-
jax = ">=0.6.0"
156+
jax = ">=0.7.2"
159157

160158
[tool.pixi.feature.backends.target.osx-64.dependencies]
161-
jax = ">=0.6.0"
159+
jax = ">=0.7.2"
162160

163161
[tool.pixi.feature.backends.target.osx-arm64.dependencies]
164-
jax = ">=0.6.0"
162+
jax = ">=0.7.2"
165163

166164
[tool.pixi.feature.backends.target.win-64.dependencies]
167165
# jax = "*" # unavailable
@@ -175,23 +173,17 @@ jax = ">=0.6.0"
175173
[tool.pixi.feature.cuda-backends]
176174
system-requirements = { cuda = "12" }
177175

178-
[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
176+
[tool.pixi.feature.cuda-backends.target.linux.dependencies]
179177
cupy = ">=13.6.0"
180-
# JAX 0.6.2 and 0.7.0 segfault on CUDA
181-
jaxlib = { version = ">=0.6.0,!=0.6.2,!=0.7.0", build = "cuda12*" }
178+
jaxlib = { version = ">=0.7.2", build = "cuda12*" }
182179
pytorch = { version = ">=2.9.1", build = "cuda12*" }
183180

184-
[tool.pixi.feature.cuda-backends.target.osx-64.dependencies]
181+
[tool.pixi.feature.cuda-backends.target.osx.dependencies]
185182
# cupy = "*" # unavailable
186183
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
187184
# pytorch = { version = "*", build = "cuda12*" } # unavailable
188185

189-
[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies]
190-
# cupy = "*" # unavailable
191-
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
192-
# pytorch = { version = "*", build = "cuda12*" } # unavailable
193-
194-
[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
186+
[tool.pixi.feature.cuda-backends.target.win.dependencies]
195187
cupy = ">=13.6.0"
196188
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
197189
pytorch = { version = ">=2.9.1", build = "cuda12*" }

tests/test_funcs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ def test_complex(self, xp: ModuleType):
521521
expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128)
522522
xp_assert_close(actual, expect)
523523

524+
@pytest.mark.xfail_xp_backend(Backend.JAX_GPU, reason="jax#32296")
524525
@pytest.mark.xfail_xp_backend(Backend.JAX, reason="jax#32296")
525526
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#877")
526527
def test_empty(self, xp: ModuleType):
@@ -989,14 +990,14 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
989990
assert get_device(res) == device
990991

991992
def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device):
992-
a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device)
993+
a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device, dtype=xp.float64)
993994
b = 1
994995
res = isclose(a, b)
995996
assert get_device(res) == device
996997
xp_assert_equal(res, xp.asarray([False, False, False, False, True]))
997998

998999
a = 0.1
999-
b = xp.asarray([0.01, 0.5, 0.8, 0.9, 0.100001], device=device)
1000+
b = xp.asarray([0.01, 0.5, 0.8, 0.9, 0.100001], device=device, dtype=xp.float64)
10001001
res = isclose(a, b)
10011002
assert get_device(res) == device
10021003
xp_assert_equal(res, xp.asarray([False, False, False, False, True]))

0 commit comments

Comments
 (0)