Skip to content

Commit cf4d7d9

Browse files
committed
MAINT: bump JAX to 0.7.2
1 parent 7333098 commit cf4d7d9

2 files changed

Lines changed: 8 additions & 16 deletions

File tree

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,13 @@ dask-core = ">=2025.12.0" # No distributed, tornado, etc.
153153
sparse = ">=0.17.0"
154154

155155
[tool.pixi.feature.backends.target.linux-64.dependencies]
156-
# On CPU use >=0.7.0
157-
# On GPU, use 0.6.0 (0.6.2 and 0.7.0 both segfault); see jaxlib pin below.
158-
jax = ">=0.6.0"
156+
jax = ">=0.7.2"
159157

160158
[tool.pixi.feature.backends.target.osx-64.dependencies]
161-
jax = ">=0.6.0"
159+
jax = ">=0.7.2"
162160

163161
[tool.pixi.feature.backends.target.osx-arm64.dependencies]
164-
jax = ">=0.6.0"
162+
jax = ">=0.7.2"
165163

166164
[tool.pixi.feature.backends.target.win-64.dependencies]
167165
# jax = "*" # unavailable
@@ -175,23 +173,17 @@ jax = ">=0.6.0"
175173
[tool.pixi.feature.cuda-backends]
176174
system-requirements = { cuda = "12" }
177175

178-
[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
176+
[tool.pixi.feature.cuda-backends.target.linux.dependencies]
179177
cupy = ">=13.6.0"
180-
# JAX 0.6.2 and 0.7.0 segfault on CUDA
181-
jaxlib = { version = ">=0.6.0,!=0.6.2,!=0.7.0", build = "cuda12*" }
178+
jaxlib = { version = ">=0.7.2", build = "cuda12*" }
182179
pytorch = { version = ">=2.9.1", build = "cuda12*" }
183180

184-
[tool.pixi.feature.cuda-backends.target.osx-64.dependencies]
181+
[tool.pixi.feature.cuda-backends.target.osx.dependencies]
185182
# cupy = "*" # unavailable
186183
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
187184
# pytorch = { version = "*", build = "cuda12*" } # unavailable
188185

189-
[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies]
190-
# cupy = "*" # unavailable
191-
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
192-
# pytorch = { version = "*", build = "cuda12*" } # unavailable
193-
194-
[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
186+
[tool.pixi.feature.cuda-backends.target.win.dependencies]
195187
cupy = ">=13.6.0"
196188
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
197189
pytorch = { version = ">=2.9.1", build = "cuda12*" }

0 commit comments

Comments
 (0)