Skip to content

Commit 9e48a23

Browse files
committed
Disable JAX compilation cache when dump_hlo is True
1 parent c3dc904 commit 9e48a23

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

src/maxtext/configs/pyconfig.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,14 @@ def _initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConf
541541
if "pytest" not in sys.modules:
542542
max_utils.maybe_initialize_jax_distributed_system(pydantic_kwargs)
543543
if pydantic_kwargs.get("jax_cache_dir"):
544-
from jax.experimental.compilation_cache import compilation_cache # pylint: disable=import-outside-toplevel
544+
if pydantic_kwargs.get("dump_hlo"):
545+
max_logging.warning(
546+
"JAX compilation cache is disabled because dump_hlo is True. HLO dumping requires recompilation."
547+
)
548+
else:
549+
from jax.experimental.compilation_cache import compilation_cache # pylint: disable=import-outside-toplevel
545550

546-
compilation_cache.set_cache_dir(os.path.expanduser(pydantic_kwargs["jax_cache_dir"]))
551+
compilation_cache.set_cache_dir(os.path.expanduser(pydantic_kwargs["jax_cache_dir"]))
547552

548553
pydantic_config = types.MaxTextConfig(**pydantic_kwargs)
549554
config = HyperParameters(pydantic_config)

0 commit comments

Comments
 (0)