Commit 17da9d2
authored
See jax-ml/jax#24909. Without setting this
flag, the precision will become very low, which seems to be a bug (the
documentation says the GPU uses tensorfloat32 or float32, but the
default behavior seems wrong...).
See
https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html
for what this option is.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
- **Chores**
- Updated environment configuration to set default matrix multiplication
precision to "tensorfloat32" for improved performance with JAX.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent a5b1b1f commit 17da9d2
1 file changed
Lines changed: 2 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
| 16 | + | |
15 | 17 | | |
16 | 18 | | |
17 | 19 | | |
| |||
0 commit comments