Skip to content

Commit 17da9d2

Browse files
authored
fix(jax): set default_matmul_precision to tensorfloat32 (deepmodeling#4726)
See jax-ml/jax#24909. Without setting this flag, the precision will become very low, which seems to be a bug (the documentation says the GPU uses tensorfloat32 or float32, but the default behavior seems wrong...). See https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html for what this option is. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Chores** - Updated environment configuration to set default matrix multiplication precision to "tensorfloat32" for improved performance with JAX. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent a5b1b1f commit 17da9d2

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

deepmd/jax/env.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
jax.config.update("jax_enable_x64", True)
1414
# jax.config.update("jax_debug_nans", True)
15+
# https://github.com/jax-ml/jax/issues/24909
16+
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
1517

1618
if os.environ.get("DP_DTYPE_PROMOTION_STRICT") == "1":
1719
jax.config.update("jax_numpy_dtype_promotion", "strict")

0 commit comments

Comments
 (0)