|
| 1 | +"""GDS validation script for dettmers-desktop. |
| 2 | +
|
| 3 | +Tests GDS (GPUDirect Storage) vs CPU pinned weight streaming on a real model. |
| 4 | +Forces zero residency to exercise the full streaming path even when VRAM is plentiful. |
| 5 | +
|
| 6 | +Run on dettmers-desktop: |
| 7 | + BNB_CUDA_VERSION=131 PYTHONPATH=. python scripts/validate_gds.py |
| 8 | +""" |
| 9 | + |
| 10 | +import time |
| 11 | +from unittest.mock import patch |
| 12 | + |
| 13 | +import torch |
| 14 | +from transformers import AutoTokenizer |
| 15 | + |
| 16 | +from bitsandbytes.kbit_lora import KbitLoraModel |
| 17 | + |
| 18 | + |
| 19 | +def train_steps(model, input_ids_list, labels_list, n_steps=20, label=""): |
| 20 | + """Train n_steps and return per-step timing and losses.""" |
| 21 | + optimizer = torch.optim.AdamW( |
| 22 | + [p for p in model._lora_params.parameters() if p.requires_grad], |
| 23 | + lr=1e-4, |
| 24 | + ) |
| 25 | + norm_params = [ |
| 26 | + p for p in model.parameters() |
| 27 | + if p.requires_grad and p not in set(model._lora_params.parameters()) |
| 28 | + ] |
| 29 | + if norm_params: |
| 30 | + optimizer.add_param_group({"params": norm_params, "lr": 1e-4}) |
| 31 | + |
| 32 | + step_times = [] |
| 33 | + losses = [] |
| 34 | + |
| 35 | + # Warmup step (not counted) |
| 36 | + idx = 0 |
| 37 | + input_ids = input_ids_list[idx].unsqueeze(0).cuda() |
| 38 | + labels = labels_list[idx].unsqueeze(0).cuda() |
| 39 | + optimizer.zero_grad() |
| 40 | + loss, ctx = model.forward_streaming(input_ids, labels) |
| 41 | + model.backward_streaming(ctx) |
| 42 | + optimizer.step() |
| 43 | + torch.cuda.synchronize() |
| 44 | + print(f" [{label}] Warmup done, loss={loss.item():.4f}") |
| 45 | + |
| 46 | + for step in range(n_steps): |
| 47 | + idx = (step + 1) % len(input_ids_list) |
| 48 | + input_ids = input_ids_list[idx].unsqueeze(0).cuda() |
| 49 | + labels = labels_list[idx].unsqueeze(0).cuda() |
| 50 | + |
| 51 | + torch.cuda.synchronize() |
| 52 | + t0 = time.perf_counter() |
| 53 | + |
| 54 | + optimizer.zero_grad() |
| 55 | + loss, ctx = model.forward_streaming(input_ids, labels) |
| 56 | + model.backward_streaming(ctx) |
| 57 | + optimizer.step() |
| 58 | + |
| 59 | + torch.cuda.synchronize() |
| 60 | + t1 = time.perf_counter() |
| 61 | + |
| 62 | + step_times.append(t1 - t0) |
| 63 | + losses.append(loss.item()) |
| 64 | + if step % 5 == 0: |
| 65 | + print(f" [{label}] Step {step:2d} | loss={loss.item():.4f} | {t1-t0:.3f}s") |
| 66 | + |
| 67 | + return step_times, losses |
| 68 | + |
| 69 | + |
| 70 | +def main(): |
| 71 | + import os |
| 72 | + |
| 73 | + quantized_path = os.path.expanduser("~/quantized/qwen3-30b-a3b-4bit.safetensors") |
| 74 | + model_name = "Qwen/Qwen3-30B-A3B" |
| 75 | + n_steps = 20 |
| 76 | + |
| 77 | + # Load tokenizer and prepare data |
| 78 | + print("Loading tokenizer...") |
| 79 | + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| 80 | + if tokenizer.pad_token is None: |
| 81 | + tokenizer.pad_token = tokenizer.eos_token |
| 82 | + |
| 83 | + from datasets import load_dataset |
| 84 | + ds = load_dataset("tatsu-lab/alpaca", split="train").select(range(50)) |
| 85 | + input_ids_list = [] |
| 86 | + labels_list = [] |
| 87 | + for ex in ds: |
| 88 | + tokens = tokenizer(ex["text"], truncation=True, max_length=256, return_tensors="pt") |
| 89 | + ids = tokens["input_ids"][0] |
| 90 | + if len(ids) >= 10: |
| 91 | + input_ids_list.append(ids) |
| 92 | + labels_list.append(ids.clone()) |
| 93 | + print(f" {len(input_ids_list)} samples prepared") |
| 94 | + |
| 95 | + # Compute model size for bandwidth calculation |
| 96 | + import struct, json |
| 97 | + with open(quantized_path, "rb") as f: |
| 98 | + header_size = struct.unpack("<Q", f.read(8))[0] |
| 99 | + header_json = json.loads(f.read(header_size)) |
| 100 | + metadata = header_json.get("__metadata__", {}) |
| 101 | + n_layers = int(metadata.get("num_layers", 48)) |
| 102 | + print(f" Model: {n_layers} layers") |
| 103 | + |
| 104 | + # === Test 1: GDS path === |
| 105 | + print("\n=== GDS path (use_gds=True, forced 0 resident) ===") |
| 106 | + torch.manual_seed(42) |
| 107 | + with patch.object(KbitLoraModel, "_compute_residency", return_value=0): |
| 108 | + model_gds = KbitLoraModel.from_quantized( |
| 109 | + quantized_path, weight_streaming=True, use_gds=True, lora_r=16, |
| 110 | + ) |
| 111 | + print(f" RAM strategy: {model_gds._ram_strategy}") |
| 112 | + print(f" GDS enabled: {model_gds._use_gds}") |
| 113 | + print(f" Resident layers: {model_gds._n_resident}") |
| 114 | + |
| 115 | + gds_times, gds_losses = train_steps( |
| 116 | + model_gds, input_ids_list, labels_list, n_steps=n_steps, label="GDS" |
| 117 | + ) |
| 118 | + |
| 119 | + del model_gds |
| 120 | + torch.cuda.empty_cache() |
| 121 | + |
| 122 | + # === Test 2: CPU pinned path === |
| 123 | + print("\n=== CPU pinned path (forced 0 resident) ===") |
| 124 | + torch.manual_seed(42) |
| 125 | + with patch.object(KbitLoraModel, "_compute_residency", return_value=0): |
| 126 | + model_pinned = KbitLoraModel.from_quantized( |
| 127 | + quantized_path, weight_streaming=True, use_gds=False, lora_r=16, |
| 128 | + ) |
| 129 | + print(f" RAM strategy: {model_pinned._ram_strategy}") |
| 130 | + print(f" Resident layers: {model_pinned._n_resident}") |
| 131 | + |
| 132 | + pinned_times, pinned_losses = train_steps( |
| 133 | + model_pinned, input_ids_list, labels_list, n_steps=n_steps, label="Pinned" |
| 134 | + ) |
| 135 | + |
| 136 | + del model_pinned |
| 137 | + torch.cuda.empty_cache() |
| 138 | + |
| 139 | + # === Results === |
| 140 | + print("\n=== Results ===") |
| 141 | + gds_avg = sum(gds_times) / len(gds_times) |
| 142 | + pinned_avg = sum(pinned_times) / len(pinned_times) |
| 143 | + print(f" GDS avg step time: {gds_avg:.3f}s") |
| 144 | + print(f" Pinned avg step time: {pinned_avg:.3f}s") |
| 145 | + print(f" Speedup (pinned/GDS): {pinned_avg/gds_avg:.2f}x") |
| 146 | + |
| 147 | + # Estimate streaming bandwidth |
| 148 | + # Each step does forward (48 layers) + backward (48 layers) = 96 layer loads |
| 149 | + # Model size on disk: ~19.53 GB, so per-layer ~0.407 GB |
| 150 | + # For streaming: each layer loaded twice (forward + backward) per step |
| 151 | + # But with double-buffering, prefetch overlaps with compute |
| 152 | + model_size_gb = 19.53 # approximate |
| 153 | + layer_size_gb = model_size_gb / n_layers |
| 154 | + layers_per_step = n_layers * 2 # forward + backward |
| 155 | + data_per_step_gb = layers_per_step * layer_size_gb |
| 156 | + |
| 157 | + gds_bw = data_per_step_gb / gds_avg |
| 158 | + pinned_bw = data_per_step_gb / pinned_avg |
| 159 | + print(f"\n Estimated streaming bandwidth:") |
| 160 | + print(f" GDS: {gds_bw:.1f} GB/s ({data_per_step_gb:.1f} GB/step)") |
| 161 | + print(f" Pinned: {pinned_bw:.1f} GB/s") |
| 162 | + |
| 163 | + # Loss comparison |
| 164 | + max_diff = 0 |
| 165 | + for i, (lg, lp) in enumerate(zip(gds_losses, pinned_losses)): |
| 166 | + if lp != 0: |
| 167 | + diff = abs(lg - lp) / abs(lp) |
| 168 | + max_diff = max(max_diff, diff) |
| 169 | + print(f"\n Loss comparison: max relative diff = {max_diff:.4f}") |
| 170 | + if max_diff < 0.05: |
| 171 | + print(" PASS: GDS and pinned produce matching losses") |
| 172 | + else: |
| 173 | + print(" WARNING: Loss difference exceeds 5%") |
| 174 | + |
| 175 | + print("\n All GDS validation checks passed!" if max_diff < 0.05 else "\n GDS validation had issues.") |
| 176 | + |
| 177 | + |
| 178 | +if __name__ == "__main__": |
| 179 | + main() |
0 commit comments