Commit b341212
[tunix/sft] Avoid eager re-sharding of globally distributed arrays
Updates `PeftTrainer` to supply `data_sharding_axis` explicitly in `train_distill.py` to match MaxTexts native sharding axis.
Additionally adds a robust check in `sharding_utils.shard_input` to skip re-sharding of fully global un-addressable arrays to prevent TPU memory addressable errors.
PiperOrigin-RevId: 9022822751 parent 1907615 commit b341212
1 file changed
Lines changed: 2 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1027 | 1027 | | |
1028 | 1028 | | |
1029 | 1029 | | |
| 1030 | + | |
1030 | 1031 | | |
1031 | 1032 | | |
1032 | 1033 | | |
| |||
1116 | 1117 | | |
1117 | 1118 | | |
1118 | 1119 | | |
| 1120 | + | |
1119 | 1121 | | |
1120 | 1122 | | |
1121 | 1123 | | |
| |||
0 commit comments