Skip to content

Commit f6fca0f

Browse files
Merge pull request #3806 from AI-Hypercomputer:xfgu-onduty
PiperOrigin-RevId: 910310785
2 parents 1e72989 + a935c39 commit f6fca0f

3 files changed

Lines changed: 34 additions & 25 deletions

File tree

src/maxtext/common/gcloud_stub.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,29 @@ def _import():
609609

610610
# ------------------------- TensorBoardX --------------------------
611611

612+
613+
class StubSummaryWriter:
614+
"""Stubbed TensorBoardX SummaryWriter replacement."""
615+
616+
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
617+
del args, kwargs
618+
619+
def add_text(self, *args, **kwargs):
620+
pass
621+
622+
def add_scalar(self, *args, **kwargs):
623+
pass
624+
625+
def add_histogram(self, *args, **kwargs):
626+
pass
627+
628+
def flush(self):
629+
pass
630+
631+
def close(self):
632+
pass
633+
634+
612635
try:
613636
if not is_decoupled(): # Only attempt real import when not decoupled
614637
from tensorboardX import writer # type: ignore # pylint: disable=import-outside-toplevel,unused-import
@@ -619,30 +642,10 @@ def _import():
619642
except Exception: # pragma: no cover - provide stub fallback # pylint: disable=broad-exception-caught
620643
_TENSORBOARDX_AVAILABLE = False
621644

622-
class _StubSummaryWriter:
623-
"""Stubbed TensorBoardX SummaryWriter replacement."""
624-
625-
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
626-
del args, kwargs
627-
628-
def add_text(self, *args, **kwargs):
629-
pass
630-
631-
def add_scalar(self, *args, **kwargs):
632-
pass
633-
634-
def add_histogram(self, *args, **kwargs):
635-
pass
636-
637-
def flush(self):
638-
pass
639-
640-
def close(self):
641-
pass
642-
643645
class writer: # pylint: disable=too-few-public-methods
644-
SummaryWriter = _StubSummaryWriter
646+
SummaryWriter = StubSummaryWriter
645647

646648

647649
__all__.append("writer")
648650
__all__.append("_TENSORBOARDX_AVAILABLE")
651+
__all__.append("StubSummaryWriter")

src/maxtext/common/metric_logger.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class MetricLogger:
9292
"""
9393

9494
def __init__(self, config, learning_rate_schedule):
95-
self.writer = max_utils.initialize_summary_writer(config.tensorboard_dir, config.run_name)
95+
self.writer = max_utils.initialize_summary_writer(config.tensorboard_dir, config.run_name, config.enable_tensorboard)
9696
self.config = config
9797
self.metadata = {}
9898
self.running_gcs_metrics = [] if config.gcs_metrics else None
@@ -295,6 +295,8 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step):
295295

296296
def write_setup_info_to_tensorboard(self, params):
297297
"""Writes setup information like train config params, num model params, and XLA flags to TensorBoard."""
298+
if not self.config.enable_tensorboard:
299+
return
298300
num_model_parameters = max_utils.calculate_num_params_from_pytree(params)
299301
self.metadata[MetadataKey.PER_DEVICE_TFLOPS], _, _ = maxtext_utils.calculate_tflops_training_per_device(self.config)
300302
self.metadata[MetadataKey.PER_DEVICE_TOKENS] = maxtext_utils.calculate_tokens_training_per_device(self.config)

src/maxtext/utils/max_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
from maxtext.utils import elastic_utils
4646
from maxtext.common.gcloud_stub import is_decoupled
47-
from maxtext.common.gcloud_stub import writer, _TENSORBOARDX_AVAILABLE
47+
from maxtext.common.gcloud_stub import writer, _TENSORBOARDX_AVAILABLE, StubSummaryWriter
4848
from maxtext.utils import max_logging
4949
from maxtext.common.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN
5050

@@ -182,7 +182,7 @@ def summarize_size_from_pytree(params):
182182
return num_params, num_bytes, num_bytes / num_params
183183

184184

185-
def initialize_summary_writer(tensorboard_dir, run_name):
185+
def initialize_summary_writer(tensorboard_dir, run_name, enable_tensorboard=True):
186186
"""Return a tensorboardX SummaryWriter or a no-op stub.
187187
188188
In decoupled mode (no Google Cloud), this prefers a repo-local
@@ -191,6 +191,10 @@ def initialize_summary_writer(tensorboard_dir, run_name):
191191
if jax.process_index() != 0:
192192
return None
193193

194+
if not enable_tensorboard:
195+
max_logging.log("TensorBoard disabled; using no-op SummaryWriter.")
196+
return StubSummaryWriter()
197+
194198
if not _TENSORBOARDX_AVAILABLE:
195199
max_logging.log("tensorboardX not available; using no-op SummaryWriter.")
196200
return writer.SummaryWriter()

0 commit comments

Comments
 (0)