Skip to content

Commit 837a0dc

Browse files
TimDettmersclaude
andcommitted
script: Add end-to-end training validation for Qwen3-30B-A3B
Trains with both streaming and non-streaming paths, compares loss curves, tests LoRA save/reload. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 835139f commit 837a0dc

File tree

1 file changed

+213
-0
lines changed

1 file changed

+213
-0
lines changed

scripts/train_qwen3_30b.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""End-to-end training validation for Qwen3-30B-A3B.
2+
3+
Trains with both streaming (from_quantized) and non-streaming (standard) paths
4+
and compares loss curves. Success criterion: loss must match within 5% per step.
5+
"""
6+
7+
import json
8+
import os
9+
import time
10+
11+
import torch
12+
from datasets import load_dataset
13+
from transformers import AutoTokenizer
14+
15+
from bitsandbytes.checkpoint import save_quantized, save_lora, load_lora
16+
from bitsandbytes.kbit_lora import KbitLoraModel
17+
18+
19+
def prepare_data(tokenizer, n_samples=200, max_len=256):
20+
"""Load Alpaca and tokenize."""
21+
ds = load_dataset("tatsu-lab/alpaca", split="train")
22+
ds = ds.select(range(n_samples))
23+
24+
all_input_ids = []
25+
all_labels = []
26+
for example in ds:
27+
text = example["text"]
28+
tokens = tokenizer(text, truncation=True, max_length=max_len, return_tensors="pt")
29+
input_ids = tokens["input_ids"][0]
30+
if len(input_ids) < 10:
31+
continue
32+
all_input_ids.append(input_ids)
33+
all_labels.append(input_ids.clone())
34+
35+
return all_input_ids, all_labels
36+
37+
38+
def train_streaming(model, input_ids_list, labels_list, n_steps=100, lr=1e-4):
39+
"""Train with forward_streaming / backward_streaming."""
40+
optimizer = torch.optim.AdamW(
41+
[p for p in model._lora_params.parameters() if p.requires_grad],
42+
lr=lr,
43+
)
44+
# Also add norm params
45+
norm_params = [p for p in model.parameters() if p.requires_grad and p not in set(model._lora_params.parameters())]
46+
if norm_params:
47+
optimizer.add_param_group({"params": norm_params, "lr": lr})
48+
49+
losses = []
50+
t0 = time.time()
51+
for step in range(n_steps):
52+
idx = step % len(input_ids_list)
53+
input_ids = input_ids_list[idx].unsqueeze(0).cuda()
54+
labels = labels_list[idx].unsqueeze(0).cuda()
55+
56+
optimizer.zero_grad()
57+
loss, ctx = model.forward_streaming(input_ids, labels)
58+
model.backward_streaming(ctx)
59+
optimizer.step()
60+
61+
loss_val = loss.item()
62+
losses.append(loss_val)
63+
if step % 10 == 0:
64+
elapsed = time.time() - t0
65+
print(f" Step {step:3d} | loss={loss_val:.4f} | {elapsed:.1f}s")
66+
67+
elapsed = time.time() - t0
68+
print(f" Training complete: {n_steps} steps in {elapsed:.1f}s ({elapsed/n_steps:.2f}s/step)")
69+
return losses
70+
71+
72+
def train_standard(model, input_ids_list, labels_list, n_steps=100, lr=1e-4):
73+
"""Train with standard forward + loss.backward()."""
74+
optimizer = torch.optim.AdamW(
75+
[p for p in model._lora_params.parameters() if p.requires_grad],
76+
lr=lr,
77+
)
78+
norm_params = [p for p in model.parameters() if p.requires_grad and p not in set(model._lora_params.parameters())]
79+
if norm_params:
80+
optimizer.add_param_group({"params": norm_params, "lr": lr})
81+
82+
losses = []
83+
t0 = time.time()
84+
for step in range(n_steps):
85+
idx = step % len(input_ids_list)
86+
input_ids = input_ids_list[idx].unsqueeze(0).cuda()
87+
labels = labels_list[idx].unsqueeze(0).cuda()
88+
89+
optimizer.zero_grad()
90+
loss = model(input_ids, labels)
91+
loss.backward()
92+
optimizer.step()
93+
94+
loss_val = loss.item()
95+
losses.append(loss_val)
96+
if step % 10 == 0:
97+
elapsed = time.time() - t0
98+
print(f" Step {step:3d} | loss={loss_val:.4f} | {elapsed:.1f}s")
99+
100+
elapsed = time.time() - t0
101+
print(f" Training complete: {n_steps} steps in {elapsed:.1f}s ({elapsed/n_steps:.2f}s/step)")
102+
return losses
103+
104+
105+
def compare_losses(losses_streaming, losses_standard, tolerance=0.05):
106+
"""Compare two loss curves. Returns True if they match within tolerance."""
107+
assert len(losses_streaming) == len(losses_standard)
108+
max_rel_diff = 0
109+
mismatches = 0
110+
for i, (ls, ln) in enumerate(zip(losses_streaming, losses_standard)):
111+
if ln == 0:
112+
continue
113+
rel_diff = abs(ls - ln) / abs(ln)
114+
max_rel_diff = max(max_rel_diff, rel_diff)
115+
if rel_diff > tolerance:
116+
mismatches += 1
117+
if mismatches <= 5:
118+
print(f" Step {i}: streaming={ls:.4f} standard={ln:.4f} diff={rel_diff:.4f}")
119+
120+
print(f" Max relative difference: {max_rel_diff:.4f}")
121+
print(f" Steps exceeding {tolerance*100}% tolerance: {mismatches}/{len(losses_streaming)}")
122+
return mismatches == 0, max_rel_diff
123+
124+
125+
def main():
126+
quantized_path = os.path.expanduser("~/quantized/qwen3-30b-a3b-4bit.safetensors")
127+
model_name = "Qwen/Qwen3-30B-A3B"
128+
n_steps = 100
129+
lr = 1e-4
130+
131+
# Load tokenizer
132+
print("Loading tokenizer...")
133+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
134+
if tokenizer.pad_token is None:
135+
tokenizer.pad_token = tokenizer.eos_token
136+
137+
# Prepare data
138+
print("Preparing data...")
139+
input_ids_list, labels_list = prepare_data(tokenizer, n_samples=200, max_len=256)
140+
print(f" {len(input_ids_list)} samples prepared")
141+
142+
# === Path 1: Streaming (from_quantized) ===
143+
print("\n=== Streaming path (from_quantized) ===")
144+
torch.manual_seed(42)
145+
model_stream = KbitLoraModel.from_quantized(
146+
quantized_path, weight_streaming=True, lora_r=16,
147+
)
148+
losses_streaming = train_streaming(model_stream, input_ids_list, labels_list, n_steps=n_steps, lr=lr)
149+
150+
# Save LoRA
151+
lora_path = os.path.expanduser("~/quantized/qwen3-30b-lora.pt")
152+
save_lora(model_stream, lora_path)
153+
print(f" LoRA saved to {lora_path}")
154+
155+
# Free memory
156+
del model_stream
157+
torch.cuda.empty_cache()
158+
159+
# === Path 2: Non-streaming (standard forward) ===
160+
print("\n=== Non-streaming path (standard forward) ===")
161+
torch.manual_seed(42)
162+
model_standard = KbitLoraModel.from_quantized(
163+
quantized_path, weight_streaming=False, lora_r=16,
164+
)
165+
losses_standard = train_standard(model_standard, input_ids_list, labels_list, n_steps=n_steps, lr=lr)
166+
167+
del model_standard
168+
torch.cuda.empty_cache()
169+
170+
# === Compare loss curves ===
171+
print("\n=== Loss curve comparison ===")
172+
matches, max_diff = compare_losses(losses_streaming, losses_standard)
173+
if matches:
174+
print(" PASS: Loss curves match within 5%")
175+
else:
176+
print(" FAIL: Loss curves diverge by more than 5%")
177+
178+
# === Reload LoRA and verify ===
179+
print("\n=== LoRA reload test ===")
180+
torch.manual_seed(42)
181+
model_reload = KbitLoraModel.from_quantized(
182+
quantized_path, weight_streaming=False, lora_r=16,
183+
lora_checkpoint=lora_path,
184+
)
185+
# Quick inference test
186+
prompt = "What is machine learning?"
187+
tokens = tokenizer(prompt, return_tensors="pt")
188+
input_ids = tokens["input_ids"].cuda()
189+
190+
with torch.no_grad():
191+
output = model_reload(input_ids, labels=None)
192+
# Just verify it runs without error
193+
print(f" LoRA reload OK, output shape: {output.shape if hasattr(output, 'shape') else type(output)}")
194+
195+
# Save results
196+
results = {
197+
"losses_streaming": losses_streaming,
198+
"losses_standard": losses_standard,
199+
"max_rel_diff": max_diff,
200+
"matches": matches,
201+
"n_steps": n_steps,
202+
"lr": lr,
203+
"model": "Qwen3-30B-A3B",
204+
"lora_r": 16,
205+
}
206+
results_path = os.path.expanduser("~/quantized/training_results.json")
207+
with open(results_path, "w") as f:
208+
json.dump(results, f, indent=2)
209+
print(f"\nResults saved to {results_path}")
210+
211+
212+
if __name__ == "__main__":
213+
main()

0 commit comments

Comments
 (0)