diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 44771ecb05..8d29325708 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -21,6 +21,7 @@ import sys import queue import enum +import wandb import numpy as np @@ -109,6 +110,13 @@ def __init__(self, config, learning_rate_schedule): self._pending_eval_step_count = 0 if self.config.managed_mldiagnostics: ManagedMLDiagnostics(config) # Initialize the MLRun instance. + + if self.config.enable_wandb and jax.process_index() == 0: + wandb.init( + project=config.wandb_project_name, + name=config.wandb_run_name, + resume="allow", + ) # Initialize wandb logger. def reset_eval_metrics(self): """Resets the cumulative metrics dictionary for a new evaluation run.""" @@ -131,6 +139,9 @@ def write_metrics(self, metrics, step, metric_type="train"): if self.config.managed_mldiagnostics: self.write_metrics_to_managed_mldiagnostics(metrics, step) + + if self.config.enable_wandb and jax.process_index() == 0: + self.write_metrics_to_wandb(metrics, step) if metric_type == "train": self._maybe_abort_after_write_metrics(metrics) @@ -326,6 +337,16 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step): mapped_metric_name = _METRICS_TO_MANAGED.get(metric_name, metric_name) mldiag.metrics.record(mapped_metric_name, value, step=int(step)) + def write_metrics_to_wandb(self, metrics, step): + """Write metrics to weights and biases (wandb).""" + flat_metrics = {} + for key, val in metrics.get("scalar", {}).items(): + flat_metrics[key] = float(val) + for key, val in metrics.get("scalars", {}).items(): + for subkey, subval in val.items(): + flat_metrics[f"{key}/{subkey}"] = float(subval) + wandb.log(flat_metrics, step=step) + def write_setup_info_to_tensorboard(self, params): """Writes setup information like train config params, num model params, and XLA flags to TensorBoard.""" num_model_parameters = max_utils.calculate_num_params_from_pytree(params) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 2570f5c915..c808c4cd65 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -101,6 +101,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. if empty, # if true save metrics such as loss and tflops to gcs in {base_output_directory}/{run_name}/metrics/ gcs_metrics: false +enable_wandb: False +wandb_project_name: "" +wandb_run_name: "" + # if true save config to gcs in {base_output_directory}/{run_name}/ save_config_to_gcs: false diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index 0dfb76f29c..2518a33af6 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -222,7 +222,7 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]: for key, value in raw_keys.items(): if key not in valid_fields: logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key)) - raise ValueError(f"{key!r} not in {', '.join(map(repr, valid_fields))}.") + raise ValueError(f"{key!r} not in {", ".join(map(repr, valid_fields))}.") new_value = value if isinstance(new_value, str) and new_value.lower() == "none": diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 4bb16d65c6..02ef1701bc 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1742,6 +1742,9 @@ class Metrics(BaseModel): False, description="Whether to enable Tunix-managed metrics measurement. The metrics will be uploaded to tensorboard.", ) + enable_wandb: bool = Field(False, description="Enable Weights & Biases logging.") + wandb_project_name: str = Field("maxtext", description="Weights & Biases project name.") + wandb_run_name: str = Field("", description="Weights & Biases run name. If empty, a default name is generated.") class ManagedMLDiagnostics(BaseModel):