Skip to content

Commit 22de737

Browse files
committed
Add wandb support
1 parent c4b5e64 commit 22de737

4 files changed

Lines changed: 30 additions & 1 deletion

File tree

src/maxtext/common/metric_logger.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import queue
2222
import enum
23+
import wandb
2324

2425
import numpy as np
2526

@@ -99,8 +100,16 @@ def __init__(self, config, learning_rate_schedule):
99100
self.learning_rate_schedule = learning_rate_schedule
100101
self.cumulative_eval_metrics = {"scalar": defaultdict(float)}
101102
self.buffered_train_metrics = None
103+
102104
if self.config.managed_mldiagnostics:
103105
ManagedMLDiagnostics(config) # Initialize the MLRun instance.
106+
107+
if self.config.enable_wandb and jax.process_index() == 0:
108+
wandb.init(
109+
project=config.wandb_project_name,
110+
name=config.wandb_run_name,
111+
resume="allow",
112+
) # Initialize wandb logger.
104113

105114
def reset_eval_metrics(self):
106115
"""Resets the cumulative metrics dictionary for a new evaluation run."""
@@ -122,6 +131,9 @@ def write_metrics(self, metrics, step, is_training=True):
122131

123132
if self.config.managed_mldiagnostics:
124133
self.write_metrics_to_managed_mldiagnostics(metrics, step)
134+
135+
if self.config.enable_wandb and jax.process_index() == 0:
136+
self.write_metrics_to_wandb(metrics, step)
125137

126138
def log_metrics(self, metrics, step, is_training):
127139
"""Logs metrics via max_logging."""
@@ -267,6 +279,16 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step):
267279
mapped_metric_name = _METRICS_TO_MANAGED.get(metric_name, metric_name)
268280
mldiag.metrics.record(mapped_metric_name, value, step=int(step))
269281

282+
def write_metrics_to_wandb(self, metrics, step):
283+
"""Write metrics to weights and biases (wandb)."""
284+
flat_metrics = {}
285+
for key, val in metrics.get("scalar", {}).items():
286+
flat_metrics[key] = float(val)
287+
for key, val in metrics.get("scalars", {}).items():
288+
for subkey, subval in val.items():
289+
flat_metrics[f"{key}/{subkey}"] = float(subval)
290+
wandb.log(flat_metrics, step=step)
291+
270292
def write_setup_info_to_tensorboard(self, params):
271293
"""Writes setup information like train config params, num model params, and XLA flags to TensorBoard."""
272294
num_model_parameters = max_utils.calculate_num_params_from_pytree(params)

src/maxtext/configs/base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. if empty,
9393
# if true save metrics such as loss and tflops to gcs in {base_output_directory}/{run_name}/metrics/
9494
gcs_metrics: false
9595

96+
enable_wandb: False
97+
wandb_project_name: ""
98+
wandb_run_name: ""
99+
96100
# if true save config to gcs in {base_output_directory}/{run_name}/
97101
save_config_to_gcs: false
98102

src/maxtext/configs/pyconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
170170
for key, value in raw_keys.items():
171171
if key not in valid_fields:
172172
logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key))
173-
raise ValueError(f"{key!r} not in {', '.join(map(repr, valid_fields))}.")
173+
raise ValueError(f"{key!r} not in {", ".join(map(repr, valid_fields))}.")
174174

175175
new_value = value
176176
if isinstance(new_value, str) and new_value.lower() == "none":

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,9 @@ class Metrics(BaseModel):
14711471
False,
14721472
description="Whether to enable Tunix-managed metrics measurement. The metrics will be uploaded to tensorboard.",
14731473
)
1474+
enable_wandb: bool = Field(False, description="Enable Weights & Biases logging.")
1475+
wandb_project_name: str = Field("maxtext", description="Weights & Biases project name.")
1476+
wandb_run_name: str = Field("", description="Weights & Biases run name. If empty, a default name is generated.")
14741477

14751478

14761479
class ManagedMLDiagnostics(BaseModel):

0 commit comments

Comments
 (0)