Skip to content

Commit 571c86c

Browse files
committed
Add wandb support
1 parent 7c68a9d commit 571c86c

4 files changed

Lines changed: 29 additions & 1 deletion

File tree

src/maxtext/common/metric_logger.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import sys
2222
import queue
2323
import enum
24+
import wandb
2425

2526
import numpy as np
2627

@@ -109,6 +110,13 @@ def __init__(self, config, learning_rate_schedule):
109110
self._pending_eval_step_count = 0
110111
if self.config.managed_mldiagnostics:
111112
ManagedMLDiagnostics(config) # Initialize the MLRun instance.
113+
114+
if self.config.enable_wandb and jax.process_index() == 0:
115+
wandb.init(
116+
project=config.wandb_project_name,
117+
name=config.wandb_run_name,
118+
resume="allow",
119+
) # Initialize wandb logger.
112120

113121
def reset_eval_metrics(self):
114122
"""Resets the cumulative metrics dictionary for a new evaluation run."""
@@ -131,6 +139,9 @@ def write_metrics(self, metrics, step, metric_type="train"):
131139

132140
if self.config.managed_mldiagnostics:
133141
self.write_metrics_to_managed_mldiagnostics(metrics, step)
142+
143+
if self.config.enable_wandb and jax.process_index() == 0:
144+
self.write_metrics_to_wandb(metrics, step)
134145

135146
if metric_type == "train":
136147
self._maybe_abort_after_write_metrics(metrics)
@@ -326,6 +337,16 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step):
326337
mapped_metric_name = _METRICS_TO_MANAGED.get(metric_name, metric_name)
327338
mldiag.metrics.record(mapped_metric_name, value, step=int(step))
328339

340+
def write_metrics_to_wandb(self, metrics, step):
341+
"""Write metrics to weights and biases (wandb)."""
342+
flat_metrics = {}
343+
for key, val in metrics.get("scalar", {}).items():
344+
flat_metrics[key] = float(val)
345+
for key, val in metrics.get("scalars", {}).items():
346+
for subkey, subval in val.items():
347+
flat_metrics[f"{key}/{subkey}"] = float(subval)
348+
wandb.log(flat_metrics, step=step)
349+
329350
def write_setup_info_to_tensorboard(self, params):
330351
"""Writes setup information like train config params, num model params, and XLA flags to TensorBoard."""
331352
num_model_parameters = max_utils.calculate_num_params_from_pytree(params)

src/maxtext/configs/base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. if empty,
101101
# if true save metrics such as loss and tflops to gcs in {base_output_directory}/{run_name}/metrics/
102102
gcs_metrics: false
103103

104+
enable_wandb: False
105+
wandb_project_name: ""
106+
wandb_run_name: ""
107+
104108
# if true save config to gcs in {base_output_directory}/{run_name}/
105109
save_config_to_gcs: false
106110

src/maxtext/configs/pyconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
222222
for key, value in raw_keys.items():
223223
if key not in valid_fields:
224224
logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key))
225-
raise ValueError(f"{key!r} not in {', '.join(map(repr, valid_fields))}.")
225+
raise ValueError(f"{key!r} not in {", ".join(map(repr, valid_fields))}.")
226226

227227
new_value = value
228228
if isinstance(new_value, str) and new_value.lower() == "none":

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,6 +1742,9 @@ class Metrics(BaseModel):
17421742
False,
17431743
description="Whether to enable Tunix-managed metrics measurement. The metrics will be uploaded to tensorboard.",
17441744
)
1745+
enable_wandb: bool = Field(False, description="Enable Weights & Biases logging.")
1746+
wandb_project_name: str = Field("maxtext", description="Weights & Biases project name.")
1747+
wandb_run_name: str = Field("", description="Weights & Biases run name. If empty, a default name is generated.")
17451748

17461749

17471750
class ManagedMLDiagnostics(BaseModel):

0 commit comments

Comments
 (0)