2121import sys
2222import queue
2323import enum
24+ import wandb
2425
2526import numpy as np
2627
@@ -109,6 +110,13 @@ def __init__(self, config, learning_rate_schedule):
109110 self ._pending_eval_step_count = 0
110111 if self .config .managed_mldiagnostics :
111112 ManagedMLDiagnostics (config ) # Initialize the MLRun instance.
113+
114+ if self .config .enable_wandb and jax .process_index () == 0 :
115+ wandb .init (
116+ project = config .wandb_project_name ,
117+ name = config .wandb_run_name ,
118+ resume = "allow" ,
119+ ) # Initialize wandb logger.
112120
113121 def reset_eval_metrics (self ):
114122 """Resets the cumulative metrics dictionary for a new evaluation run."""
@@ -131,6 +139,9 @@ def write_metrics(self, metrics, step, metric_type="train"):
131139
132140 if self .config .managed_mldiagnostics :
133141 self .write_metrics_to_managed_mldiagnostics (metrics , step )
142+
143+ if self .config .enable_wandb and jax .process_index () == 0 :
144+ self .write_metrics_to_wandb (metrics , step )
134145
135146 if metric_type == "train" :
136147 self ._maybe_abort_after_write_metrics (metrics )
@@ -326,6 +337,16 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step):
326337 mapped_metric_name = _METRICS_TO_MANAGED .get (metric_name , metric_name )
327338 mldiag .metrics .record (mapped_metric_name , value , step = int (step ))
328339
340+ def write_metrics_to_wandb (self , metrics , step ):
341+ """Write metrics to weights and biases (wandb)."""
342+ flat_metrics = {}
343+ for key , val in metrics .get ("scalar" , {}).items ():
344+ flat_metrics [key ] = float (val )
345+ for key , val in metrics .get ("scalars" , {}).items ():
346+ for subkey , subval in val .items ():
347+ flat_metrics [f"{ key } /{ subkey } " ] = float (subval )
348+ wandb .log (flat_metrics , step = step )
349+
329350 def write_setup_info_to_tensorboard (self , params ):
330351 """Writes setup information like train config params, num model params, and XLA flags to TensorBoard."""
331352 num_model_parameters = max_utils .calculate_num_params_from_pytree (params )
0 commit comments