Describe the bug
When trying to run colabfold_batch on a machine equipped with the new NVIDIA RTX 5090 (Blackwell architecture), the process crashes immediately at Query 1 with a Segmentation Fault and MLIR optimization errors. It appears the pre-compiled jaxlib in the current pixi environment does not yet support Blackwell's Compute Capability.
Expected behavior
The model should compile the XLA graph and start predicting structures on the GPU, as it does on Ampere/Ada architectures.
To Reproduce
- Install
localcolabfold using standard pixi instructions.
- Run standard batch prediction:
pixi run colabfold_batch msa_batch results --num-models 1 --num-recycle 3
- Observe crash.
Logs / Error Messages
2026-06-28 18:44:43,611 Running colabfold 1.6.1
2026-06-28 18:44:49,812 Running on GPU
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1782665109.960637 20354 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
Segmentation fault (core dumped)
Attempting to disable MLIR (export XLA_FLAGS="--xla_gpu_enable_mlir_graph_optimization=false") results in an unknown flag error.
System details
- OS: Windows 11 WSL2 (Ubuntu 24.04.4 LTS) / Linux
- GPU: NVIDIA RTX 5090 (Blackwell)
- ColabFold version: 1.6.1 (local, via pixi)
Temporary Workaround (for others)
Forcing CPU execution (export JAX_PLATFORMS="cpu" and export CUDA_VISIBLE_DEVICES="") works perfectly, confirming the issue is strictly XLA/GPU related.
Are there any plans to update the jax/jaxlib dependencies to versions compiled with CUDA 12.8+ to support Blackwell?
Best regards
Karol007
Describe the bug
When trying to run
colabfold_batchon a machine equipped with the new NVIDIA RTX 5090 (Blackwell architecture), the process crashes immediately atQuery 1with a Segmentation Fault and MLIR optimization errors. It appears the pre-compiledjaxlibin the currentpixienvironment does not yet support Blackwell's Compute Capability.Expected behavior
The model should compile the XLA graph and start predicting structures on the GPU, as it does on Ampere/Ada architectures.
To Reproduce
localcolabfoldusing standard pixi instructions.pixi run colabfold_batch msa_batch results --num-models 1 --num-recycle 3Logs / Error Messages
2026-06-28 18:44:43,611 Running colabfold 1.6.1
2026-06-28 18:44:49,812 Running on GPU
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1782665109.960637 20354 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
Segmentation fault (core dumped)
Attempting to disable MLIR (
export XLA_FLAGS="--xla_gpu_enable_mlir_graph_optimization=false") results in an unknown flag error.System details
Temporary Workaround (for others)
Forcing CPU execution (
export JAX_PLATFORMS="cpu"andexport CUDA_VISIBLE_DEVICES="") works perfectly, confirming the issue is strictly XLA/GPU related.Are there any plans to update the
jax/jaxlibdependencies to versions compiled with CUDA 12.8+ to support Blackwell?Best regards
Karol007