Skip to content

Commit 044cbf7

Browse files
committed
added optional keys to config to make the code cleaner
1 parent 3d26d02 commit 044cbf7

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,13 @@ class Distillation(BaseModel):
10521052
# --- Loss Params ---
10531053
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")
10541054
distill_temperature: float = Field(1.0, description="Temperature for distillation softening.")
1055+
1056+
# --- Teacher topk distillation ---
1057+
teacher_logits_optional_keys: list[str] = Field(
1058+
default=["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"],
1059+
description="Optional keys to save from teacher logits"
1060+
)
1061+
10551062

10561063

10571064
class TrainingLoop(BaseModel):
@@ -1809,7 +1816,6 @@ class MaxTextConfig(
18091816
# Reinforcement Learning
18101817
RLHardware,
18111818
VLLM,
1812-
RL,
18131819
RLDataset,
18141820
RLEvaluation,
18151821
Reward,

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def generate_and_save_data(config, k_val):
7272
gathered_idx = multihost_utils.process_allgather(top_k_idx, tiled=True)
7373
gathered_tokens = multihost_utils.process_allgather(tokens, tiled=True)
7474

75-
optional_keys = ["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"]
75+
optional_keys = config.teacher_logits_optional_keys
7676
gathered_optionals = {
7777
key: multihost_utils.process_allgather(batch[key], tiled=True) for key in optional_keys if key in batch
7878
}

0 commit comments

Comments
 (0)