Skip to content

Commit 317fbfb

Browse files
committed
Fix: Eagerly initialize ML Diagnostics to prevent Protobuf descriptor collision crash
1 parent b2153a3 commit 317fbfb

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

src/maxtext/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,16 @@
3333
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "0")
3434
del os
3535

36+
import google_cloud_mldiagnostics # pylint: disable=unused-import
37+
from google_cloud_mldiagnostics.utils.libtpu_utils import libtpu_metric
38+
from maxtext.utils import max_logging
39+
40+
try:
41+
max_logging.info("ML diagnostics initialization")
42+
libtpu_metric._initialize()
43+
except Exception as e:
44+
max_logging.warning(f"ML diagnostics initialization failed: {e}")
45+
3646
from jax.sharding import Mesh
3747

3848
from maxtext.configs import pyconfig

0 commit comments

Comments
 (0)