|
1 | | -# # Copyright 2023–2026 Google LLC |
2 | | -# # |
3 | | -# # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | -# # you may not use this file except in compliance with the License. |
5 | | -# # You may obtain a copy of the License at |
6 | | -# # |
7 | | -# # https://www.apache.org/licenses/LICENSE-2.0 |
8 | | -# # |
9 | | -# # Unless required by applicable law or agreed to in writing, software |
10 | | -# # distributed under the License is distributed on an "AS IS" BASIS, |
11 | | -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | -# # See the License for the specific language governing permissions and |
13 | | -# # limitations under the License. |
14 | | - |
15 | | -# """ |
16 | | -# This module provides functionality to save top-k teacher logits |
17 | | -# for distillation purposes in MaxText. |
18 | | -# """ |
19 | | - |
20 | | -# import os |
21 | | -# import subprocess |
22 | | -# from concurrent.futures import ThreadPoolExecutor |
23 | | - |
24 | | -# # Force the pure Python protobuf implementation to avoid UPB compatibility issues with TFDS |
25 | | -# os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" |
26 | | -# from typing import Sequence |
27 | | -# import argparse |
28 | | -# import time |
29 | | -# import sys |
30 | | -# from etils import epath |
31 | | -# import tensorflow as tf |
32 | | -# import re |
33 | | -# import numpy as np |
34 | | - |
35 | | -# import jax |
36 | | -# import jax.numpy as jnp |
37 | | -# from flax import nnx |
38 | | -# import functools |
39 | | -# from itertools import islice |
40 | | - |
41 | | -# from absl import app |
42 | | -# from maxtext.configs import pyconfig |
43 | | -# from maxtext.utils import model_creation_utils |
44 | | -# from maxtext.input_pipeline import input_pipeline_interface |
45 | | -# from maxtext.utils import maxtext_utils |
46 | | -# from maxtext.utils import max_logging |
47 | | - |
48 | | -# from jax.experimental import multihost_utils |
49 | | -# from array_record.python import array_record_module |
50 | | - |
51 | | - |
52 | | -# def get_top_k_logits(logits: jax.Array, k: int): |
53 | | -# """Extracts the top-k values and their vocabulary indices""" |
54 | | -# top_k_values, top_k_indices = jax.lax.top_k(logits, k) |
55 | | -# return top_k_values, top_k_indices |
56 | | - |
57 | | - |
58 | | -# def get_start_step(config, local_args): |
59 | | -# """Determines the starting step for the generation process.""" |
60 | | -# if jax.process_index() != 0: |
61 | | -# return 0 |
62 | | - |
63 | | -# output_dir = local_args.gcs_upload_path if local_args.gcs_upload_path else local_args.local_tmp_dir |
64 | | -# output_path = epath.Path(output_dir) |
65 | | -# if not output_path.exists(): |
66 | | -# output_path.mkdir(parents=True, exist_ok=True) |
67 | | -# return 0 |
68 | | - |
69 | | -# existing_files = list(output_path.glob("teacher_top_k_part_*.array_record")) |
70 | | -# if not existing_files: |
71 | | -# return 0 |
72 | | - |
73 | | -# max_part_num = max( |
74 | | -# (int(m.group(1)) for f in existing_files if (m := re.search(r"part_(\d+)_host", f.name))), |
75 | | -# default=-1, |
76 | | -# ) |
77 | | - |
78 | | -# if max_part_num == -1: |
79 | | -# return 0 |
80 | | - |
81 | | -# start_step = max_part_num * local_args.steps_per_file |
82 | | -# max_logging.log(f"Found existing data, resuming from step {start_step}") |
83 | | -# return start_step |
84 | | - |
85 | | - |
86 | | -# def create_tf_example(example_dict): |
87 | | -# """Converts a dictionary of single-example numpy arrays to a tf.train.Example.""" |
88 | | -# features = {} |
89 | | -# for key, val in example_dict.items(): |
90 | | -# if key == "sequence_hash": |
91 | | -# features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=[val])) |
92 | | -# continue |
93 | | - |
94 | | -# flat_val = np.asarray(val).flatten() |
95 | | - |
96 | | -# if flat_val.dtype in [np.float32, np.float64, np.float16, jnp.bfloat16]: |
97 | | -# features[key] = tf.train.Feature(float_list=tf.train.FloatList(value=flat_val.astype(np.float32))) |
98 | | -# elif flat_val.dtype in [np.int32, np.int64]: |
99 | | -# features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=flat_val.astype(np.int64))) |
100 | | -# else: |
101 | | -# raise ValueError(f"Unsupported dtype {flat_val.dtype} for key {key}") |
102 | | - |
103 | | -# return tf.train.Example(features=tf.train.Features(feature=features)).SerializeToString() |
104 | | - |
105 | | - |
106 | | -# def background_process_and_write(writer, tokens, vals, idx, opt_data): |
107 | | -# """Executes entirely on a background CPU thread so the TPU never waits.""" |
108 | | -# with tf.profiler.experimental.Trace("background_local_disk_write"): |
109 | | -# tokens_np = np.array(tokens) |
110 | | -# vals_np = np.array(vals) |
111 | | -# idx_np = np.array(idx) |
112 | | -# opt_data_np = {k: np.array(v) for k, v in opt_data.items()} |
113 | | - |
114 | | -# batch_size = tokens_np.shape[0] |
115 | | -# for i in range(batch_size): |
116 | | -# seq_bytes = tokens_np[i].tobytes() |
117 | | -# example_dict = { |
118 | | -# "inputs": tokens_np[i], |
119 | | -# "top_k_logits": vals_np[i], |
120 | | -# "top_k_indices": idx_np[i], |
121 | | -# "sequence_hash": hash(seq_bytes), |
122 | | -# } |
123 | | -# for key, val_np in opt_data_np.items(): |
124 | | -# example_dict[key] = val_np[i] |
125 | | - |
126 | | -# writer.write(create_tf_example(example_dict)) |
127 | | - |
128 | | - |
129 | | -# def background_upload(local_path, gcs_path, process_index): |
130 | | -# """Executes a highly optimized concurrent upload via gcloud.""" |
131 | | -# # Swapped to TF Trace context |
132 | | -# with tf.profiler.experimental.Trace("gcs_upload_and_cleanup"): |
133 | | -# try: |
134 | | -# subprocess.run(["gcloud", "storage", "cp", local_path, gcs_path], check=True, capture_output=True) |
135 | | -# os.remove(local_path) |
136 | | -# if process_index == 0: |
137 | | -# max_logging.log(f"Background upload complete: {gcs_path}") |
138 | | -# except subprocess.CalledProcessError as e: |
139 | | -# if process_index == 0: |
140 | | -# max_logging.log(f"Upload failed for {local_path}: {e.stderr.decode()}") |
141 | | - |
142 | | - |
143 | | -# @nnx.jit(static_argnames=("k",)) |
144 | | -# def teacher_step(model, batch, k): |
145 | | -# """Runs a forward pass through the teacher model and extracts top-k logits.""" |
146 | | -# logits = model( |
147 | | -# decoder_input_tokens=batch["inputs"], |
148 | | -# decoder_positions=batch["inputs_position"], |
149 | | -# decoder_segment_ids=batch.get("inputs_segmentation"), |
150 | | -# decoder_target_tokens=batch.get("targets"), |
151 | | -# decoder_target_mask=batch.get("targets_segmentation"), |
152 | | -# enable_dropout=False, |
153 | | -# ) |
154 | | -# return get_top_k_logits(logits, k=k) |
155 | | - |
156 | | - |
157 | | -# def generate_and_save_data(config, local_args): |
158 | | -# """Generates top-k logits from the teacher model and saves them locally, optionally uploading to GCS""" |
159 | | -# k_val = local_args.top_k |
160 | | -# optional_keys = local_args.optional_keys |
161 | | -# gcs_upload_path = local_args.gcs_upload_path |
162 | | -# local_tmp_dir = local_args.local_tmp_dir |
163 | | -# steps_per_file = local_args.steps_per_file |
164 | | - |
165 | | -# writer = None |
166 | | -# local_output_path = None |
167 | | - |
168 | | -# if not os.path.exists(local_tmp_dir): |
169 | | -# os.makedirs(local_tmp_dir, exist_ok=True) |
170 | | - |
171 | | -# upload_executor = ThreadPoolExecutor(max_workers=4) |
172 | | -# # FIX: Restrict to 1 worker to ensure sequential writing to the ArrayRecord file |
173 | | -# write_executor = ThreadPoolExecutor(max_workers=1) |
174 | | - |
175 | | -# devices = jax.devices() |
176 | | -# devices_array = maxtext_utils.create_device_mesh(config, devices) |
177 | | -# mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) |
178 | | - |
179 | | -# if jax.process_index() == 0: |
180 | | -# max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...") |
181 | | - |
182 | | -# teacher_model = model_creation_utils.from_pretrained(config, mesh=mesh) |
183 | | -# train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh) |
184 | | - |
185 | | -# start_step = get_start_step(config, local_args) |
186 | | -# start_step = int(multihost_utils.broadcast_one_to_all(jnp.array(start_step))) |
187 | | - |
188 | | -# multihost_utils.sync_global_devices("start_generation_loop") |
189 | | - |
190 | | -# with mesh: |
191 | | -# if jax.process_index() == 0: |
192 | | -# max_logging.log(f"Starting Distributed Top-K generation loop for {config.steps - start_step} steps...") |
193 | | - |
194 | | -# loop_start = time.time() |
195 | | - |
196 | | -# for step, batch in enumerate(islice(train_iter, start_step, config.steps), start=start_step): |
197 | | -# step_start = time.time() |
198 | | - |
199 | | -# # --- 1. PROFILER SETUP --- |
200 | | -# is_profiling_step = ( |
201 | | -# config.profiler == "xplane" |
202 | | -# and step == config.skip_first_n_steps_for_profiler |
203 | | -# ) |
204 | | - |
205 | | -# is_profiling_stop_step = ( |
206 | | -# config.profiler == "xplane" |
207 | | -# and step == config.skip_first_n_steps_for_profiler + config.profiler_steps |
208 | | -# ) |
209 | | - |
210 | | -# if is_profiling_step and jax.process_index() == 0: |
211 | | -# max_logging.log(f"Recording Host-Only XProf trace for step {step} using TF API...") |
212 | | -# options = tf.profiler.experimental.ProfilerOptions(host_tracer_level=2, device_tracer_level=0) |
213 | | -# tf.profiler.experimental.start(config.tensorboard_dir, options=options) |
214 | | - |
215 | | -# if step % steps_per_file == 0: |
216 | | -# if writer: |
217 | | -# write_executor.shutdown(wait=True) |
218 | | -# writer.close() |
219 | | -# if gcs_upload_path: |
220 | | -# gcs_file_path = os.path.join(gcs_upload_path, os.path.basename(local_output_path)) |
221 | | -# if jax.process_index() == 0: |
222 | | -# max_logging.log(f"Queueing distributed background uploads for Step {step}...") |
223 | | -# # Swapped to TF Trace context |
224 | | -# with tf.profiler.experimental.Trace("submit_to_gcs_upload"): |
225 | | -# upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index()) |
226 | | - |
227 | | -# # FIX: Re-initialize the writer with 1 worker |
228 | | -# write_executor = ThreadPoolExecutor(max_workers=1) |
229 | | - |
230 | | -# file_index = step // steps_per_file |
231 | | -# filename = f"teacher_top_k_part_{file_index:05d}_host_{jax.process_index():03d}.array_record" |
232 | | -# local_output_path = os.path.join(local_tmp_dir, filename) |
233 | | -# writer = array_record_module.ArrayRecordWriter(local_output_path, "group_size:1") |
234 | | - |
235 | | -# tokens = batch["inputs"] |
236 | | - |
237 | | -# # --- TRACE 1: Model Forward Pass & Network Gather --- |
238 | | -# # Swapped to TF Trace context |
239 | | -# with tf.profiler.experimental.Trace("teacher_forward_and_gather"): |
240 | | -# top_k_vals, top_k_idx = teacher_step(teacher_model, batch, k_val) |
241 | | - |
242 | | -# global_tokens = jax.experimental.multihost_utils.process_allgather(tokens, tiled=True) |
243 | | -# global_vals = jax.experimental.multihost_utils.process_allgather(top_k_vals, tiled=True) |
244 | | -# global_idx = jax.experimental.multihost_utils.process_allgather(top_k_idx, tiled=True) |
245 | | - |
246 | | -# optional_data = {} |
247 | | -# for key in optional_keys: |
248 | | -# if key in batch: |
249 | | -# optional_data[key] = jax.experimental.multihost_utils.process_allgather(batch[key], tiled=True) |
250 | | - |
251 | | -# if writer: |
252 | | -# global_tokens_np = np.array(global_tokens) |
253 | | -# global_vals_np = np.array(global_vals) |
254 | | -# global_idx_np = np.array(global_idx) |
255 | | -# optional_data_np = {k: np.array(v) for k, v in optional_data.items()} |
256 | | - |
257 | | -# global_batch_size = global_tokens_np.shape[0] |
258 | | -# local_batch_size = global_batch_size // jax.process_count() |
259 | | -# start_idx = jax.process_index() * local_batch_size |
260 | | -# end_idx = start_idx + local_batch_size |
261 | | - |
262 | | -# local_tokens_np = global_tokens_np[start_idx:end_idx] |
263 | | -# local_vals_np = global_vals_np[start_idx:end_idx] |
264 | | -# local_idx_np = global_idx_np[start_idx:end_idx] |
265 | | -# local_opt_data_np = {k: v[start_idx:end_idx] for k, v in optional_data_np.items()} |
266 | | - |
267 | | -# # --- TRACE 2: Local Disk Writing --- |
268 | | -# # FIX: Submit to the background thread instead of blocking the main thread |
269 | | -# with tf.profiler.experimental.Trace("local_disk_write_submit"): |
270 | | -# write_executor.submit( |
271 | | -# background_process_and_write, |
272 | | -# writer, |
273 | | -# local_tokens_np, |
274 | | -# local_vals_np, |
275 | | -# local_idx_np, |
276 | | -# local_opt_data_np |
277 | | -# ) |
278 | | - |
279 | | -# if step % 50 == 0 and jax.process_index() == 0: |
280 | | -# max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s") |
281 | | - |
282 | | -# multihost_utils.sync_global_devices(f"step_{step}_complete") |
283 | | - |
284 | | -# # --- 2. STOP PROFILER --- |
285 | | -# if is_profiling_stop_step: |
286 | | -# if jax.process_index() == 0: |
287 | | -# max_logging.log(f"Stopping XProf profiler and uploading clean host trace...") |
288 | | -# tf.profiler.experimental.stop() |
289 | | - |
290 | | -# if jax.process_index() == 0: |
291 | | -# max_logging.log(f"Generation loop finished in {time.time() - loop_start:.2f}s") |
292 | | - |
293 | | -# multihost_utils.sync_global_devices("loop_finished") |
294 | | - |
295 | | -# if writer: |
296 | | -# if write_executor: |
297 | | -# write_executor.shutdown(wait=True) |
298 | | -# writer.close() |
299 | | - |
300 | | -# if gcs_upload_path: |
301 | | -# gcs_file_path = os.path.join(gcs_upload_path, os.path.basename(local_output_path)) |
302 | | -# # Swapped to TF Trace context |
303 | | -# with tf.profiler.experimental.Trace("submit_to_gcs_upload"): |
304 | | -# upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index()) |
305 | | - |
306 | | -# if upload_executor: |
307 | | -# if jax.process_index() == 0: |
308 | | -# max_logging.log("Waiting for all background uploads to finish across all hosts...") |
309 | | -# upload_executor.shutdown(wait=True) |
310 | | -# if jax.process_index() == 0: |
311 | | -# max_logging.log("All GCS uploads complete.") |
312 | | - |
313 | | -# multihost_utils.sync_global_devices("upload_complete") |
314 | | - |
315 | | -# if jax.process_index() == 0: |
316 | | -# max_logging.log("Waiting 15 seconds for XProf to save the trace...") |
317 | | -# time.sleep(15) |
318 | | - |
319 | | - |
320 | | -# def main(argv: Sequence[str], local_args): |
321 | | -# global_config = pyconfig.initialize(argv) |
322 | | -# teacher_overrides = global_config.teacher_overrides |
323 | | - |
324 | | -# teacher_config = pyconfig.initialize(argv, **teacher_overrides) |
325 | | - |
326 | | -# generate_and_save_data(teacher_config, local_args) |
327 | | - |
328 | | - |
329 | | -# if __name__ == "__main__": |
330 | | -# parser = argparse.ArgumentParser() |
331 | | -# parser.add_argument("--top_k", type=int, default=128) |
332 | | -# parser.add_argument( |
333 | | -# "--optional_keys", |
334 | | -# type=str, |
335 | | -# nargs="*", |
336 | | -# default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"], |
337 | | -# ) |
338 | | -# parser.add_argument("--gcs_upload_path", type=str, default=None) |
339 | | -# parser.add_argument("--local_tmp_dir", type=str, default="/tmp") |
340 | | -# parser.add_argument("--steps_per_file", type=int, default=2) |
341 | | -# local_arg, remaining_args = parser.parse_known_args() |
342 | | - |
343 | | -# main_wrapper = functools.partial(main, local_args=local_arg) |
344 | | -# app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args) |
345 | | - |
346 | | -# Copyright 2023–2026 Google LLC |
347 | | -# |
348 | | -# Licensed under the Apache License, Version 2.0 (the "License"); |
349 | | -# you may not use this file except in compliance with the License. |
350 | | -# You may obtain a copy of the License at |
351 | | -# |
352 | | -# https://www.apache.org/licenses/LICENSE-2.0 |
353 | | -# |
354 | | -# Unless required by applicable law or agreed to in writing, software |
355 | | -# distributed under the License is distributed on an "AS IS" BASIS, |
356 | | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
357 | | -# See the License for the specific language governing permissions and |
358 | | -# limitations under the License. |
359 | | - |
360 | 1 | """ |
361 | 2 | This module provides functionality to save top-k teacher logits |
362 | 3 | for distillation purposes in MaxText. |
|
0 commit comments