2020import os
2121import queue
2222import enum
23+ import wandb
2324
2425import numpy as np
2526
@@ -99,8 +100,16 @@ def __init__(self, config, learning_rate_schedule):
99100 self .learning_rate_schedule = learning_rate_schedule
100101 self .cumulative_eval_metrics = {"scalar" : defaultdict (float )}
101102 self .buffered_train_metrics = None
103+
102104 if self .config .managed_mldiagnostics :
103105 ManagedMLDiagnostics (config ) # Initialize the MLRun instance.
106+
107+ if self .config .enable_wandb and jax .process_index () == 0 :
108+ wandb .init (
109+ project = config .wandb_project_name ,
110+ name = config .wandb_run_name ,
111+ resume = "allow" ,
112+ ) # Initialize wandb logger.
104113
105114 def reset_eval_metrics (self ):
106115 """Resets the cumulative metrics dictionary for a new evaluation run."""
@@ -122,6 +131,9 @@ def write_metrics(self, metrics, step, is_training=True):
122131
123132 if self .config .managed_mldiagnostics :
124133 self .write_metrics_to_managed_mldiagnostics (metrics , step )
134+
135+ if self .config .enable_wandb and jax .process_index () == 0 :
136+ self .write_metrics_to_wandb (metrics , step )
125137
126138 def log_metrics (self , metrics , step , is_training ):
127139 """Logs metrics via max_logging."""
@@ -267,6 +279,16 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step):
267279 mapped_metric_name = _METRICS_TO_MANAGED .get (metric_name , metric_name )
268280 mldiag .metrics .record (mapped_metric_name , value , step = int (step ))
269281
282+ def write_metrics_to_wandb (self , metrics , step ):
283+ """Write metrics to weights and biases (wandb)."""
284+ flat_metrics = {}
285+ for key , val in metrics .get ("scalar" , {}).items ():
286+ flat_metrics [key ] = float (val )
287+ for key , val in metrics .get ("scalars" , {}).items ():
288+ for subkey , subval in val .items ():
289+ flat_metrics [f"{ key } /{ subkey } " ] = float (subval )
290+ wandb .log (flat_metrics , step = step )
291+
270292 def write_setup_info_to_tensorboard (self , params ):
271293 """Writes setup information like train config params, num model params, and XLA flags to TensorBoard."""
272294 num_model_parameters = max_utils .calculate_num_params_from_pytree (params )
0 commit comments