Skip to content

Commit 1171f89

Browse files
committed
Add DeepSpeed finetune demo with MMLU/GSM8K benchmarks and Moonlight AutoEP configs
- Add unified finetune script (finetune_llama.py) with DATASET_REGISTRY supporting Alpaca, CodeAlpaca, Magicoder, MetaMathQA, MMLU, MBPP datasets - Add sample_rate mechanism for dataset downsampling (MetaMathQA: 0.1) - Add MMLU and GSM8K evaluation pipelines (vllm-based generation + scoring) - Add Moonlight AutoEP ZeRO-2 configs (AdamW and Muon) - Add end-to-end run_and_evaluate.sh supporting MBPP/MMLU/GSM8K benchmarks - Add DeepSpeed checkpoint to HF model conversion with AutoEP/MoE support - Update README with dataset registry details, benchmark usage, and configs Signed-off-by: Guokai Ma <guokai.ma@gmail.com>
1 parent 45b4b71 commit 1171f89

23 files changed

Lines changed: 1737 additions & 0 deletions
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# DeepSpeed finetune examples
2+
This finetune example is extracted and modified from [ZenFlow Llama-2 Fine-Tuning Example](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/DeepSpeed-ZenFlow/finetuning) in [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples). The purpose is to demostrate how to use different DeepSpeed training features and compare their performance in a single place.
3+
4+
Currently in DeepSpeedExamples, each technology has a dedicated directory to show how to use it. However, DeepSpeed's philosophy is to allow users to use different features with different configuration file with no code change needed. This project put this claim to the test.
5+
6+
# How to use
7+
8+
To run the example, simply run:
9+
```
10+
./finetune.sh <NUM_GPUS> <MODEL_NAME> <DS_CONFIG>
11+
```
12+
13+
For example, if we want to run Qwen2.5-3B model with ZeRO offload on 2 GPUs, we can run:
14+
```
15+
./finetune.sh 2 Qwen2.5-3B configs/zo_config.json
16+
```
17+
18+
## Key arguments
19+
20+
| Argument | Description | Default |
21+
|----------|-------------|---------|
22+
| `--batch_size` | Training batch size per GPU | required |
23+
| `--eval_batch_size` | Eval batch size per rank | 4 |
24+
| `--eval_steps` | Run evaluation every N steps (0 disables) | 0 |
25+
| `--max_steps` | Stop after N steps (-1 = full epoch) | -1 |
26+
| `--checkpoint_steps` | Save a checkpoint every N steps (0 disables); keeps last 2 | 0 |
27+
| `--wandb_name` | Wandb run name (optional) | None |
28+
| `--num_train_epochs` | Number of training epochs | 3 |
29+
| `--weight_decay` | Weight decay | 0.01 |
30+
| `--warmup` | Warmup ratio | 0.01 |
31+
32+
Note: Learning rate is controlled entirely by the DeepSpeed config JSON, not by command-line arguments.
33+
34+
## Batch size
35+
In DeepSpeed, batch size is decided by configuration file. However, to avoid modify the config file, this python script takes `--batch_size` parameter and use it to decide train batch size. Keep this in mind if you need to try different batch size.
36+
37+
## Wandb support
38+
An optional `--wandb_name` can be supplied to finetune_llama.py to generate wandb graph. But you need to modify `finetune.sh` manually to supply this argument.
39+
40+
## Dataset support
41+
42+
The training script uses a `DATASET_REGISTRY` to configure datasets. Registered datasets are loaded with proper field mapping and preprocessing automatically.
43+
44+
| Dataset | Format | Use Case | Notes |
45+
|---------|--------|----------|-------|
46+
| `sahil2801/CodeAlpaca-20k` | Alpaca | Code instruction tuning | |
47+
| `meta-math/MetaMathQA` | Alpaca | Math reasoning | `sample_rate=0.1` (39.5k of 395k) |
48+
| `cais/mmlu` | MMLU MCQ | Knowledge tasks | Uses `auxiliary_train` split (~95k) |
49+
| `tatsu-lab/alpaca` | Alpaca | General instruction tuning | Fallback default |
50+
| `ise-uiuc/Magicoder-OSS-Instruct-75K` | Magicoder | Code instruction tuning | Auto-detected via `problem` column |
51+
52+
**Registered datasets** are specified by `--dataset_name` directly. **Unregistered datasets** are auto-detected: if the dataset has a `problem` column, Magicoder format is used; otherwise Alpaca format is assumed.
53+
54+
All formats use instruction-masked loss (only the response part contributes to loss).
55+
56+
### Adding a new dataset
57+
58+
Add an entry to `DATASET_REGISTRY` in `finetune_llama.py`:
59+
60+
```python
61+
"your-dataset/name": {
62+
"split": "train",
63+
"preprocessor": "alpaca", # or a custom preprocessor name
64+
"field_map": { # maps source fields to Alpaca format
65+
"instruction": "source_inst_field",
66+
"input": None, # set to None if not present
67+
"output": "source_output_field",
68+
},
69+
"sample_rate": 0.1, # optional: downsample large datasets
70+
},
71+
```
72+
73+
# Moonlight-16B-A3B with AutoEP + Muon
74+
75+
This project supports fine-tuning [Moonlight-16B-A3B](https://huggingface.co/moonshotai/Moonlight-16B-A3B) (a 16B-parameter MoE model with 3B active parameters) using DeepSpeed AutoEP (automatic expert parallelism) and the Muon optimizer.
76+
77+
## Quick start (8x A100 40GB)
78+
79+
```bash
80+
# 1. Train
81+
deepspeed --num_gpus=8 finetune_llama.py \
82+
--model_name moonshotai/Moonlight-16B-A3B \
83+
--output_dir output_moonlight_muon \
84+
--batch_size 16 --max_length 512 \
85+
--deepspeed_config configs/z2_moonlight_autoep_muon.json \
86+
--dataset_name sahil2801/CodeAlpaca-20k \
87+
--num_train_epochs 1
88+
89+
# 2. Convert DeepSpeed checkpoint to HuggingFace format
90+
python convert_ds_to_hf.py \
91+
--ds_checkpoint output_moonlight_muon/step_<LAST_STEP> \
92+
--original_model moonshotai/Moonlight-16B-A3B \
93+
--output_dir hf_model_muon \
94+
--ep_size 8
95+
96+
# 3. Generate HumanEval completions
97+
python evaluate/humaneval/gen_humaneval.py \
98+
--model hf_model_muon \
99+
--output evalplus_results/muon \
100+
--instruction
101+
102+
# 4. Evaluate
103+
python -m evalplus.evaluate \
104+
--dataset humaneval \
105+
--samples evalplus_results/muon/samples.jsonl
106+
```
107+
108+
## Checkpoint format
109+
110+
With AutoEP, each rank holds a different expert shard. The training script saves checkpoints to `<output_dir>/step_<N>/`:
111+
- `0/model_weights.pt`: full state dict (non-expert params + local experts for rank 0)
112+
- `1/model_weights.pt` ... `7/model_weights.pt`: expert shard params only
113+
114+
Use `convert_ds_to_hf.py` to merge all shards back into a standard HuggingFace model.
115+
116+
## HumanEval results
117+
118+
| Model | HumanEval (base) | HumanEval+ |
119+
|-------|-----------------|------------|
120+
| Moonlight-16B-A3B (baseline) | 46.3% | 40.2% |
121+
| + Muon fine-tune on CodeAlpaca-20k (1 epoch) | 54.9% | 47.0% |
122+
123+
## AutoEP config
124+
125+
AutoEP config goes inside the DeepSpeed JSON under `expert_parallel`:
126+
127+
```json
128+
{
129+
"expert_parallel": {
130+
"enabled": true,
131+
"autoep_size": 8,
132+
"expert_w1": "gate_proj",
133+
"expert_w2": "down_proj",
134+
"expert_w3": "up_proj",
135+
"route_scale": 2.446,
136+
"load_balance_coeff": null
137+
}
138+
}
139+
```
140+
141+
| Parameter | Description |
142+
|-----------|-------------|
143+
| `autoep_size` | Number of expert-parallel ranks (typically = num_gpus) |
144+
| `expert_w1/w2/w3` | Names of the expert weight projections in the HF model |
145+
| `route_scale` | Router output scaling factor (should match `routed_scaling_factor` in model config) |
146+
| `load_balance_coeff` | Auxiliary load-balancing loss coefficient (`null` to disable) |
147+
148+
# Benchmarking
149+
150+
To run benchmark, run:
151+
```
152+
./benchmark.sh <NUM_GPUS> <MODEL_NAME> <DS_CONFIG>
153+
```
154+
155+
# Profiling
156+
157+
To run profiling, run:
158+
```
159+
./profile.sh <NUM_GPUS> <MODEL_NAME> <DS_CONFIG>
160+
```
161+
162+
# Config files
163+
164+
For quick start, some config files are added, you may also modify the config to fit your need.
165+
166+
| Config File | Description |
167+
|-------------|-------------|
168+
| z2_config.json | ZeRO Stage 2 with AdamW |
169+
| z3_config.json | ZeRO Stage 3 with AdamW |
170+
| zo_config.json | ZeRO Offload, stage 2 |
171+
| z3o_config.json | ZeRO Offload, stage 3 |
172+
| zf_config.json | ZeRO Offload with ZenFlow |
173+
| so_config.json | ZeRO Offload with SuperOffload |
174+
| z2_muon.json | ZeRO 2 with Muon optimizer |
175+
| z3_muon.json | ZeRO 3 with Muon optimizer |
176+
| tp_config.json | ZeRO 2 with AutoTP |
177+
| z2_moonlight_autoep_adam.json | Moonlight-16B-A3B with AutoEP + AdamW |
178+
| z2_moonlight_autoep_muon.json | Moonlight-16B-A3B with AutoEP + Muon |
179+
180+
## Muon optimizer config
181+
182+
Muon is a hybrid optimizer: it applies Muon updates to 2D hidden weights and Adam to everything else. The config supports separate learning rates:
183+
184+
```json
185+
{
186+
"optimizer": {
187+
"type": "Muon",
188+
"params": {
189+
"muon_lr": 1e-3,
190+
"adam_lr": 2e-5,
191+
"momentum": 0.95,
192+
"betas": [0.9, 0.999],
193+
"eps": 1e-8,
194+
"weight_decay": 0.01
195+
}
196+
}
197+
}
198+
```
199+
200+
| Parameter | Description |
201+
|-----------|-------------|
202+
| `muon_lr` | Learning rate for Muon (2D hidden weights) |
203+
| `adam_lr` | Learning rate for Adam (embeddings, layer norms, lm_head, etc.) |
204+
| `momentum` | Muon momentum factor |
205+
| `betas` | Adam betas (for non-Muon parameters) |
206+
| `eps` | Adam epsilon |
207+
| `weight_decay` | Weight decay for both Muon and Adam parameters |
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
NUM="${1:-2}"
2+
MODEL="${2:-Qwen/Qwen2.5-0.5B}"
3+
CONFIG="${3:-configs/z2_config.json}"
4+
deepspeed --num_gpus=$NUM --bind_cores_to_rank finetune_llama.py --model_name $MODEL --output_dir output --batch_size 8 --deepspeed_config $CONFIG --num_train_epochs 1 --bench_start 4
5+
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 3,
6+
"sub_group_size": 100000000,
7+
"offload_optimizer": {
8+
"super_offload": true,
9+
"convert_grad_on_cpu": true,
10+
"device": "cpu",
11+
"pin_memory": true
12+
}
13+
},
14+
"optimizer": {
15+
"type": "AdamW",
16+
"params": {
17+
"lr": 2e-5,
18+
"betas": [0.9, 0.999],
19+
"eps": 1e-8,
20+
"weight_decay": 0.01
21+
}
22+
},
23+
"gradient_accumulation_steps": 1,
24+
"gradient_clipping": 1.0,
25+
"zero_allow_untested_optimizer": true
26+
}
27+
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 2
6+
},
7+
8+
"tensor_parallel":{
9+
"autotp_size": 4
10+
},
11+
12+
"optimizer": {
13+
"type": "AdamW",
14+
"params": {
15+
"lr": 2e-5,
16+
"betas": [0.9, 0.999],
17+
"eps": 1e-8,
18+
"weight_decay": 0.01
19+
}
20+
},
21+
"gradient_accumulation_steps": 1,
22+
"gradient_clipping": 1.0,
23+
"zero_allow_untested_optimizer": true
24+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 2
6+
},
7+
"optimizer": {
8+
"type": "AdamW",
9+
"params": {
10+
"lr": 2e-5,
11+
"betas": [0.9, 0.999],
12+
"eps": 1e-8,
13+
"weight_decay": 0.01
14+
}
15+
},
16+
"gradient_accumulation_steps": 1,
17+
"gradient_clipping": 1.0,
18+
"zero_allow_untested_optimizer": true
19+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"train_batch_size": 16,
3+
"bf16": {
4+
"enabled": true
5+
},
6+
"zero_optimization": {
7+
"stage": 2
8+
},
9+
"optimizer": {
10+
"type": "AdamW",
11+
"params": {
12+
"lr": 2e-06,
13+
"betas": [0.9, 0.999],
14+
"eps": 1e-08,
15+
"weight_decay": 0.01
16+
}
17+
},
18+
"expert_parallel": {
19+
"enabled": true,
20+
"autoep_size": 8,
21+
"expert_w1": "gate_proj",
22+
"expert_w2": "down_proj",
23+
"expert_w3": "up_proj",
24+
"route_scale": 2.446,
25+
"load_balance_coeff": null
26+
},
27+
"gradient_accumulation_steps": 2,
28+
"gradient_clipping": 1.0,
29+
"zero_allow_untested_optimizer": true
30+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"train_batch_size": 16,
3+
"bf16": {
4+
"enabled": true
5+
},
6+
"zero_optimization": {
7+
"stage": 2
8+
},
9+
"optimizer": {
10+
"type": "Muon",
11+
"params": {
12+
"muon_lr": 0.002,
13+
"adam_lr": 2e-06,
14+
"momentum": 0.95,
15+
"ns_method": "gram",
16+
"betas": [
17+
0.9,
18+
0.999
19+
],
20+
"eps": 1e-08,
21+
"weight_decay": 0.01
22+
}
23+
},
24+
"expert_parallel": {
25+
"enabled": true,
26+
"autoep_size": 8,
27+
"expert_w1": "gate_proj",
28+
"expert_w2": "down_proj",
29+
"expert_w3": "up_proj",
30+
"route_scale": 2.446,
31+
"load_balance_coeff": null
32+
},
33+
"gradient_accumulation_steps": 2,
34+
"gradient_clipping": 1.0,
35+
"zero_allow_untested_optimizer": true
36+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 2
6+
},
7+
"optimizer": {
8+
"type": "Muon",
9+
"params": {
10+
"muon_lr": 1e-3,
11+
"adam_lr": 2e-5,
12+
"momentum": 0.95,
13+
"betas": [0.9, 0.999],
14+
"eps": 1e-8,
15+
"weight_decay": 0.01
16+
}
17+
},
18+
"gradient_accumulation_steps": 1,
19+
"gradient_clipping": 1.0,
20+
"zero_allow_untested_optimizer": true
21+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"train_batch_size": 8,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 3
6+
},
7+
"optimizer": {
8+
"type": "AdamW",
9+
"params": {
10+
"lr": 2e-5,
11+
"betas": [0.9, 0.999],
12+
"eps": 1e-8,
13+
"weight_decay": 0.01
14+
}
15+
},
16+
"gradient_accumulation_steps": 1,
17+
"gradient_clipping": 1.0,
18+
"zero_allow_untested_optimizer": true
19+
}

0 commit comments

Comments
 (0)