Skip to content

Commit cfdb38e

Browse files
committed
Rename to train_sft_native.py
1 parent 3de17a7 commit cfdb38e

13 files changed

Lines changed: 11 additions & 35 deletions

File tree

docs/tutorials/posttraining/knowledge_distillation.md

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -165,30 +165,6 @@ python3 -m tools.data_generation.generate_distillation_data_vllm \
165165

166166
You can now fine-tune your smaller student model using supervised fine-tuning technique in MaxText.
167167

168-
#### Fine-tune the student model using the generated dataset
169-
170-
Example command to run fine-tuning on a TPU v6e-8:
171-
172-
```bash
173-
python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated \
174-
run_name=${RUN_NAME?} \
175-
base_output_directory=${BASE_OUTPUT_DIRECTORY?}/distillation/qwen3-32b-distill-llama3.1-8b \
176-
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct tokenizer_type=huggingface \
177-
dataset_type=hf \
178-
hf_path=parquet \
179-
hf_train_files=${OUTPUT_DATASET?} \
180-
train_split='train' \
181-
train_data_columns=['messages'] \
182-
load_parameters_path=${MAXTEXT_CKPT_PATH?}/0/items \
183-
model_name=llama3.1-8b \
184-
per_device_batch_size=2 \
185-
steps=200 \
186-
ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \
187-
max_target_length=2048 \
188-
hf_access_token=${HF_TOKEN?} \
189-
profiler=xplane
190-
```
191-
192168
#### **[OPTIONAL]** Fine-tune the student model using the original dataset
193169

194170
The checkpoint from the student model's fine-tuning (on the teacher-generated dataset) can be used for a subsequent fine-tuning stage. In this step, the student model is fine-tuned on the original dataset that was initially provided to the teacher model for generating the dataset.

docs/tutorials/posttraining/multimodal.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ Here, we use [ChartQA](https://huggingface.co/datasets/HuggingFaceM4/ChartQA) as
130130
export MAXTEXT_CKPT_PATH=... # either set to an already available MaxText ckpt or to the one we just converted in the previous step
131131
export BASE_OUTPUT_DIRECTORY=gs://...
132132
export STEPS=1000
133-
python -m maxtext.trainers.post_train.sft.train_sft_deprecated \
133+
python -m maxtext.trainers.post_train.sft.train_sft_native \
134134
src/maxtext/configs/post_train/sft-vision-chartqa.yml \
135135
run_name="chartqa-sft" \
136136
model_name=gemma3-4b \

src/maxtext/configs/pyconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
"maxtext.trainers.post_train.dpo.train_dpo": "post_train/dpo.yml",
5656
"maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml",
5757
"maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml",
58-
"maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml",
58+
"maxtext.trainers.post_train.sft.train_sft_native": "post_train/sft.yml",
5959
"maxtext.inference.decode": "base.yml",
6060
"maxtext.inference.decode_multi": "base.yml",
6161
"maxtext.inference.inference_microbenchmark": "base.yml",

src/maxtext/examples/multimodal_gemma3_demo.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@
160160
"STEPS=10\n",
161161
"PER_DEVICE_BATCH_SIZE=1\n",
162162
"\n",
163-
"!python -m maxtext.trainers.post_train.sft.train_sft_deprecated \\\n",
163+
"!python -m maxtext.trainers.post_train.sft.train_sft_native \\\n",
164164
" $MAXTEXT_CONFIGS_DIR/sft-vision-chartqa.yml \\\n",
165165
" run_name=$WORKLOAD_NAME \\\n",
166166
" model_name=$MODEL_NAME \\\n",

src/maxtext/trainers/post_train/sft/train_sft_deprecated.py renamed to src/maxtext/trainers/post_train/sft/train_sft_native.py

File renamed without changes.

tests/end_to_end/tpu/deepseek/Run_DeepSeek.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
147147
One example command to run supervised finetuning with V3 on v5p-256. Supervised fine-tuning is only working with HuggingFace conversational datasets. And, you can customize the dataset path using the `hf_path` config and provide your access token with `hf_access_token` config.
148148

149149
```sh
150-
python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated src/maxtext/configs/post_train/sft.yml \
150+
python3 -m maxtext.trainers.post_train.sft.train_sft_native src/maxtext/configs/post_train/sft.yml \
151151
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
152152
load_parameters_path=${SCANNED_CKPT_PATH?} \
153153
run_name=matmul_supervised_fine_tuning \

tests/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:
4444

4545
# 3. SFT the MaxText converted checkpoint on ChartQA dataset
4646
export BASE_OUTPUT_DIRECTORY=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/sft
47-
python -m maxtext.trainers.post_train.sft.train_sft_deprecated "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//sft-vision-chartqa.yml \
47+
python -m maxtext.trainers.post_train.sft.train_sft_native "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//sft-vision-chartqa.yml \
4848
run_name=$idx \
4949
model_name=$MODEL_NAME tokenizer_path="google/gemma-3-4b-pt" \
5050
per_device_batch_size=1 \

tests/end_to_end/tpu/gpt_oss/120b/test_gpt_oss.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_RE
6060
python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32
6161

6262
# Run supervised fine-tuning - megablox implementation
63-
python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32
63+
python3 -m maxtext.trainers.post_train.sft.train_sft_native "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32
6464

6565
# Run decoding - megablox implementation
6666
# Note decode requires the access token for huggingface tokenizer even if the model is not gated

tests/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_RE
6666

6767
# Run supervised fine-tuning - megablox implementation
6868
# TODO: remove `abort_on_nan_loss=false` after b/497864549
69-
python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 gcs_metrics=true abort_on_nan_loss=false
69+
python3 -m maxtext.trainers.post_train.sft.train_sft_native "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 gcs_metrics=true abort_on_nan_loss=false
7070

7171
# Run decoding - megablox implementation
7272
# Note decode requires the access token for huggingface tokenizer even if the model is not gated

tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
110110
One example command to run supervised finetuning with gpt-oss-20b on v5p-8. Supervised finetuning is only working with HuggingFace conversational datasets. And, you can customize the dataset path using the `hf_path` config. If using [gated dataset](https://huggingface.co/docs/hub/en/datasets-gated) or [gated model](https://huggingface.co/docs/hub/en/models-gated), you need additionally provide the access token with `hf_access_token` config.
111111

112112
```sh
113-
python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated src/maxtext/configs/post_train/sft.yml \
113+
python3 -m maxtext.trainers.post_train.sft.train_sft_native src/maxtext/configs/post_train/sft.yml \
114114
base_output_directory=${BASE_OUTPUT_PATH?} \
115115
run_name=megablox_supervised_fine_tuning \
116116
model_name=gpt-oss-20b \

0 commit comments

Comments
 (0)