1717import numpy as np
1818import jax
1919import jax .numpy as jnp
20+ import queue
2021
2122from maxdiffusion import max_utils , max_logging
2223
@@ -68,10 +69,31 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
6869 metrics ["scalar" ].update ({"learning/current_learning_rate" : lr })
6970
7071
72+ _metrics_queue = queue .Queue ()
7173_buffered_step = None
7274_buffered_metrics = None
7375
7476
77+ def _tensorboard_writer_worker (writer , config ):
78+ """
79+ A worker function that runs in a separate thread.
80+ It waits for metrics to appear in the queue and writes them to TensorBoard.
81+ """
82+ while True :
83+ data = _metrics_queue .get ()
84+ if data is None :
85+ break
86+ metrics , step = data
87+ if jax .process_index () == 0 :
88+ for metric_name in metrics .get ("scalar" , []):
89+ writer .add_scalar (metric_name , np .array (metrics ["scalar" ][metric_name ]), step )
90+ for metric_name in metrics .get ("scalars" , []):
91+ writer .add_scalars (metric_name , metrics ["scalars" ][metric_name ], step )
92+
93+ if step % config .log_period == 0 :
94+ writer .flush ()
95+
96+
7597def write_metrics (writer , local_metrics_file , running_gcs_metrics , metrics , step , config ):
7698 """Entry point for all metrics writing in Train's Main.
7799 TODO: would be better as a Class in the future (that initialized all state!)
@@ -81,16 +103,18 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
81103 The logic is that this ensures that Jax is able to queues train_steps and we
82104 don't block when turning "lazy" Jax arrays into real Python numbers.
83105 """
84- global _buffered_step , _buffered_metrics
106+ global _buffered_step , _buffered_metrics , _metrics_queue
85107
108+ if metrics :
109+ _metrics_queue .put ((metrics , step ))
86110 if _buffered_metrics is not None :
111+ if config .metrics_file :
112+ max_utils .write_metrics_locally (_buffered_metrics , _buffered_step , config , local_metrics_file )
113+
87114 if _buffered_step is None :
88115 raise ValueError (f"When writing metrics, { _buffered_step = } was none" )
89116 write_metrics_to_tensorboard (writer , _buffered_metrics , _buffered_step , config )
90117
91- if config .metrics_file :
92- max_utils .write_metrics_locally (_buffered_metrics , _buffered_step , config , local_metrics_file )
93-
94118 if config .gcs_metrics and jax .process_index () == 0 :
95119 running_gcs_metrics = max_utils .write_metrics_for_gcs (_buffered_metrics , _buffered_step , config , running_gcs_metrics )
96120
@@ -100,13 +124,6 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
100124
101125def write_metrics_to_tensorboard (writer , metrics , step , config ):
102126 """Writes metrics to tensorboard"""
103- if jax .process_index () == 0 :
104- for metric_name in metrics .get ("scalar" , []):
105- writer .add_scalar (metric_name , np .array (metrics ["scalar" ][metric_name ]), step )
106- for metric_name in metrics .get ("scalars" , []):
107- writer .add_scalars (metric_name , metrics ["scalars" ][metric_name ], step )
108-
109- full_log = step % config .log_period == 0
110127 if jax .process_index () == 0 :
111128 max_logging .log (
112129 "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}" .format (
@@ -116,6 +133,13 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):
116133 float (metrics ["scalar" ]["learning/loss" ]),
117134 )
118135 )
136+ if jax .process_index () == 0 :
137+ for metric_name in metrics .get ("scalar" , []):
138+ writer .add_scalar (metric_name , np .array (metrics ["scalar" ][metric_name ]), step )
139+ for metric_name in metrics .get ("scalars" , []):
140+ writer .add_scalars (metric_name , metrics ["scalars" ][metric_name ], step )
141+
142+ full_log = step % config .log_period == 0
119143
120144 if full_log and jax .process_index () == 0 :
121145 max_logging .log (f"To see full metrics 'tensorboard --logdir={ config .tensorboard_dir } '" )
0 commit comments