Skip to content

Commit bc2f020

Browse files
committed
Add unified DeepSpeed finetune demo with MMLU/GSM8K benchmarks
This PR adds a comprehensive finetuning example that demonstrates DeepSpeed's philosophy: use different training features via config files with no code change needed. Features: - Dataset registry with automatic format detection (Alpaca, Magicoder, MMLU, etc.) - sample_rate support for downsampling large datasets - MMLU and GSM8K benchmark evaluation scripts (vLLM-based generation + scoring) - Auto-detection of flash_attn availability - DistributedSampler for multi-GPU training - Checkpoint conversion (DeepSpeed -> HuggingFace) with AutoEP support - RunPod environment setup guide (LAB.md) - 9 DeepSpeed config variants (ZeRO-2/3, Offload, ZenFlow, SuperOffload, Muon, AutoTP) Tested on Qwen2.5-0.5B with 2x RTX 4090: - GSM8K baseline: 28.43% | MMLU baseline: 33.12% - Full pipeline: train -> convert checkpoint -> vLLM eval
1 parent 45b4b71 commit bc2f020

21 files changed

Lines changed: 1649 additions & 0 deletions
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# DeepSpeed finetune examples
2+
This finetune example is extracted and modified from [ZenFlow Llama-2 Fine-Tuning Example](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/DeepSpeed-ZenFlow/finetuning) in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples). The purpose is to demostrate how to use different DeepSpeed training features and compare their performance in a single place.
3+
4+
Currently in DeepSpeedExamples, each technology has a dedicated directory to show how to use it. However, DeepSpeed's philosophy is to allow users to use different features with different configuration file with no code change needed. This project put this claim to the test.
5+
6+
# How to use
7+
8+
To run the example, simply run:
9+
```
10+
./finetune.sh <NUM_GPUS> <MODEL_NAME> <DS_CONFIG>
11+
```
12+
13+
For example, if we want to run Qwen2.5-3B model with ZeRO offload on 2 GPUs, we can run:
14+
```
15+
./finetune.sh 2 Qwen2.5-3B configs/zo_config.json
16+
```
17+
18+
## Key arguments
19+
20+
| Argument | Description | Default |
21+
|----------|-------------|---------|
22+
| `--batch_size` | Training batch size per GPU | required |
23+
| `--eval_batch_size` | Eval batch size per rank | 1 |
24+
| `--eval_steps` | Run evaluation every N steps (0 disables) | 0 |
25+
| `--max_steps` | Stop after N steps (-1 = full epoch) | -1 |
26+
| `--checkpoint_steps` | Save a checkpoint every N steps (0 disables); keeps last 2 | 0 |
27+
| `--wandb_name` | Wandb run name (optional) | None |
28+
| `--num_train_epochs` | Number of training epochs | 1 |
29+
| `--weight_decay` | Weight decay | 0.01 |
30+
| `--warmup` | Warmup steps | 0 |
31+
32+
Note: Learning rate is controlled entirely by the DeepSpeed config JSON, not by command-line arguments.
33+
34+
## Batch size
35+
In DeepSpeed, batch size is decided by configuration file. However, to avoid modify the config file, this python script takes `--batch_size` parameter and use it to decide train batch size. Keep this in mind if you need to try different batch size.
36+
37+
## Wandb support
38+
An optional `--wandb_name` can be supplied to finetune_llama.py to generate wandb graph. But you need to modify `finetune.sh` manually to supply this argument.
39+
40+
## Dataset support
41+
42+
The training script auto-detects the dataset format:
43+
44+
- **Alpaca format** (default): datasets with `instruction`/`input`/`output` fields (e.g., `sahil2801/CodeAlpaca-20k`, `tatsu-lab/alpaca`)
45+
- **Magicoder format**: datasets with `problem`/`solution` fields (e.g., `ise-uiuc/Magicoder-OSS-Instruct-75K`)
46+
47+
Both formats use instruction-masked loss (only the response part contributes to loss).
48+
49+
# Moonlight-16B-A3B with AutoEP + Muon
50+
51+
This project supports fine-tuning [Moonlight-16B-A3B](https://huggingface.co/moonshotai/Moonlight-16B-A3B) (a 16B-parameter MoE model with 3B active parameters) using DeepSpeed AutoEP (automatic expert parallelism) and the Muon optimizer.
52+
53+
## Quick start (8x A100 40GB)
54+
55+
```bash
56+
# 1. Train
57+
deepspeed --num_gpus=8 finetune_llama.py \
58+
--model_name moonshotai/Moonlight-16B-A3B \
59+
--output_dir output_moonlight_muon \
60+
--batch_size 16 --max_length 512 \
61+
--deepspeed_config configs/z2_moonlight_autoep_muon.json \
62+
--dataset_name sahil2801/CodeAlpaca-20k \
63+
--num_train_epochs 1
64+
65+
# 2. Convert DeepSpeed checkpoint to HuggingFace format
66+
python convert_ds_to_hf.py \
67+
--ds_checkpoint output_moonlight_muon/step_<LAST_STEP> \
68+
--original_model moonshotai/Moonlight-16B-A3B \
69+
--output_dir hf_model_muon \
70+
--ep_size 8
71+
72+
# 3. Generate HumanEval completions
73+
python evaluate/humaneval/gen_humaneval.py \
74+
--model hf_model_muon \
75+
--output evalplus_results/muon \
76+
--instruction
77+
78+
# 4. Evaluate
79+
python -m evalplus.evaluate \
80+
--dataset humaneval \
81+
--samples evalplus_results/muon/samples.jsonl
82+
```
83+
84+
## Checkpoint format
85+
86+
With AutoEP, each rank holds a different expert shard. The training script saves checkpoints to `<output_dir>/step_<N>/`:
87+
- `0/model_weights.pt`: full state dict (non-expert params + local experts for rank 0)
88+
- `1/model_weights.pt` ... `7/model_weights.pt`: expert shard params only
89+
90+
Use `convert_ds_to_hf.py` to merge all shards back into a standard HuggingFace model.
91+
92+
## HumanEval results
93+
94+
| Model | HumanEval (base) | HumanEval+ |
95+
|-------|-----------------|------------|
96+
| Moonlight-16B-A3B (baseline) | 46.3% | 40.2% |
97+
| + Muon fine-tune on CodeAlpaca-20k (1 epoch) | 54.9% | 47.0% |
98+
99+
## AutoEP config
100+
101+
AutoEP config goes inside the DeepSpeed JSON under `expert_parallel`:
102+
103+
```json
104+
{
105+
"expert_parallel": {
106+
"enabled": true,
107+
"autoep_size": 8,
108+
"expert_w1": "gate_proj",
109+
"expert_w2": "down_proj",
110+
"expert_w3": "up_proj",
111+
"route_scale": 2.446,
112+
"load_balance_coeff": null
113+
}
114+
}
115+
```
116+
117+
| Parameter | Description |
118+
|-----------|-------------|
119+
| `autoep_size` | Number of expert-parallel ranks (typically = num_gpus) |
120+
| `expert_w1/w2/w3` | Names of the expert weight projections in the HF model |
121+
| `route_scale` | Router output scaling factor (should match `routed_scaling_factor` in model config) |
122+
| `load_balance_coeff` | Auxiliary load-balancing loss coefficient (`null` to disable) |
123+
124+
Note: `route_scale` and expert group settings can be auto-filled from the HF model config if using DeepSpeed branch `gma/autoep-muon-fixes`.
125+
126+
# Benchmarking
127+
128+
To run benchmark, run:
129+
```
130+
./benchmark.sh <NUM_GPUS> <MODEL_NAME> <DS_CONFIG>
131+
```
132+
133+
# Profiling
134+
135+
To run profiling, run:
136+
```
137+
./profile.sh <NUM_GPUS> <MODEL_NAME> <DS_CONFIG>
138+
```
139+
140+
# Config files
141+
142+
For quick start, some config files are added, you may also modify the config to fit your need.
143+
144+
| Config File | Description |
145+
|-------------|-------------|
146+
| z2_config.json | ZeRO Stage 2 with AdamW |
147+
| z3_config.json | ZeRO Stage 3 with AdamW |
148+
| zo_config.json | ZeRO Offload, stage 2 |
149+
| z3o_config.json | ZeRO Offload, stage 3 |
150+
| zf_config.json | ZeRO Offload with ZenFlow |
151+
| so_config.json | ZeRO Offload with SuperOffload |
152+
| z2_muon.json | ZeRO 2 with Muon optimizer |
153+
| z3_muon.json | ZeRO 3 with Muon optimizer |
154+
| tp_config.json | ZeRO 2 with AutoTP |
155+
| z2_moonlight_autoep_adam.json | Moonlight-16B-A3B with AutoEP + AdamW |
156+
| z2_moonlight_autoep_muon.json | Moonlight-16B-A3B with AutoEP + Muon |
157+
158+
## Muon optimizer config
159+
160+
Muon is a hybrid optimizer: it applies Muon updates to 2D hidden weights and Adam to everything else. The config supports separate learning rates:
161+
162+
```json
163+
{
164+
"optimizer": {
165+
"type": "Muon",
166+
"params": {
167+
"muon_lr": 1e-3,
168+
"adam_lr": 2e-5,
169+
"momentum": 0.95,
170+
"betas": [0.9, 0.999],
171+
"eps": 1e-8,
172+
"weight_decay": 0.01
173+
}
174+
}
175+
}
176+
```
177+
178+
| Parameter | Description |
179+
|-----------|-------------|
180+
| `muon_lr` | Learning rate for Muon (2D hidden weights) |
181+
| `adam_lr` | Learning rate for Adam (embeddings, layer norms, lm_head, etc.) |
182+
| `momentum` | Muon momentum factor |
183+
| `betas` | Adam betas (for non-Muon parameters) |
184+
| `eps` | Adam epsilon |
185+
| `weight_decay` | Weight decay for both Muon and Adam parameters |
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
NUM="${1:-2}"
2+
MODEL="${2:-Qwen/Qwen2.5-0.5B}"
3+
CONFIG="${3:-configs/z2_config.json}"
4+
deepspeed --num_gpus=$NUM --bind_cores_to_rank finetune_llama.py --model_name $MODEL --output_dir output --batch_size 8 --deepspeed_config $CONFIG --num_train_epochs 1 --bench_start 4
5+
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 3,
6+
"sub_group_size": 100000000,
7+
"offload_optimizer": {
8+
"super_offload": true,
9+
"convert_grad_on_cpu": true,
10+
"device": "cpu",
11+
"pin_memory": true
12+
}
13+
},
14+
"optimizer": {
15+
"type": "AdamW",
16+
"params": {
17+
"lr": 2e-5,
18+
"betas": [0.9, 0.999],
19+
"eps": 1e-8,
20+
"weight_decay": 0.01
21+
}
22+
},
23+
"gradient_accumulation_steps": 1,
24+
"gradient_clipping": 1.0,
25+
"zero_allow_untested_optimizer": true
26+
}
27+
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 2
6+
},
7+
8+
"tensor_parallel":{
9+
"autotp_size": 4
10+
},
11+
12+
"optimizer": {
13+
"type": "AdamW",
14+
"params": {
15+
"lr": 2e-5,
16+
"betas": [0.9, 0.999],
17+
"eps": 1e-8,
18+
"weight_decay": 0.01
19+
}
20+
},
21+
"gradient_accumulation_steps": 1,
22+
"gradient_clipping": 1.0,
23+
"zero_allow_untested_optimizer": true
24+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 2
6+
},
7+
"optimizer": {
8+
"type": "AdamW",
9+
"params": {
10+
"lr": 2e-5,
11+
"betas": [0.9, 0.999],
12+
"eps": 1e-8,
13+
"weight_decay": 0.01
14+
}
15+
},
16+
"gradient_accumulation_steps": 1,
17+
"gradient_clipping": 1.0,
18+
"zero_allow_untested_optimizer": true
19+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 2
6+
},
7+
"optimizer": {
8+
"type": "Muon",
9+
"params": {
10+
"muon_lr": 1e-3,
11+
"adam_lr": 2e-5,
12+
"momentum": 0.95,
13+
"betas": [0.9, 0.999],
14+
"eps": 1e-8,
15+
"weight_decay": 0.01
16+
}
17+
},
18+
"gradient_accumulation_steps": 1,
19+
"gradient_clipping": 1.0,
20+
"zero_allow_untested_optimizer": true
21+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 3
6+
},
7+
"optimizer": {
8+
"type": "AdamW",
9+
"params": {
10+
"lr": 2e-5,
11+
"betas": [0.9, 0.999],
12+
"eps": 1e-8,
13+
"weight_decay": 0.01
14+
}
15+
},
16+
"gradient_accumulation_steps": 1,
17+
"gradient_clipping": 1.0,
18+
"zero_allow_untested_optimizer": true
19+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 3,
6+
"reduce_scatter": false
7+
},
8+
"optimizer": {
9+
"type": "Muon",
10+
"params": {
11+
"muon_lr": 1e-3,
12+
"adam_lr": 2e-5,
13+
"momentum": 0.95,
14+
"betas": [0.9, 0.999],
15+
"eps": 1e-8,
16+
"weight_decay": 0.01
17+
}
18+
},
19+
"gradient_accumulation_steps": 1,
20+
"gradient_clipping": 1.0,
21+
"zero_allow_untested_optimizer": true
22+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 3,
6+
"offload_optimizer": {
7+
"device": "cpu",
8+
"pin_memory": true
9+
}
10+
},
11+
"optimizer": {
12+
"type": "AdamW",
13+
"params": {
14+
"lr": 2e-5,
15+
"betas": [0.9, 0.999],
16+
"eps": 1e-8,
17+
"weight_decay": 0.01
18+
}
19+
},
20+
"gradient_accumulation_steps": 1,
21+
"gradient_clipping": 1.0,
22+
"zero_allow_untested_optimizer": true
23+
}
24+
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 2,
6+
"offload_optimizer": {
7+
"device": "cpu",
8+
"pin_memory": true
9+
},
10+
"zenflow": {
11+
"topk_ratio": 0.1,
12+
"update_interval": 4,
13+
"full_warm_up_rounds": 0,
14+
"overlap_step": true
15+
}
16+
},
17+
"optimizer": {
18+
"type": "AdamW",
19+
"params": {
20+
"lr": 2e-5,
21+
"betas": [0.9, 0.999],
22+
"eps": 1e-8,
23+
"weight_decay": 0.01
24+
}
25+
},
26+
"gradient_accumulation_steps": 1,
27+
"gradient_clipping": 1.0,
28+
"zero_allow_untested_optimizer": true
29+
}
30+

0 commit comments

Comments
 (0)