Skip to content

Commit 699ca4f

Browse files
committed
fix merge
1 parent ce6fce8 commit 699ca4f

3 files changed

Lines changed: 2 additions & 24 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
model structures with Tunix's training interfaces.
1919
"""
2020

21-
<<<<<<< Updated upstream
2221
import pickle
2322
import tensorflow as tf
2423
from array_record.python import array_record_module
2524

26-
=======
2725
import abc
28-
>>>>>>> Stashed changes
2926
from typing import Any, Iterator, Optional, List, Callable
3027

3128
import flax

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,6 @@ def train_distill(
530530
raw_train_iter = _setup_and_restore_input_pipeline(trainer, raw_train_iter, student_config, train_config)
531531

532532
# 8. Configure Input Mapping
533-
<<<<<<< Updated upstream
534533
def custom_gen_model_input_fn(batch):
535534
inputs_dict = {
536535
"input_tokens": batch.input_tokens,
@@ -560,20 +559,6 @@ def custom_gen_model_input_fn(batch):
560559
return inputs_dict
561560

562561
trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn)
563-
=======
564-
trainer = trainer.with_gen_model_input_fn(
565-
lambda batch: {
566-
"input_tokens": batch.input_tokens,
567-
"positions": batch.positions,
568-
"attention_mask": batch.input_mask,
569-
"decoder_segment_ids": batch.decoder_segment_ids,
570-
"targets": batch.targets, # Passed to strategy (create_labels)
571-
"targets_position": batch.targets_position, # Passed to strategy (create_labels)
572-
"targets_segmentation": batch.targets_segmentation, # Passed to strategy (create_labels)
573-
"cache": None,
574-
}
575-
)
576-
>>>>>>> Stashed changes
577562

578563
# 9. Create Iterator Wrappers (Use Utils)
579564
train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter)

tests/post_training/unit/train_distill_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,10 @@ def test_monitored_strategy_sft(self):
359359

360360
def _test_monitored_strategy(self, sft_mode: bool):
361361
"""Verifies the strategy calculates metrics and returns the correct tuple."""
362-
mock_config = mock.Mock()
363-
mock_config.vocab_size = 4
364362
strategy = distillation_utils.CombinedDistillationStrategy(
365363
student_forward_fn=lambda m, **k: None,
366364
teacher_forward_fn=lambda m, **k: None,
367-
vocab_size=mock_config.vocab_size,
365+
vocab_size=4,
368366
temperature=1.0,
369367
alpha=0.5,
370368
beta_feature=1.0,
@@ -413,12 +411,10 @@ def _test_monitored_strategy(self, sft_mode: bool):
413411

414412
def verify_strategy_compute_eval_loss(self):
415413
"""Covers MonitoredLogitStrategy.compute_eval_loss."""
416-
mock_config = mock.Mock()
417-
mock_config.vocab_size = 4
418414
strategy = distillation_utils.CombinedDistillationStrategy(
419415
student_forward_fn=mock.Mock(),
420416
teacher_forward_fn=mock.Mock(),
421-
vocab_size=mock_config.vocab_size,
417+
vocab_size=4,
422418
# student_config=mock_config,
423419
temperature=1.0,
424420
alpha=0.5,

0 commit comments

Comments
 (0)