File tree Expand file tree Collapse file tree
trainers/post_train/distillation Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1070,15 +1070,13 @@ class Distillation(BaseModel):
10701070 default_factory = dict ,
10711071 description = "Overrides specific to the Teacher model (e.g., {'num_query_heads': 64})." ,
10721072 )
1073-
1073+
10741074 # --- Offline Distillation Fields ---
10751075 offline_distillation : bool = Field (
1076- False ,
1077- description = "If True, enables offline distillation using pre-generated teacher data."
1076+ False , description = "If True, enables offline distillation using pre-generated teacher data."
10781077 )
10791078 offline_data_dir : Optional [str ] = Field (
1080- None ,
1081- description = "GCS or local path to the pre-generated ArrayRecord teacher data."
1079+ None , description = "GCS or local path to the pre-generated ArrayRecord teacher data."
10821080 )
10831081
10841082 # --- Loss Params ---
Original file line number Diff line number Diff line change 1818model structures with Tunix's training interfaces.
1919"""
2020
21- import os
2221import pickle
2322import tensorflow as tf
2423from array_record .python import array_record_module
@@ -82,11 +81,7 @@ class OfflineArrayRecordIterator:
8281 """Reads the pre-generated global top-k logits file."""
8382
8483 def __init__ (self , data_dir : str , epochs : int = 100 ):
85- # Check if the user passed a directory or a direct file path
86- if tf .io .gfile .isdir (data_dir ):
87- self .filepath = os .path .join (data_dir , "teacher_top_k_global.array_record" )
88- else :
89- self .filepath = data_dir
84+ self .filepath = data_dir
9085
9186 if not tf .io .gfile .exists (self .filepath ):
9287 raise FileNotFoundError (f"Offline distillation file not found: { self .filepath } " )
Original file line number Diff line number Diff line change 32323. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose
3333 a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
3434"""
35- import argparse
36- import functools
37- import sys
38-
3935from typing import Sequence , Callable
4036from absl import app
4137from flax import nnx
@@ -641,4 +637,4 @@ def main(argv: Sequence[str]) -> None:
641637
642638
643639if __name__ == "__main__" :
644- app .run (main )
640+ app .run (main )
You can’t perform that action at this time.
0 commit comments