Skip to content

Commit b341212

Browse files
gagikaGoogle-ML-Automation
authored andcommitted
[tunix/sft] Avoid eager re-sharding of globally distributed arrays
Updates `PeftTrainer` to supply `data_sharding_axis` explicitly in `train_distill.py` to match MaxTexts native sharding axis. Additionally adds a robust check in `sharding_utils.shard_input` to skip re-sharding of fully global un-addressable arrays to prevent TPU memory addressable errors. PiperOrigin-RevId: 902282275
1 parent 1907615 commit b341212

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

tests/post_training/unit/train_distill_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,7 @@ def test_main_offline_mode_skips_teacher_loading(
10271027
mock_student_cfg.eval_interval = -1
10281028
mock_student_cfg.gradient_accumulation_steps = 1
10291029
mock_student_cfg.global_batch_size = 8
1030+
mock_student_cfg.data_sharding = ("fsdp",)
10301031

10311032
# Add dummy numbers for strategy math/logic
10321033
mock_student_cfg.distill_temperature = 1.0
@@ -1116,6 +1117,7 @@ def test_main_online_mode_loads_teacher(
11161117
mock_student_cfg.eval_interval = -1
11171118
mock_student_cfg.gradient_accumulation_steps = 1
11181119
mock_student_cfg.global_batch_size = 8
1120+
mock_student_cfg.data_sharding = ("fsdp",)
11191121

11201122
# Add dummy numbers for strategy math/logic
11211123
mock_student_cfg.distill_temperature = 1.0

0 commit comments

Comments
 (0)