1+ # Copyright 2023–2026 Google LLC
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # https://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
115"""
216This module provides functionality to save top-k teacher logits
317for distillation purposes in MaxText.
@@ -81,12 +95,11 @@ def create_tf_example(example_dict):
8195 if flat_val .dtype in [np .float32 , np .float64 , np .float16 , jnp .bfloat16 ]:
8296 if flat_val .dtype != np .float32 :
8397 flat_val = flat_val .astype (np .float32 )
84- # Use .tolist() for extremely fast Protobuf C++ ingestion
98+ # Use .tolist() for fast Protobuf C++ ingestion
8599 features [key ] = tf .train .Feature (float_list = tf .train .FloatList (value = flat_val .tolist ()))
86100 elif flat_val .dtype in [np .int32 , np .int64 ]:
87101 if flat_val .dtype != np .int64 :
88102 flat_val = flat_val .astype (np .int64 )
89- # Use .tolist() for extremely fast Protobuf C++ ingestion
90103 features [key ] = tf .train .Feature (int64_list = tf .train .Int64List (value = flat_val .tolist ()))
91104 else :
92105 raise ValueError (f"Unsupported dtype { flat_val .dtype } for key { key } " )
@@ -96,51 +109,46 @@ def create_tf_example(example_dict):
96109
97110def background_process_and_write (writer , tokens , vals , idx , opt_data , serialization_executor ):
98111 """Executes entirely on a background CPU thread so the TPU never waits."""
99- with tf .profiler .experimental .Trace ("background_local_disk_write" ):
100- # Convert exactly once
101- tokens_np = np .asarray (tokens )
102- vals_np = np .asarray (vals )
103- idx_np = np .asarray (idx )
104- opt_data_np = {k : np .asarray (v ) for k , v in opt_data .items ()}
105-
106- batch_size = tokens_np .shape [0 ]
107- example_dicts = []
108-
109- # Prepare dictionaries sequentially
110- for i in range (batch_size ):
111- seq_bytes = tokens_np [i ].tobytes ()
112- example_dict = {
113- "inputs" : tokens_np [i ],
114- "top_k_logits" : vals_np [i ],
115- "top_k_indices" : idx_np [i ],
116- "sequence_hash" : hash (seq_bytes ),
117- }
118- for key , val_np in opt_data_np .items ():
119- example_dict [key ] = val_np [i ]
120- example_dicts .append (example_dict )
121-
122- # Serialize to Protobufs in parallel across multiple CPU cores
123- with tf .profiler .experimental .Trace ("parallel_serialize" ):
124- serialized_records = list (serialization_executor .map (create_tf_example , example_dicts ))
125-
126- # Write the serialized bytes to disk sequentially
127- with tf .profiler .experimental .Trace ("sequential_write" ):
128- for record in serialized_records :
129- writer .write (record )
112+ # Convert exactly once
113+ tokens_np = np .asarray (tokens )
114+ vals_np = np .asarray (vals )
115+ idx_np = np .asarray (idx )
116+ opt_data_np = {k : np .asarray (v ) for k , v in opt_data .items ()}
117+
118+ batch_size = tokens_np .shape [0 ]
119+ example_dicts = []
120+
121+ # Prepare dictionaries sequentially
122+ for i in range (batch_size ):
123+ seq_bytes = tokens_np [i ].tobytes ()
124+ example_dict = {
125+ "inputs" : tokens_np [i ],
126+ "top_k_logits" : vals_np [i ],
127+ "top_k_indices" : idx_np [i ],
128+ "sequence_hash" : hash (seq_bytes ),
129+ }
130+ for key , val_np in opt_data_np .items ():
131+ example_dict [key ] = val_np [i ]
132+ example_dicts .append (example_dict )
133+
134+ # Serialize to Protobufs in parallel across multiple CPU cores
135+ serialized_records = list (serialization_executor .map (create_tf_example , example_dicts ))
136+
137+ # Write the serialized bytes to disk sequentially
138+ for record in serialized_records :
139+ writer .write (record )
130140
131141
132142def background_upload (local_path , gcs_path , process_index ):
133143 """Executes a highly optimized concurrent upload via gcloud."""
134- # Swapped to TF Trace context
135- with tf .profiler .experimental .Trace ("gcs_upload_and_cleanup" ):
136- try :
137- subprocess .run (["gcloud" , "storage" , "cp" , local_path , gcs_path ], check = True , capture_output = True )
138- os .remove (local_path )
139- if process_index == 0 :
140- max_logging .log (f"Background upload complete: { gcs_path } " )
141- except subprocess .CalledProcessError as e :
142- if process_index == 0 :
143- max_logging .log (f"Upload failed for { local_path } : { e .stderr .decode ()} " )
144+ try :
145+ subprocess .run (["gcloud" , "storage" , "cp" , local_path , gcs_path ], check = True , capture_output = True )
146+ os .remove (local_path )
147+ if process_index == 0 :
148+ max_logging .log (f"Background upload complete: { gcs_path } " )
149+ except subprocess .CalledProcessError as e :
150+ if process_index == 0 :
151+ max_logging .log (f"Upload failed for { local_path } : { e .stderr .decode ()} " )
144152
145153
146154@nnx .jit (static_argnames = ("k" ,))
@@ -201,22 +209,6 @@ def generate_and_save_data(config, local_args):
201209 for step , batch in enumerate (islice (train_iter , start_step , config .steps ), start = start_step ):
202210 step_start = time .time ()
203211
204- # --- 1. PROFILER SETUP ---
205- is_profiling_step = (
206- config .profiler == "xplane"
207- and step == config .skip_first_n_steps_for_profiler
208- )
209-
210- is_profiling_stop_step = (
211- config .profiler == "xplane"
212- and step == config .skip_first_n_steps_for_profiler + config .profiler_steps
213- )
214-
215- if is_profiling_step and jax .process_index () == 0 :
216- max_logging .log (f"Recording Host-Only XProf trace for step { step } using TF API..." )
217- options = tf .profiler .experimental .ProfilerOptions (host_tracer_level = 2 , device_tracer_level = 0 )
218- tf .profiler .experimental .start (config .tensorboard_dir , options = options )
219-
220212 if step % steps_per_file == 0 :
221213 if writer :
222214 write_executor .shutdown (wait = True )
@@ -225,9 +217,7 @@ def generate_and_save_data(config, local_args):
225217 gcs_file_path = os .path .join (gcs_upload_path , os .path .basename (local_output_path ))
226218 if jax .process_index () == 0 :
227219 max_logging .log (f"Queueing distributed background uploads for Step { step } ..." )
228- # Swapped to TF Trace context
229- with tf .profiler .experimental .Trace ("submit_to_gcs_upload" ):
230- upload_executor .submit (background_upload , local_output_path , gcs_file_path , jax .process_index ())
220+ upload_executor .submit (background_upload , local_output_path , gcs_file_path , jax .process_index ())
231221
232222 # Re-initialize the writer with 1 worker
233223 write_executor = ThreadPoolExecutor (max_workers = 1 )
@@ -239,19 +229,17 @@ def generate_and_save_data(config, local_args):
239229
240230 tokens = batch ["inputs" ]
241231
242- # --- TRACE 1: Model Forward Pass & Network Gather ---
243- # Swapped to TF Trace context
244- with tf .profiler .experimental .Trace ("teacher_forward_and_gather" ):
245- top_k_vals , top_k_idx = teacher_step (teacher_model , batch , k_val )
232+ # --- Model Forward Pass & Network Gather ---
233+ top_k_vals , top_k_idx = teacher_step (teacher_model , batch , k_val )
246234
247- global_tokens = jax .experimental .multihost_utils .process_allgather (tokens , tiled = True )
248- global_vals = jax .experimental .multihost_utils .process_allgather (top_k_vals , tiled = True )
249- global_idx = jax .experimental .multihost_utils .process_allgather (top_k_idx , tiled = True )
235+ global_tokens = jax .experimental .multihost_utils .process_allgather (tokens , tiled = True )
236+ global_vals = jax .experimental .multihost_utils .process_allgather (top_k_vals , tiled = True )
237+ global_idx = jax .experimental .multihost_utils .process_allgather (top_k_idx , tiled = True )
250238
251- optional_data = {}
252- for key in optional_keys :
253- if key in batch :
254- optional_data [key ] = jax .experimental .multihost_utils .process_allgather (batch [key ], tiled = True )
239+ optional_data = {}
240+ for key in optional_keys :
241+ if key in batch :
242+ optional_data [key ] = jax .experimental .multihost_utils .process_allgather (batch [key ], tiled = True )
255243
256244 if writer :
257245 global_tokens_np = np .array (global_tokens )
@@ -269,30 +257,23 @@ def generate_and_save_data(config, local_args):
269257 local_idx_np = global_idx_np [start_idx :end_idx ]
270258 local_opt_data_np = {k : v [start_idx :end_idx ] for k , v in optional_data_np .items ()}
271259
272- # --- TRACE 2: Local Disk Writing ---
260+ # --- Local Disk Writing ---
273261 # Submit to the background thread with the serialization_executor
274- with tf .profiler .experimental .Trace ("local_disk_write_submit" ):
275- write_executor .submit (
276- background_process_and_write ,
277- writer ,
278- local_tokens_np ,
279- local_vals_np ,
280- local_idx_np ,
281- local_opt_data_np ,
282- serialization_executor
283- )
262+ write_executor .submit (
263+ background_process_and_write ,
264+ writer ,
265+ local_tokens_np ,
266+ local_vals_np ,
267+ local_idx_np ,
268+ local_opt_data_np ,
269+ serialization_executor
270+ )
284271
285272 if step % 50 == 0 and jax .process_index () == 0 :
286273 max_logging .log (f"Successfully processed step { step } in { time .time () - step_start :.4f} s" )
287274
288275 multihost_utils .sync_global_devices (f"step_{ step } _complete" )
289276
290- # --- 2. STOP PROFILER ---
291- if is_profiling_stop_step :
292- if jax .process_index () == 0 :
293- max_logging .log (f"Stopping XProf profiler and uploading clean host trace..." )
294- tf .profiler .experimental .stop ()
295-
296277 if jax .process_index () == 0 :
297278 max_logging .log (f"Generation loop finished in { time .time () - loop_start :.2f} s" )
298279
@@ -307,9 +288,7 @@ def generate_and_save_data(config, local_args):
307288
308289 if gcs_upload_path :
309290 gcs_file_path = os .path .join (gcs_upload_path , os .path .basename (local_output_path ))
310- # Swapped to TF Trace context
311- with tf .profiler .experimental .Trace ("submit_to_gcs_upload" ):
312- upload_executor .submit (background_upload , local_output_path , gcs_file_path , jax .process_index ())
291+ upload_executor .submit (background_upload , local_output_path , gcs_file_path , jax .process_index ())
313292
314293 if upload_executor :
315294 if jax .process_index () == 0 :
@@ -319,10 +298,6 @@ def generate_and_save_data(config, local_args):
319298 max_logging .log ("All GCS uploads complete." )
320299
321300 multihost_utils .sync_global_devices ("upload_complete" )
322-
323- if jax .process_index () == 0 :
324- max_logging .log ("Waiting 15 seconds for XProf to save the trace..." )
325- time .sleep (15 )
326301
327302
328303def main (argv : Sequence [str ], local_args ):
@@ -345,8 +320,8 @@ def main(argv: Sequence[str], local_args):
345320 )
346321 parser .add_argument ("--gcs_upload_path" , type = str , default = None )
347322 parser .add_argument ("--local_tmp_dir" , type = str , default = "/tmp" )
348- parser .add_argument ("--steps_per_file" , type = int , default = 2 )
323+ parser .add_argument ("--steps_per_file" , type = int , default = 50 )
349324 local_arg , remaining_args = parser .parse_known_args ()
350325
351326 main_wrapper = functools .partial (main , local_args = local_arg )
352- app .run (main_wrapper , argv = [sys .argv [0 ]] + remaining_args )
327+ app .run (main_wrapper , argv = [sys .argv [0 ]] + remaining_args )
0 commit comments