Skip to content

Commit 0c256a7

Browse files
committed
added/removed comments where needed
1 parent 6d16107 commit 0c256a7

2 files changed

Lines changed: 14 additions & 12 deletions

File tree

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,11 @@ def create_tf_example(example_dict):
8989
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=[val]))
9090
continue
9191

92-
# Use .ravel() to avoid the memory copy that .flatten() does
9392
flat_val = np.asarray(val).ravel()
9493

9594
if flat_val.dtype in [np.float32, np.float64, np.float16, jnp.bfloat16]:
9695
if flat_val.dtype != np.float32:
9796
flat_val = flat_val.astype(np.float32)
98-
# Use .tolist() for fast Protobuf C++ ingestion
9997
features[key] = tf.train.Feature(float_list=tf.train.FloatList(value=flat_val.tolist()))
10098
elif flat_val.dtype in [np.int32, np.int64]:
10199
if flat_val.dtype != np.int64:
@@ -179,11 +177,11 @@ def generate_and_save_data(config, local_args):
179177
if not os.path.exists(local_tmp_dir):
180178
os.makedirs(local_tmp_dir, exist_ok=True)
181179

182-
upload_executor = ThreadPoolExecutor(max_workers=4)
180+
upload_executor = ThreadPoolExecutor(max_workers=local_args.upload_workers)
183181
# Restrict to 1 worker to ensure sequential writing to the ArrayRecord file
184182
write_executor = ThreadPoolExecutor(max_workers=1)
185183
# New executor purely for CPU-bound protobuf serialization
186-
serialization_executor = ThreadPoolExecutor(max_workers=16)
184+
serialization_executor = ThreadPoolExecutor(max_workers=local_args.serialization_workers)
187185

188186
devices = jax.devices()
189187
devices_array = maxtext_utils.create_device_mesh(config, devices)
@@ -219,7 +217,9 @@ def generate_and_save_data(config, local_args):
219217
max_logging.log(f"Queueing distributed background uploads for Step {step}...")
220218
upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index())
221219

222-
# Re-initialize the writer with 1 worker
220+
# Re-initialize the writer thread pool. We restrict it to exactly 1 worker
221+
# to ensure sequential, in-order writing to the ArrayRecord file. A new pool
222+
# is needed because the previous one was shut down to flush all pending writes.
223223
write_executor = ThreadPoolExecutor(max_workers=1)
224224

225225
file_index = step // steps_per_file
@@ -321,6 +321,10 @@ def main(argv: Sequence[str], local_args):
321321
parser.add_argument("--gcs_upload_path", type=str, default=None)
322322
parser.add_argument("--local_tmp_dir", type=str, default="/tmp")
323323
parser.add_argument("--steps_per_file", type=int, default=50)
324+
parser.add_argument("--upload_workers", type=int, default=4, help="Number of workers for GCS uploads.")
325+
parser.add_argument(
326+
"--serialization_workers", type=int, default=16, help="Number of workers for protobuf serialization."
327+
)
324328
local_arg, remaining_args = parser.parse_known_args()
325329

326330
main_wrapper = functools.partial(main, local_args=local_arg)

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -378,13 +378,11 @@ def _log_metrics(self, loss, step=None, additional_metrics=None, **kwargs):
378378
tflops_per_sec = None
379379
if step_time_delta is not None and step_time_delta > 0:
380380
tflops_per_sec = self._tflops_combined / step_time_delta
381-
tflops_metrics.update(
382-
{
383-
"perf/per_device_tflops_per_sec": tflops_per_sec,
384-
"perf/per_device_tflops_per_sec_student": self._tflops_student / step_time_delta,
385-
"perf/per_device_tflops_per_sec_teacher": self._tflops_teacher / step_time_delta,
386-
}
387-
)
381+
tflops_metrics.update({
382+
"perf/per_device_tflops_per_sec": tflops_per_sec,
383+
"perf/per_device_tflops_per_sec_student": self._tflops_student / step_time_delta,
384+
"perf/per_device_tflops_per_sec_teacher": self._tflops_teacher / step_time_delta,
385+
})
388386
for name, value in tflops_metrics.items():
389387
self.metrics_logger.log(self.metrics_prefix, name, value, self._mode, step)
390388

0 commit comments

Comments
 (0)