@@ -153,15 +153,13 @@ dask-core = ">=2025.12.0" # No distributed, tornado, etc.
153153sparse = " >=0.17.0"
154154
155155[tool .pixi .feature .backends .target .linux-64 .dependencies ]
156- # On CPU use >=0.7.0
157- # On GPU, use 0.6.0 (0.6.2 and 0.7.0 both segfault); see jaxlib pin below.
158- jax = " >=0.6.0"
156+ jax = " >=0.7.2"
159157
160158[tool .pixi .feature .backends .target .osx-64 .dependencies ]
161- jax = " >=0.6.0 "
159+ jax = " >=0.7.2 "
162160
163161[tool .pixi .feature .backends .target .osx-arm64 .dependencies ]
164- jax = " >=0.6.0 "
162+ jax = " >=0.7.2 "
165163
166164[tool .pixi .feature .backends .target .win-64 .dependencies ]
167165# jax = "*" # unavailable
@@ -175,23 +173,17 @@ jax = ">=0.6.0"
175173[tool .pixi .feature .cuda-backends ]
176174system-requirements = { cuda = " 12" }
177175
178- [tool .pixi .feature .cuda-backends .target .linux-64 .dependencies ]
176+ [tool .pixi .feature .cuda-backends .target .linux .dependencies ]
179177cupy = " >=13.6.0"
180- # JAX 0.6.2 and 0.7.0 segfault on CUDA
181- jaxlib = { version = " >=0.6.0,!=0.6.2,!=0.7.0" , build = " cuda12*" }
178+ jaxlib = { version = " >=0.7.2" , build = " cuda12*" }
182179pytorch = { version = " >=2.9.1" , build = " cuda12*" }
183180
184- [tool .pixi .feature .cuda-backends .target .osx-64 .dependencies ]
181+ [tool .pixi .feature .cuda-backends .target .osx .dependencies ]
185182# cupy = "*" # unavailable
186183# jaxlib = { version = "*", build = "cuda12*" } # unavailable
187184# pytorch = { version = "*", build = "cuda12*" } # unavailable
188185
189- [tool .pixi .feature .cuda-backends .target .osx-arm64 .dependencies ]
190- # cupy = "*" # unavailable
191- # jaxlib = { version = "*", build = "cuda12*" } # unavailable
192- # pytorch = { version = "*", build = "cuda12*" } # unavailable
193-
194- [tool .pixi .feature .cuda-backends .target .win-64 .dependencies ]
186+ [tool .pixi .feature .cuda-backends .target .win .dependencies ]
195187cupy = " >=13.6.0"
196188# jaxlib = { version = "*", build = "cuda12*" } # unavailable
197189pytorch = { version = " >=2.9.1" , build = " cuda12*" }
0 commit comments