Replace jax.sharding.use_mesh with jax.set_mesh. jax.set_mesh can act as a global setter or a context manager.
#134
| Job | Run time |
|---|---|
| 3m 46s | |
| 3m 46s |