Skip to content

Commit 8b7058d

Browse files
committed
removed commented out code to make it more readable
1 parent 5afe30e commit 8b7058d

1 file changed

Lines changed: 0 additions & 359 deletions

File tree

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 0 additions & 359 deletions
Original file line numberDiff line numberDiff line change
@@ -1,362 +1,3 @@
1-
# # Copyright 2023–2026 Google LLC
2-
# #
3-
# # Licensed under the Apache License, Version 2.0 (the "License");
4-
# # you may not use this file except in compliance with the License.
5-
# # You may obtain a copy of the License at
6-
# #
7-
# # https://www.apache.org/licenses/LICENSE-2.0
8-
# #
9-
# # Unless required by applicable law or agreed to in writing, software
10-
# # distributed under the License is distributed on an "AS IS" BASIS,
11-
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# # See the License for the specific language governing permissions and
13-
# # limitations under the License.
14-
15-
# """
16-
# This module provides functionality to save top-k teacher logits
17-
# for distillation purposes in MaxText.
18-
# """
19-
20-
# import os
21-
# import subprocess
22-
# from concurrent.futures import ThreadPoolExecutor
23-
24-
# # Force the pure Python protobuf implementation to avoid UPB compatibility issues with TFDS
25-
# os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
26-
# from typing import Sequence
27-
# import argparse
28-
# import time
29-
# import sys
30-
# from etils import epath
31-
# import tensorflow as tf
32-
# import re
33-
# import numpy as np
34-
35-
# import jax
36-
# import jax.numpy as jnp
37-
# from flax import nnx
38-
# import functools
39-
# from itertools import islice
40-
41-
# from absl import app
42-
# from maxtext.configs import pyconfig
43-
# from maxtext.utils import model_creation_utils
44-
# from maxtext.input_pipeline import input_pipeline_interface
45-
# from maxtext.utils import maxtext_utils
46-
# from maxtext.utils import max_logging
47-
48-
# from jax.experimental import multihost_utils
49-
# from array_record.python import array_record_module
50-
51-
52-
# def get_top_k_logits(logits: jax.Array, k: int):
53-
# """Extracts the top-k values and their vocabulary indices"""
54-
# top_k_values, top_k_indices = jax.lax.top_k(logits, k)
55-
# return top_k_values, top_k_indices
56-
57-
58-
# def get_start_step(config, local_args):
59-
# """Determines the starting step for the generation process."""
60-
# if jax.process_index() != 0:
61-
# return 0
62-
63-
# output_dir = local_args.gcs_upload_path if local_args.gcs_upload_path else local_args.local_tmp_dir
64-
# output_path = epath.Path(output_dir)
65-
# if not output_path.exists():
66-
# output_path.mkdir(parents=True, exist_ok=True)
67-
# return 0
68-
69-
# existing_files = list(output_path.glob("teacher_top_k_part_*.array_record"))
70-
# if not existing_files:
71-
# return 0
72-
73-
# max_part_num = max(
74-
# (int(m.group(1)) for f in existing_files if (m := re.search(r"part_(\d+)_host", f.name))),
75-
# default=-1,
76-
# )
77-
78-
# if max_part_num == -1:
79-
# return 0
80-
81-
# start_step = max_part_num * local_args.steps_per_file
82-
# max_logging.log(f"Found existing data, resuming from step {start_step}")
83-
# return start_step
84-
85-
86-
# def create_tf_example(example_dict):
87-
# """Converts a dictionary of single-example numpy arrays to a tf.train.Example."""
88-
# features = {}
89-
# for key, val in example_dict.items():
90-
# if key == "sequence_hash":
91-
# features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=[val]))
92-
# continue
93-
94-
# flat_val = np.asarray(val).flatten()
95-
96-
# if flat_val.dtype in [np.float32, np.float64, np.float16, jnp.bfloat16]:
97-
# features[key] = tf.train.Feature(float_list=tf.train.FloatList(value=flat_val.astype(np.float32)))
98-
# elif flat_val.dtype in [np.int32, np.int64]:
99-
# features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=flat_val.astype(np.int64)))
100-
# else:
101-
# raise ValueError(f"Unsupported dtype {flat_val.dtype} for key {key}")
102-
103-
# return tf.train.Example(features=tf.train.Features(feature=features)).SerializeToString()
104-
105-
106-
# def background_process_and_write(writer, tokens, vals, idx, opt_data):
107-
# """Executes entirely on a background CPU thread so the TPU never waits."""
108-
# with tf.profiler.experimental.Trace("background_local_disk_write"):
109-
# tokens_np = np.array(tokens)
110-
# vals_np = np.array(vals)
111-
# idx_np = np.array(idx)
112-
# opt_data_np = {k: np.array(v) for k, v in opt_data.items()}
113-
114-
# batch_size = tokens_np.shape[0]
115-
# for i in range(batch_size):
116-
# seq_bytes = tokens_np[i].tobytes()
117-
# example_dict = {
118-
# "inputs": tokens_np[i],
119-
# "top_k_logits": vals_np[i],
120-
# "top_k_indices": idx_np[i],
121-
# "sequence_hash": hash(seq_bytes),
122-
# }
123-
# for key, val_np in opt_data_np.items():
124-
# example_dict[key] = val_np[i]
125-
126-
# writer.write(create_tf_example(example_dict))
127-
128-
129-
# def background_upload(local_path, gcs_path, process_index):
130-
# """Executes a highly optimized concurrent upload via gcloud."""
131-
# # Swapped to TF Trace context
132-
# with tf.profiler.experimental.Trace("gcs_upload_and_cleanup"):
133-
# try:
134-
# subprocess.run(["gcloud", "storage", "cp", local_path, gcs_path], check=True, capture_output=True)
135-
# os.remove(local_path)
136-
# if process_index == 0:
137-
# max_logging.log(f"Background upload complete: {gcs_path}")
138-
# except subprocess.CalledProcessError as e:
139-
# if process_index == 0:
140-
# max_logging.log(f"Upload failed for {local_path}: {e.stderr.decode()}")
141-
142-
143-
# @nnx.jit(static_argnames=("k",))
144-
# def teacher_step(model, batch, k):
145-
# """Runs a forward pass through the teacher model and extracts top-k logits."""
146-
# logits = model(
147-
# decoder_input_tokens=batch["inputs"],
148-
# decoder_positions=batch["inputs_position"],
149-
# decoder_segment_ids=batch.get("inputs_segmentation"),
150-
# decoder_target_tokens=batch.get("targets"),
151-
# decoder_target_mask=batch.get("targets_segmentation"),
152-
# enable_dropout=False,
153-
# )
154-
# return get_top_k_logits(logits, k=k)
155-
156-
157-
# def generate_and_save_data(config, local_args):
158-
# """Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS"""
159-
# k_val = local_args.top_k
160-
# optional_keys = local_args.optional_keys
161-
# gcs_upload_path = local_args.gcs_upload_path
162-
# local_tmp_dir = local_args.local_tmp_dir
163-
# steps_per_file = local_args.steps_per_file
164-
165-
# writer = None
166-
# local_output_path = None
167-
168-
# if not os.path.exists(local_tmp_dir):
169-
# os.makedirs(local_tmp_dir, exist_ok=True)
170-
171-
# upload_executor = ThreadPoolExecutor(max_workers=4)
172-
# # FIX: Restrict to 1 worker to ensure sequential writing to the ArrayRecord file
173-
# write_executor = ThreadPoolExecutor(max_workers=1)
174-
175-
# devices = jax.devices()
176-
# devices_array = maxtext_utils.create_device_mesh(config, devices)
177-
# mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
178-
179-
# if jax.process_index() == 0:
180-
# max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...")
181-
182-
# teacher_model = model_creation_utils.from_pretrained(config, mesh=mesh)
183-
# train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh)
184-
185-
# start_step = get_start_step(config, local_args)
186-
# start_step = int(multihost_utils.broadcast_one_to_all(jnp.array(start_step)))
187-
188-
# multihost_utils.sync_global_devices("start_generation_loop")
189-
190-
# with mesh:
191-
# if jax.process_index() == 0:
192-
# max_logging.log(f"Starting Distributed Top-K generation loop for {config.steps - start_step} steps...")
193-
194-
# loop_start = time.time()
195-
196-
# for step, batch in enumerate(islice(train_iter, start_step, config.steps), start=start_step):
197-
# step_start = time.time()
198-
199-
# # --- 1. PROFILER SETUP ---
200-
# is_profiling_step = (
201-
# config.profiler == "xplane"
202-
# and step == config.skip_first_n_steps_for_profiler
203-
# )
204-
205-
# is_profiling_stop_step = (
206-
# config.profiler == "xplane"
207-
# and step == config.skip_first_n_steps_for_profiler + config.profiler_steps
208-
# )
209-
210-
# if is_profiling_step and jax.process_index() == 0:
211-
# max_logging.log(f"Recording Host-Only XProf trace for step {step} using TF API...")
212-
# options = tf.profiler.experimental.ProfilerOptions(host_tracer_level=2, device_tracer_level=0)
213-
# tf.profiler.experimental.start(config.tensorboard_dir, options=options)
214-
215-
# if step % steps_per_file == 0:
216-
# if writer:
217-
# write_executor.shutdown(wait=True)
218-
# writer.close()
219-
# if gcs_upload_path:
220-
# gcs_file_path = os.path.join(gcs_upload_path, os.path.basename(local_output_path))
221-
# if jax.process_index() == 0:
222-
# max_logging.log(f"Queueing distributed background uploads for Step {step}...")
223-
# # Swapped to TF Trace context
224-
# with tf.profiler.experimental.Trace("submit_to_gcs_upload"):
225-
# upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index())
226-
227-
# # FIX: Re-initialize the writer with 1 worker
228-
# write_executor = ThreadPoolExecutor(max_workers=1)
229-
230-
# file_index = step // steps_per_file
231-
# filename = f"teacher_top_k_part_{file_index:05d}_host_{jax.process_index():03d}.array_record"
232-
# local_output_path = os.path.join(local_tmp_dir, filename)
233-
# writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1")
234-
235-
# tokens = batch["inputs"]
236-
237-
# # --- TRACE 1: Model Forward Pass & Network Gather ---
238-
# # Swapped to TF Trace context
239-
# with tf.profiler.experimental.Trace("teacher_forward_and_gather"):
240-
# top_k_vals, top_k_idx = teacher_step(teacher_model, batch, k_val)
241-
242-
# global_tokens = jax.experimental.multihost_utils.process_allgather(tokens, tiled=True)
243-
# global_vals = jax.experimental.multihost_utils.process_allgather(top_k_vals, tiled=True)
244-
# global_idx = jax.experimental.multihost_utils.process_allgather(top_k_idx, tiled=True)
245-
246-
# optional_data = {}
247-
# for key in optional_keys:
248-
# if key in batch:
249-
# optional_data[key] = jax.experimental.multihost_utils.process_allgather(batch[key], tiled=True)
250-
251-
# if writer:
252-
# global_tokens_np = np.array(global_tokens)
253-
# global_vals_np = np.array(global_vals)
254-
# global_idx_np = np.array(global_idx)
255-
# optional_data_np = {k: np.array(v) for k, v in optional_data.items()}
256-
257-
# global_batch_size = global_tokens_np.shape[0]
258-
# local_batch_size = global_batch_size // jax.process_count()
259-
# start_idx = jax.process_index() * local_batch_size
260-
# end_idx = start_idx + local_batch_size
261-
262-
# local_tokens_np = global_tokens_np[start_idx:end_idx]
263-
# local_vals_np = global_vals_np[start_idx:end_idx]
264-
# local_idx_np = global_idx_np[start_idx:end_idx]
265-
# local_opt_data_np = {k: v[start_idx:end_idx] for k, v in optional_data_np.items()}
266-
267-
# # --- TRACE 2: Local Disk Writing ---
268-
# # FIX: Submit to the background thread instead of blocking the main thread
269-
# with tf.profiler.experimental.Trace("local_disk_write_submit"):
270-
# write_executor.submit(
271-
# background_process_and_write,
272-
# writer,
273-
# local_tokens_np,
274-
# local_vals_np,
275-
# local_idx_np,
276-
# local_opt_data_np
277-
# )
278-
279-
# if step % 50 == 0 and jax.process_index() == 0:
280-
# max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s")
281-
282-
# multihost_utils.sync_global_devices(f"step_{step}_complete")
283-
284-
# # --- 2. STOP PROFILER ---
285-
# if is_profiling_stop_step:
286-
# if jax.process_index() == 0:
287-
# max_logging.log(f"Stopping XProf profiler and uploading clean host trace...")
288-
# tf.profiler.experimental.stop()
289-
290-
# if jax.process_index() == 0:
291-
# max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s")
292-
293-
# multihost_utils.sync_global_devices("loop_finished")
294-
295-
# if writer:
296-
# if write_executor:
297-
# write_executor.shutdown(wait=True)
298-
# writer.close()
299-
300-
# if gcs_upload_path:
301-
# gcs_file_path = os.path.join(gcs_upload_path, os.path.basename(local_output_path))
302-
# # Swapped to TF Trace context
303-
# with tf.profiler.experimental.Trace("submit_to_gcs_upload"):
304-
# upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index())
305-
306-
# if upload_executor:
307-
# if jax.process_index() == 0:
308-
# max_logging.log("Waiting for all background uploads to finish across all hosts...")
309-
# upload_executor.shutdown(wait=True)
310-
# if jax.process_index() == 0:
311-
# max_logging.log("All GCS uploads complete.")
312-
313-
# multihost_utils.sync_global_devices("upload_complete")
314-
315-
# if jax.process_index() == 0:
316-
# max_logging.log("Waiting 15 seconds for XProf to save the trace...")
317-
# time.sleep(15)
318-
319-
320-
# def main(argv: Sequence[str], local_args):
321-
# global_config = pyconfig.initialize(argv)
322-
# teacher_overrides = global_config.teacher_overrides
323-
324-
# teacher_config = pyconfig.initialize(argv, **teacher_overrides)
325-
326-
# generate_and_save_data(teacher_config, local_args)
327-
328-
329-
# if __name__ == "__main__":
330-
# parser = argparse.ArgumentParser()
331-
# parser.add_argument("--top_k", type=int, default=128)
332-
# parser.add_argument(
333-
# "--optional_keys",
334-
# type=str,
335-
# nargs="*",
336-
# default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"],
337-
# )
338-
# parser.add_argument("--gcs_upload_path", type=str, default=None)
339-
# parser.add_argument("--local_tmp_dir", type=str, default="/tmp")
340-
# parser.add_argument("--steps_per_file", type=int, default=2)
341-
# local_arg, remaining_args = parser.parse_known_args()
342-
343-
# main_wrapper = functools.partial(main, local_args=local_arg)
344-
# app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)
345-
346-
# Copyright 2023–2026 Google LLC
347-
#
348-
# Licensed under the Apache License, Version 2.0 (the "License");
349-
# you may not use this file except in compliance with the License.
350-
# You may obtain a copy of the License at
351-
#
352-
# https://www.apache.org/licenses/LICENSE-2.0
353-
#
354-
# Unless required by applicable law or agreed to in writing, software
355-
# distributed under the License is distributed on an "AS IS" BASIS,
356-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
357-
# See the License for the specific language governing permissions and
358-
# limitations under the License.
359-
3601
"""
3612
This module provides functionality to save top-k teacher logits
3623
for distillation purposes in MaxText.

0 commit comments

Comments
 (0)