Skip to content

Commit f3cef07

Browse files
committed
script: Add GDS validation script for dettmers-desktop
Compares GDS (GPUDirect Storage) vs CPU pinned weight streaming on Qwen3-30B-A3B with forced zero residency. Measures per-step timing and estimates streaming bandwidth.
1 parent 654c7db commit f3cef07

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed

scripts/validate_gds.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""GDS validation script for dettmers-desktop.
2+
3+
Tests GDS (GPUDirect Storage) vs CPU pinned weight streaming on a real model.
4+
Forces zero residency to exercise the full streaming path even when VRAM is plentiful.
5+
6+
Run on dettmers-desktop:
7+
BNB_CUDA_VERSION=131 PYTHONPATH=. python scripts/validate_gds.py
8+
"""
9+
10+
import time
11+
from unittest.mock import patch
12+
13+
import torch
14+
from transformers import AutoTokenizer
15+
16+
from bitsandbytes.kbit_lora import KbitLoraModel
17+
18+
19+
def train_steps(model, input_ids_list, labels_list, n_steps=20, label=""):
20+
"""Train n_steps and return per-step timing and losses."""
21+
optimizer = torch.optim.AdamW(
22+
[p for p in model._lora_params.parameters() if p.requires_grad],
23+
lr=1e-4,
24+
)
25+
norm_params = [
26+
p for p in model.parameters()
27+
if p.requires_grad and p not in set(model._lora_params.parameters())
28+
]
29+
if norm_params:
30+
optimizer.add_param_group({"params": norm_params, "lr": 1e-4})
31+
32+
step_times = []
33+
losses = []
34+
35+
# Warmup step (not counted)
36+
idx = 0
37+
input_ids = input_ids_list[idx].unsqueeze(0).cuda()
38+
labels = labels_list[idx].unsqueeze(0).cuda()
39+
optimizer.zero_grad()
40+
loss, ctx = model.forward_streaming(input_ids, labels)
41+
model.backward_streaming(ctx)
42+
optimizer.step()
43+
torch.cuda.synchronize()
44+
print(f" [{label}] Warmup done, loss={loss.item():.4f}")
45+
46+
for step in range(n_steps):
47+
idx = (step + 1) % len(input_ids_list)
48+
input_ids = input_ids_list[idx].unsqueeze(0).cuda()
49+
labels = labels_list[idx].unsqueeze(0).cuda()
50+
51+
torch.cuda.synchronize()
52+
t0 = time.perf_counter()
53+
54+
optimizer.zero_grad()
55+
loss, ctx = model.forward_streaming(input_ids, labels)
56+
model.backward_streaming(ctx)
57+
optimizer.step()
58+
59+
torch.cuda.synchronize()
60+
t1 = time.perf_counter()
61+
62+
step_times.append(t1 - t0)
63+
losses.append(loss.item())
64+
if step % 5 == 0:
65+
print(f" [{label}] Step {step:2d} | loss={loss.item():.4f} | {t1-t0:.3f}s")
66+
67+
return step_times, losses
68+
69+
70+
def main():
71+
import os
72+
73+
quantized_path = os.path.expanduser("~/quantized/qwen3-30b-a3b-4bit.safetensors")
74+
model_name = "Qwen/Qwen3-30B-A3B"
75+
n_steps = 20
76+
77+
# Load tokenizer and prepare data
78+
print("Loading tokenizer...")
79+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
80+
if tokenizer.pad_token is None:
81+
tokenizer.pad_token = tokenizer.eos_token
82+
83+
from datasets import load_dataset
84+
ds = load_dataset("tatsu-lab/alpaca", split="train").select(range(50))
85+
input_ids_list = []
86+
labels_list = []
87+
for ex in ds:
88+
tokens = tokenizer(ex["text"], truncation=True, max_length=256, return_tensors="pt")
89+
ids = tokens["input_ids"][0]
90+
if len(ids) >= 10:
91+
input_ids_list.append(ids)
92+
labels_list.append(ids.clone())
93+
print(f" {len(input_ids_list)} samples prepared")
94+
95+
# Compute model size for bandwidth calculation
96+
import struct, json
97+
with open(quantized_path, "rb") as f:
98+
header_size = struct.unpack("<Q", f.read(8))[0]
99+
header_json = json.loads(f.read(header_size))
100+
metadata = header_json.get("__metadata__", {})
101+
n_layers = int(metadata.get("num_layers", 48))
102+
print(f" Model: {n_layers} layers")
103+
104+
# === Test 1: GDS path ===
105+
print("\n=== GDS path (use_gds=True, forced 0 resident) ===")
106+
torch.manual_seed(42)
107+
with patch.object(KbitLoraModel, "_compute_residency", return_value=0):
108+
model_gds = KbitLoraModel.from_quantized(
109+
quantized_path, weight_streaming=True, use_gds=True, lora_r=16,
110+
)
111+
print(f" RAM strategy: {model_gds._ram_strategy}")
112+
print(f" GDS enabled: {model_gds._use_gds}")
113+
print(f" Resident layers: {model_gds._n_resident}")
114+
115+
gds_times, gds_losses = train_steps(
116+
model_gds, input_ids_list, labels_list, n_steps=n_steps, label="GDS"
117+
)
118+
119+
del model_gds
120+
torch.cuda.empty_cache()
121+
122+
# === Test 2: CPU pinned path ===
123+
print("\n=== CPU pinned path (forced 0 resident) ===")
124+
torch.manual_seed(42)
125+
with patch.object(KbitLoraModel, "_compute_residency", return_value=0):
126+
model_pinned = KbitLoraModel.from_quantized(
127+
quantized_path, weight_streaming=True, use_gds=False, lora_r=16,
128+
)
129+
print(f" RAM strategy: {model_pinned._ram_strategy}")
130+
print(f" Resident layers: {model_pinned._n_resident}")
131+
132+
pinned_times, pinned_losses = train_steps(
133+
model_pinned, input_ids_list, labels_list, n_steps=n_steps, label="Pinned"
134+
)
135+
136+
del model_pinned
137+
torch.cuda.empty_cache()
138+
139+
# === Results ===
140+
print("\n=== Results ===")
141+
gds_avg = sum(gds_times) / len(gds_times)
142+
pinned_avg = sum(pinned_times) / len(pinned_times)
143+
print(f" GDS avg step time: {gds_avg:.3f}s")
144+
print(f" Pinned avg step time: {pinned_avg:.3f}s")
145+
print(f" Speedup (pinned/GDS): {pinned_avg/gds_avg:.2f}x")
146+
147+
# Estimate streaming bandwidth
148+
# Each step does forward (48 layers) + backward (48 layers) = 96 layer loads
149+
# Model size on disk: ~19.53 GB, so per-layer ~0.407 GB
150+
# For streaming: each layer loaded twice (forward + backward) per step
151+
# But with double-buffering, prefetch overlaps with compute
152+
model_size_gb = 19.53 # approximate
153+
layer_size_gb = model_size_gb / n_layers
154+
layers_per_step = n_layers * 2 # forward + backward
155+
data_per_step_gb = layers_per_step * layer_size_gb
156+
157+
gds_bw = data_per_step_gb / gds_avg
158+
pinned_bw = data_per_step_gb / pinned_avg
159+
print(f"\n Estimated streaming bandwidth:")
160+
print(f" GDS: {gds_bw:.1f} GB/s ({data_per_step_gb:.1f} GB/step)")
161+
print(f" Pinned: {pinned_bw:.1f} GB/s")
162+
163+
# Loss comparison
164+
max_diff = 0
165+
for i, (lg, lp) in enumerate(zip(gds_losses, pinned_losses)):
166+
if lp != 0:
167+
diff = abs(lg - lp) / abs(lp)
168+
max_diff = max(max_diff, diff)
169+
print(f"\n Loss comparison: max relative diff = {max_diff:.4f}")
170+
if max_diff < 0.05:
171+
print(" PASS: GDS and pinned produce matching losses")
172+
else:
173+
print(" WARNING: Loss difference exceeds 5%")
174+
175+
print("\n All GDS validation checks passed!" if max_diff < 0.05 else "\n GDS validation had issues.")
176+
177+
178+
if __name__ == "__main__":
179+
main()

0 commit comments

Comments
 (0)