Skip to content

Commit 39f6e77

Browse files
author
SCAO Authors
committed
feat: add QLoRA 3B benchmark script against AdamW
1 parent 7590db4 commit 39f6e77

3 files changed

Lines changed: 296 additions & 0 deletions

File tree

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, BitsAndBytesConfig, DataCollatorForLanguageModeling
3+
from datasets import load_dataset
4+
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
5+
from scao import SCAO # Your 2nd-order optimizer implementation
6+
7+
def main():
8+
print("🚀 Starting 4B-Scale Benchmark for SCAO...")
9+
10+
# Using Qwen 2.5 3B model (optimal for the 4B category tests)
11+
model_id = "Qwen/Qwen2.5-3B"
12+
13+
tokenizer = AutoTokenizer.from_pretrained(model_id)
14+
if tokenizer.pad_token is None:
15+
tokenizer.pad_token = tokenizer.eos_token
16+
17+
print("📦 Loading base model in 4-bit (QLoRA) to optimize GPU memory usage...")
18+
bnb_config = BitsAndBytesConfig(
19+
load_in_4bit=True,
20+
bnb_4bit_use_double_quant=True,
21+
bnb_4bit_quant_type="nf4",
22+
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
23+
)
24+
25+
model = AutoModelForCausalLM.from_pretrained(
26+
model_id,
27+
quantization_config=bnb_config,
28+
device_map="auto"
29+
)
30+
31+
print("🧠 Initializing LoRA adapters...")
32+
model = prepare_model_for_kbit_training(model)
33+
34+
lora_config = LoraConfig(
35+
r=16,
36+
lora_alpha=32,
37+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Focusing on Attention layers
38+
lora_dropout=0.05,
39+
bias="none",
40+
task_type="CAUSAL_LM"
41+
)
42+
model = get_peft_model(model, lora_config)
43+
44+
# Filter only parameters that require gradients
45+
trainable_params = [p for p in model.parameters() if p.requires_grad]
46+
print(f"🔥 Trainable parameters (LoRA): {sum(p.numel() for p in trainable_params):,}")
47+
48+
print("📚 Loading dataset...")
49+
# Using wikitext-2 for consistency across benchmarks
50+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:5%]")
51+
52+
def tokenize(example):
53+
return tokenizer(example["text"], padding="max_length", truncation=True, max_length=256)
54+
55+
tokenized_dataset = dataset.map(tokenize, batched=True)
56+
57+
print("⚙️ Injecting SCAO Optimizer...")
58+
# SCAO uses 2nd-order information for faster convergence
59+
optimizer = SCAO(trainable_params, lr=2e-4) # Standard QLoRA learning rate
60+
61+
# Data collator for causal language modeling (automatically creates 'labels')
62+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
63+
64+
args = TrainingArguments(
65+
output_dir="./scao_benchmark_4b_results",
66+
per_device_train_batch_size=2, # Small batch size to manage VRAM constraints
67+
gradient_accumulation_steps=4, # Effective batch size of 8
68+
max_steps=100, # 100 steps to evaluate performance and loss decay
69+
logging_steps=10,
70+
report_to="none",
71+
gradient_checkpointing=True, # Essential to avoid Out-Of-Memory errors
72+
optim="adamw_torch" # Placeholder; SCAO will override this if passed to Trainer
73+
)
74+
75+
trainer = Trainer(
76+
model=model,
77+
args=args,
78+
train_dataset=tokenized_dataset,
79+
data_collator=data_collator, # Pass the collator to handle sequence labeling
80+
# To fully utilize SCAO, pass it to the Trainer's optimizers argument:
81+
optimizers=(optimizer, None)
82+
)
83+
84+
print("⚡ Training active! Watch for the loss reduction curve...")
85+
trainer.train()
86+
87+
if __name__ == "__main__":
88+
main()
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
{
2+
"best_global_step": null,
3+
"best_metric": null,
4+
"best_model_checkpoint": null,
5+
"epoch": 0.4357298474945534,
6+
"eval_steps": 500,
7+
"global_step": 100,
8+
"is_hyper_param_search": false,
9+
"is_local_process_zero": true,
10+
"is_world_process_zero": true,
11+
"log_history": [
12+
{
13+
"epoch": 0.04357298474945534,
14+
"grad_norm": 0.8557016849517822,
15+
"learning_rate": 4.55e-05,
16+
"loss": 2.7660524368286135,
17+
"step": 10
18+
},
19+
{
20+
"epoch": 0.08714596949891068,
21+
"grad_norm": 1.0481969118118286,
22+
"learning_rate": 4.05e-05,
23+
"loss": 3.1643226623535154,
24+
"step": 20
25+
},
26+
{
27+
"epoch": 0.13071895424836602,
28+
"grad_norm": 0.599257230758667,
29+
"learning_rate": 3.55e-05,
30+
"loss": 2.9573230743408203,
31+
"step": 30
32+
},
33+
{
34+
"epoch": 0.17429193899782136,
35+
"grad_norm": 0.5881643295288086,
36+
"learning_rate": 3.05e-05,
37+
"loss": 2.8170230865478514,
38+
"step": 40
39+
},
40+
{
41+
"epoch": 0.2178649237472767,
42+
"grad_norm": 0.5822768211364746,
43+
"learning_rate": 2.5500000000000003e-05,
44+
"loss": 2.7181997299194336,
45+
"step": 50
46+
},
47+
{
48+
"epoch": 0.26143790849673204,
49+
"grad_norm": 0.8144065141677856,
50+
"learning_rate": 2.05e-05,
51+
"loss": 2.6438121795654297,
52+
"step": 60
53+
},
54+
{
55+
"epoch": 0.30501089324618735,
56+
"grad_norm": 0.6474636793136597,
57+
"learning_rate": 1.55e-05,
58+
"loss": 2.623094940185547,
59+
"step": 70
60+
},
61+
{
62+
"epoch": 0.3485838779956427,
63+
"grad_norm": 0.7236846089363098,
64+
"learning_rate": 1.05e-05,
65+
"loss": 2.6345943450927733,
66+
"step": 80
67+
},
68+
{
69+
"epoch": 0.39215686274509803,
70+
"grad_norm": 0.5381503105163574,
71+
"learning_rate": 5.500000000000001e-06,
72+
"loss": 2.6301015853881835,
73+
"step": 90
74+
},
75+
{
76+
"epoch": 0.4357298474945534,
77+
"grad_norm": 1.4788507223129272,
78+
"learning_rate": 5.000000000000001e-07,
79+
"loss": 2.453901672363281,
80+
"step": 100
81+
}
82+
],
83+
"logging_steps": 10,
84+
"max_steps": 100,
85+
"num_input_tokens_seen": 0,
86+
"num_train_epochs": 1,
87+
"save_steps": 500,
88+
"stateful_callbacks": {
89+
"TrainerControl": {
90+
"args": {
91+
"should_epoch_stop": false,
92+
"should_evaluate": false,
93+
"should_log": false,
94+
"should_save": true,
95+
"should_training_stop": true
96+
},
97+
"attributes": {}
98+
}
99+
},
100+
"total_flos": 3418701692928000.0,
101+
"train_batch_size": 2,
102+
"trial_name": null,
103+
"trial_params": null
104+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
{
2+
"best_global_step": null,
3+
"best_metric": null,
4+
"best_model_checkpoint": null,
5+
"epoch": 0.4357298474945534,
6+
"eval_steps": 500,
7+
"global_step": 100,
8+
"is_hyper_param_search": false,
9+
"is_local_process_zero": true,
10+
"is_world_process_zero": true,
11+
"log_history": [
12+
{
13+
"epoch": 0.04357298474945534,
14+
"grad_norm": 1.0032742023468018,
15+
"learning_rate": 0.000182,
16+
"loss": 2.718400001525879,
17+
"step": 10
18+
},
19+
{
20+
"epoch": 0.08714596949891068,
21+
"grad_norm": 1.1835830211639404,
22+
"learning_rate": 0.000162,
23+
"loss": 2.7992538452148437,
24+
"step": 20
25+
},
26+
{
27+
"epoch": 0.13071895424836602,
28+
"grad_norm": 0.7115559577941895,
29+
"learning_rate": 0.000142,
30+
"loss": 2.5605207443237306,
31+
"step": 30
32+
},
33+
{
34+
"epoch": 0.17429193899782136,
35+
"grad_norm": 0.8470057249069214,
36+
"learning_rate": 0.000122,
37+
"loss": 2.559454345703125,
38+
"step": 40
39+
},
40+
{
41+
"epoch": 0.2178649237472767,
42+
"grad_norm": 0.7798479199409485,
43+
"learning_rate": 0.00010200000000000001,
44+
"loss": 2.545095443725586,
45+
"step": 50
46+
},
47+
{
48+
"epoch": 0.26143790849673204,
49+
"grad_norm": 1.027970314025879,
50+
"learning_rate": 8.2e-05,
51+
"loss": 2.415896987915039,
52+
"step": 60
53+
},
54+
{
55+
"epoch": 0.30501089324618735,
56+
"grad_norm": 0.9205061197280884,
57+
"learning_rate": 6.2e-05,
58+
"loss": 2.448424530029297,
59+
"step": 70
60+
},
61+
{
62+
"epoch": 0.3485838779956427,
63+
"grad_norm": 0.8652524948120117,
64+
"learning_rate": 4.2e-05,
65+
"loss": 2.450721549987793,
66+
"step": 80
67+
},
68+
{
69+
"epoch": 0.39215686274509803,
70+
"grad_norm": 0.7723934054374695,
71+
"learning_rate": 2.2000000000000003e-05,
72+
"loss": 2.462067222595215,
73+
"step": 90
74+
},
75+
{
76+
"epoch": 0.4357298474945534,
77+
"grad_norm": 1.4185497760772705,
78+
"learning_rate": 2.0000000000000003e-06,
79+
"loss": 2.2699026107788085,
80+
"step": 100
81+
}
82+
],
83+
"logging_steps": 10,
84+
"max_steps": 100,
85+
"num_input_tokens_seen": 0,
86+
"num_train_epochs": 1,
87+
"save_steps": 500,
88+
"stateful_callbacks": {
89+
"TrainerControl": {
90+
"args": {
91+
"should_epoch_stop": false,
92+
"should_evaluate": false,
93+
"should_log": false,
94+
"should_save": true,
95+
"should_training_stop": true
96+
},
97+
"attributes": {}
98+
}
99+
},
100+
"total_flos": 3418701692928000.0,
101+
"train_batch_size": 2,
102+
"trial_name": null,
103+
"trial_params": null
104+
}

0 commit comments

Comments
 (0)