Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ optimizer:
lr_scheduler:
warmup_steps: 200

compile:
enable: false

training:
local_batch_size: 8
seq_len: 2048
max_norm: 1.0
steps: 1000
compile: false
datasets:
- path: "yahma/alpaca-cleaned"
split: "train[:95%]"
Expand Down
6 changes: 3 additions & 3 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def record_batch_metrics(self, data_metrics: list):
@endpoint
async def setup(self):
# Validate that compile is only used with flex attention
if self.job_config.training.compile:
if self.job_config.compile.enable:
raise ValueError(
"training.compile=True is not currently supported. "
"compile.enable=True is not currently supported. "
"Compile is only supported with flex attention enabled, which requires PyTorch nightly. "
"Please set training.compile=false in your config."
"Please set compile.enable=false in your config."
)

# all ranks should record loss, except when PP=True. Then, only the last stage should record loss.
Expand Down
4 changes: 3 additions & 1 deletion apps/sft/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ optimizer:
lr_scheduler:
warmup_steps: 200

compile:
enable: false

training:
local_batch_size: 8
seq_len: 2048
max_norm: 1.0
steps: 1000
compile: false
datasets:
- path: "yahma/alpaca-cleaned"
split: "train[:95%]"
Expand Down
Loading