Skip to content

Commit 7f6d9ea

Browse files
Maxtwin xm support.
This is to support g3 internal distillation runs. Contains minor import and type hint fixes between github and g3 only environments. PiperOrigin-RevId: 925513738
1 parent df77ec3 commit 7f6d9ea

2 files changed

Lines changed: 2 additions & 3 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from maxtext.utils import max_logging
3333
from maxtext.utils import maxtext_utils
34-
# Reuse MaxText's native checkpointing logic
34+
# Reuse MaxText's native checkpointing logic.
3535
from maxtext.common.checkpointing import GrainCheckpointHandler, GrainCheckpointSave, GrainCheckpointRestore
3636
from tunix.sft import checkpoint_manager as tunix_checkpoint_manager
3737
from tunix.sft import peft_trainer

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,6 @@ class MaxTextDistillationTrainer(peft_trainer.PeftTrainer):
213213
(positions, segment_ids) are passed to the model.
214214
"""
215215

216-
checkpoint_manager: distillation_utils.MaxTextCheckpointManager | None
217-
218216
def __init__(
219217
self,
220218
model,
@@ -234,6 +232,7 @@ def __init__(
234232
super().__init__(model=model, optimizer=dummy_optimizer, training_config=training_config, **kwargs)
235233

236234
self.strategy = strategy
235+
self.checkpoint_manager: distillation_utils.MaxTextCheckpointManager = None
237236

238237
# Per-step per-device TFLOPs (constants for the run): student fwd+bwd + teacher fwd-only.
239238
(

0 commit comments

Comments
 (0)