Skip to content

Commit 1fc0699

Browse files
committed
removed the need for hardcoding arrayrecord file and read directly from config
1 parent 6bb64d0 commit 1fc0699

3 files changed

Lines changed: 5 additions & 16 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,15 +1070,13 @@ class Distillation(BaseModel):
10701070
default_factory=dict,
10711071
description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64}).",
10721072
)
1073-
1073+
10741074
# --- Offline Distillation Fields ---
10751075
offline_distillation: bool = Field(
1076-
False,
1077-
description="If True, enables offline distillation using pre-generated teacher data."
1076+
False, description="If True, enables offline distillation using pre-generated teacher data."
10781077
)
10791078
offline_data_dir: Optional[str] = Field(
1080-
None,
1081-
description="GCS or local path to the pre-generated ArrayRecord teacher data."
1079+
None, description="GCS or local path to the pre-generated ArrayRecord teacher data."
10821080
)
10831081

10841082
# --- Loss Params ---

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
model structures with Tunix's training interfaces.
1919
"""
2020

21-
import os
2221
import pickle
2322
import tensorflow as tf
2423
from array_record.python import array_record_module
@@ -82,11 +81,7 @@ class OfflineArrayRecordIterator:
8281
"""Reads the pre-generated global top-k logits file."""
8382

8483
def __init__(self, data_dir: str, epochs: int = 100):
85-
# Check if the user passed a directory or a direct file path
86-
if tf.io.gfile.isdir(data_dir):
87-
self.filepath = os.path.join(data_dir, "teacher_top_k_global.array_record")
88-
else:
89-
self.filepath = data_dir
84+
self.filepath = data_dir
9085

9186
if not tf.io.gfile.exists(self.filepath):
9287
raise FileNotFoundError(f"Offline distillation file not found: {self.filepath}")

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@
3232
3. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose
3333
a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
3434
"""
35-
import argparse
36-
import functools
37-
import sys
38-
3935
from typing import Sequence, Callable
4036
from absl import app
4137
from flax import nnx
@@ -641,4 +637,4 @@ def main(argv: Sequence[str]) -> None:
641637

642638

643639
if __name__ == "__main__":
644-
app.run(main)
640+
app.run(main)

0 commit comments

Comments
 (0)