Skip to content

Commit 9a7b2e4

Browse files
committed
Removed redundant offline_distillation flag and relied on offline_data_dir to know when to run offfline vs online distillation
1 parent 1fc0699 commit 9a7b2e4

2 files changed

Lines changed: 5 additions & 6 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,10 +1071,7 @@ class Distillation(BaseModel):
10711071
description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64}).",
10721072
)
10731073

1074-
# --- Offline Distillation Fields ---
1075-
offline_distillation: bool = Field(
1076-
False, description="If True, enables offline distillation using pre-generated teacher data."
1077-
)
1074+
# --- Offline Distillation Field ---
10781075
offline_data_dir: Optional[str] = Field(
10791076
None, description="GCS or local path to the pre-generated ArrayRecord teacher data."
10801077
)

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,12 +616,14 @@ def main(argv: Sequence[str]) -> None:
616616
student_overrides = global_config.student_overrides
617617
student_config = pyconfig.initialize(argv, **student_overrides)
618618

619+
is_offline = bool(global_config.offline_data_dir)
620+
619621
# 3. Initialize TEACHER Config
620622
# We isolate the Teacher from Student CLI arguments (like pruning params).
621623
teacher_overrides = global_config.teacher_overrides
622624

623625
# Ensure load_parameters_path is set in overrides
624-
if not global_config.offline_distillation and not teacher_overrides.get("load_parameters_path"):
626+
if not is_offline and not teacher_overrides.get("load_parameters_path"):
625627
raise ValueError(
626628
"Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
627629
"in your config or arguments."
@@ -633,7 +635,7 @@ def main(argv: Sequence[str]) -> None:
633635
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
634636

635637
# 4. Run Training
636-
train_distill(student_config, teacher_config, global_config.offline_distillation, global_config.offline_data_dir)
638+
train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir)
637639

638640

639641
if __name__ == "__main__":

0 commit comments

Comments
 (0)