Skip to content

Commit 6d16107

Browse files
committed
updated code formatting
1 parent 95db5de commit 6d16107

2 files changed

Lines changed: 10 additions & 11 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,6 @@ def compute_loss(
520520
ce_teacher_per_pos = jnp.zeros(s_logits.shape[:-1])
521521
kl_t1_sum = jnp.array(0.0)
522522

523-
524523
else:
525524
# --- DENSE KL DIVERGENCE (Online Mode) ---
526525
t_p_T = jax.nn.softmax(t_logits / temperature, axis=-1)

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def generate_and_save_data(config, local_args):
218218
if jax.process_index() == 0:
219219
max_logging.log(f"Queueing distributed background uploads for Step {step}...")
220220
upload_executor.submit(background_upload, local_output_path, gcs_file_path, jax.process_index())
221-
221+
222222
# Re-initialize the writer with 1 worker
223223
write_executor = ThreadPoolExecutor(max_workers=1)
224224

@@ -260,13 +260,13 @@ def generate_and_save_data(config, local_args):
260260
# --- Local Disk Writing ---
261261
# Submit to the background thread with the serialization_executor
262262
write_executor.submit(
263-
background_process_and_write,
264-
writer,
265-
local_tokens_np,
266-
local_vals_np,
267-
local_idx_np,
263+
background_process_and_write,
264+
writer,
265+
local_tokens_np,
266+
local_vals_np,
267+
local_idx_np,
268268
local_opt_data_np,
269-
serialization_executor
269+
serialization_executor,
270270
)
271271

272272
if step % 50 == 0 and jax.process_index() == 0:
@@ -303,9 +303,9 @@ def generate_and_save_data(config, local_args):
303303
def main(argv: Sequence[str], local_args):
304304
global_config = pyconfig.initialize(argv)
305305
teacher_overrides = global_config.teacher_overrides
306-
306+
307307
teacher_config = pyconfig.initialize(argv, **teacher_overrides)
308-
308+
309309
generate_and_save_data(teacher_config, local_args)
310310

311311

@@ -324,4 +324,4 @@ def main(argv: Sequence[str], local_args):
324324
local_arg, remaining_args = parser.parse_known_args()
325325

326326
main_wrapper = functools.partial(main, local_args=local_arg)
327-
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)
327+
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)

0 commit comments

Comments
 (0)