|
| 1 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 2 | +// BITNET b1.58 ACTIVATION QUANTIZATION TEST |
| 3 | +// Test coherent text generation with 8-bit activation quantization |
| 4 | +// φ² + 1/φ² = 3 = TRINITY | KOSCHEI IS IMMORTAL |
| 5 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 6 | + |
| 7 | +const std = @import("std"); |
| 8 | +const full_model = @import("bitnet_full_model.zig"); |
| 9 | +const json = std.json; |
| 10 | + |
| 11 | +pub const PHI: f64 = 1.618033988749895; |
| 12 | + |
| 13 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 14 | +// TOKENIZER |
| 15 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 16 | + |
| 17 | +pub const Tokenizer = struct { |
| 18 | + allocator: std.mem.Allocator, |
| 19 | + vocab: std.StringHashMap(u32), |
| 20 | + id_to_token: std.AutoHashMap(u32, []const u8), |
| 21 | + bos_token_id: u32 = 1, |
| 22 | + eos_token_id: u32 = 2, |
| 23 | + |
| 24 | + pub fn load(allocator: std.mem.Allocator, path: []const u8) !Tokenizer { |
| 25 | + const file = try std.fs.cwd().openFile(path, .{}); |
| 26 | + defer file.close(); |
| 27 | + |
| 28 | + const content = try file.readToEndAlloc(allocator, 100 * 1024 * 1024); |
| 29 | + defer allocator.free(content); |
| 30 | + |
| 31 | + var parsed = try json.parseFromSlice(json.Value, allocator, content, .{}); |
| 32 | + defer parsed.deinit(); |
| 33 | + |
| 34 | + var vocab = std.StringHashMap(u32).init(allocator); |
| 35 | + var id_to_token = std.AutoHashMap(u32, []const u8).init(allocator); |
| 36 | + |
| 37 | + // Parse vocab from model section |
| 38 | + if (parsed.value.object.get("model")) |model| { |
| 39 | + if (model.object.get("vocab")) |vocab_obj| { |
| 40 | + var it = vocab_obj.object.iterator(); |
| 41 | + while (it.next()) |entry| { |
| 42 | + const token = try allocator.dupe(u8, entry.key_ptr.*); |
| 43 | + const id: u32 = @intCast(entry.value_ptr.*.integer); |
| 44 | + try vocab.put(token, id); |
| 45 | + try id_to_token.put(id, token); |
| 46 | + } |
| 47 | + } |
| 48 | + } |
| 49 | + |
| 50 | + std.debug.print("Loaded tokenizer with {d} tokens\n", .{vocab.count()}); |
| 51 | + |
| 52 | + return Tokenizer{ |
| 53 | + .allocator = allocator, |
| 54 | + .vocab = vocab, |
| 55 | + .id_to_token = id_to_token, |
| 56 | + }; |
| 57 | + } |
| 58 | + |
| 59 | + pub fn encode(self: *Tokenizer, text: []const u8) ![]u32 { |
| 60 | + var tokens = std.ArrayList(u32).init(self.allocator); |
| 61 | + |
| 62 | + // Add BOS token |
| 63 | + try tokens.append(self.bos_token_id); |
| 64 | + |
| 65 | + // Simple character-level fallback |
| 66 | + var i: usize = 0; |
| 67 | + while (i < text.len) { |
| 68 | + var found = false; |
| 69 | + |
| 70 | + // Try to match longest token first |
| 71 | + var max_len = @min(text.len - i, 20); |
| 72 | + while (max_len > 0) : (max_len -= 1) { |
| 73 | + const substr = text[i..i + max_len]; |
| 74 | + if (self.vocab.get(substr)) |id| { |
| 75 | + try tokens.append(id); |
| 76 | + i += max_len; |
| 77 | + found = true; |
| 78 | + break; |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + if (!found) { |
| 83 | + // Unknown token, skip character |
| 84 | + i += 1; |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + return tokens.toOwnedSlice(); |
| 89 | + } |
| 90 | + |
| 91 | + pub fn decode(self: *Tokenizer, tokens: []const u32) ![]u8 { |
| 92 | + var result = std.ArrayList(u8).init(self.allocator); |
| 93 | + |
| 94 | + for (tokens) |id| { |
| 95 | + if (id == self.bos_token_id or id == self.eos_token_id) continue; |
| 96 | + |
| 97 | + if (self.id_to_token.get(id)) |token| { |
| 98 | + // Handle special tokens like Ġ (space prefix) |
| 99 | + for (token) |c| { |
| 100 | + if (c == 0xC4) continue; // Skip UTF-8 prefix |
| 101 | + if (c == 0xA0) { // Ġ = space |
| 102 | + try result.append(' '); |
| 103 | + } else { |
| 104 | + try result.append(c); |
| 105 | + } |
| 106 | + } |
| 107 | + } else { |
| 108 | + try result.appendSlice("[UNK]"); |
| 109 | + } |
| 110 | + } |
| 111 | + |
| 112 | + return result.toOwnedSlice(); |
| 113 | + } |
| 114 | + |
| 115 | + pub fn deinit(self: *Tokenizer) void { |
| 116 | + var it = self.vocab.iterator(); |
| 117 | + while (it.next()) |entry| { |
| 118 | + self.allocator.free(entry.key_ptr.*); |
| 119 | + } |
| 120 | + self.vocab.deinit(); |
| 121 | + self.id_to_token.deinit(); |
| 122 | + } |
| 123 | +}; |
| 124 | + |
| 125 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 126 | +// MAIN TEST |
| 127 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 128 | + |
| 129 | +pub fn main() !void { |
| 130 | + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; |
| 131 | + defer _ = gpa.deinit(); |
| 132 | + const allocator = gpa.allocator(); |
| 133 | + |
| 134 | + std.debug.print("\n", .{}); |
| 135 | + std.debug.print("╔══════════════════════════════════════════════════════════════╗\n", .{}); |
| 136 | + std.debug.print("║ BITNET b1.58 ACTIVATION QUANTIZATION TEST ║\n", .{}); |
| 137 | + std.debug.print("║ 8-bit per-token absmax quantization ║\n", .{}); |
| 138 | + std.debug.print("║ φ² + 1/φ² = 3 = TRINITY ║\n", .{}); |
| 139 | + std.debug.print("╚══════════════════════════════════════════════════════════════╝\n", .{}); |
| 140 | + std.debug.print("\n", .{}); |
| 141 | + |
| 142 | + // Initialize model |
| 143 | + std.debug.print("Initializing BitNet b1.58 model with activation quantization...\n", .{}); |
| 144 | + const config = full_model.BitNetConfig{}; |
| 145 | + var model = try full_model.BitNetFullModel.init(allocator, config); |
| 146 | + defer model.deinit(); |
| 147 | + |
| 148 | + // Load model weights |
| 149 | + std.debug.print("Loading model weights from safetensors...\n", .{}); |
| 150 | + model.loadFromSafetensors("/workspaces/trinity/models/bitnet/model.safetensors") catch |err| { |
| 151 | + std.debug.print("Failed to load model: {}\n", .{err}); |
| 152 | + std.debug.print("Please ensure model is downloaded to models/bitnet/\n", .{}); |
| 153 | + return; |
| 154 | + }; |
| 155 | + |
| 156 | + // Initialize KV-cache |
| 157 | + try model.initKVCache(256); |
| 158 | + |
| 159 | + // Load tokenizer |
| 160 | + std.debug.print("\nLoading tokenizer...\n", .{}); |
| 161 | + var tokenizer = Tokenizer.load(allocator, "/workspaces/trinity/models/bitnet/tokenizer.json") catch |err| { |
| 162 | + std.debug.print("Failed to load tokenizer: {}\n", .{err}); |
| 163 | + return; |
| 164 | + }; |
| 165 | + defer tokenizer.deinit(); |
| 166 | + |
| 167 | + // Test prompts (10+ for comprehensive testing) |
| 168 | + const prompts = [_][]const u8{ |
| 169 | + "Hello, my name is", |
| 170 | + "The meaning of life is", |
| 171 | + "Artificial intelligence will", |
| 172 | + "The golden ratio phi equals", |
| 173 | + "In the year 2026,", |
| 174 | + "The best programming language is", |
| 175 | + "Machine learning models can", |
| 176 | + "The future of technology", |
| 177 | + "Science has proven that", |
| 178 | + "The most important thing in life is", |
| 179 | + "Quantum computing will revolutionize", |
| 180 | + "The universe is made of", |
| 181 | + }; |
| 182 | + |
| 183 | + std.debug.print("\n", .{}); |
| 184 | + std.debug.print("═══════════════════════════════════════════════════════════════════\n", .{}); |
| 185 | + std.debug.print(" GENERATION RESULTS (with 8-bit activation quantization) \n", .{}); |
| 186 | + std.debug.print("═══════════════════════════════════════════════════════════════════\n", .{}); |
| 187 | + |
| 188 | + var total_tokens: usize = 0; |
| 189 | + var total_time_ms: i64 = 0; |
| 190 | + var coherent_count: usize = 0; |
| 191 | + |
| 192 | + for (prompts, 0..) |prompt, i| { |
| 193 | + std.debug.print("\n[Test {d}] Prompt: \"{s}\"\n", .{i + 1, prompt}); |
| 194 | + |
| 195 | + // Encode prompt |
| 196 | + const prompt_tokens = try tokenizer.encode(prompt); |
| 197 | + defer allocator.free(prompt_tokens); |
| 198 | + |
| 199 | + std.debug.print(" Prompt tokens ({d}): ", .{prompt_tokens.len}); |
| 200 | + for (prompt_tokens[0..@min(prompt_tokens.len, 8)]) |t| { |
| 201 | + std.debug.print("{d} ", .{t}); |
| 202 | + } |
| 203 | + std.debug.print("\n", .{}); |
| 204 | + |
| 205 | + // Reset KV-cache for new generation |
| 206 | + model.resetKVCache(); |
| 207 | + |
| 208 | + // Generate with full model (includes activation quantization) |
| 209 | + const start_time = std.time.milliTimestamp(); |
| 210 | + const generated = model.generate(prompt_tokens, 32, 0.8) catch |err| { |
| 211 | + std.debug.print(" Generation failed: {}\n", .{err}); |
| 212 | + continue; |
| 213 | + }; |
| 214 | + defer allocator.free(generated); |
| 215 | + const end_time = std.time.milliTimestamp(); |
| 216 | + |
| 217 | + // Decode |
| 218 | + const text = try tokenizer.decode(generated); |
| 219 | + defer allocator.free(text); |
| 220 | + |
| 221 | + const gen_tokens = generated.len - prompt_tokens.len; |
| 222 | + const time_ms = end_time - start_time; |
| 223 | + const tps = if (time_ms > 0) @as(f32, @floatFromInt(gen_tokens)) / (@as(f32, @floatFromInt(time_ms)) / 1000.0) else 0.0; |
| 224 | + |
| 225 | + total_tokens += gen_tokens; |
| 226 | + total_time_ms += time_ms; |
| 227 | + |
| 228 | + // Check coherence (simple heuristic: has spaces and reasonable length) |
| 229 | + const is_coherent = text.len > prompt.len + 5 and std.mem.indexOf(u8, text, " ") != null; |
| 230 | + if (is_coherent) coherent_count += 1; |
| 231 | + |
| 232 | + std.debug.print(" Generated ({d} tokens in {d}ms = {d:.1} tok/s):\n", .{gen_tokens, time_ms, tps}); |
| 233 | + std.debug.print(" \"{s}\"\n", .{text}); |
| 234 | + std.debug.print(" Coherent: {s}\n", .{if (is_coherent) "YES" else "NO"}); |
| 235 | + } |
| 236 | + |
| 237 | + // Summary |
| 238 | + std.debug.print("\n", .{}); |
| 239 | + std.debug.print("═══════════════════════════════════════════════════════════════════\n", .{}); |
| 240 | + std.debug.print(" SUMMARY \n", .{}); |
| 241 | + std.debug.print("═══════════════════════════════════════════════════════════════════\n", .{}); |
| 242 | + |
| 243 | + const avg_tps = if (total_time_ms > 0) @as(f32, @floatFromInt(total_tokens)) / (@as(f32, @floatFromInt(total_time_ms)) / 1000.0) else 0.0; |
| 244 | + |
| 245 | + std.debug.print("\n", .{}); |
| 246 | + std.debug.print(" Total prompts tested: {d}\n", .{prompts.len}); |
| 247 | + std.debug.print(" Coherent generations: {d}/{d} ({d:.1}%)\n", .{ |
| 248 | + coherent_count, prompts.len, |
| 249 | + @as(f32, @floatFromInt(coherent_count)) / @as(f32, @floatFromInt(prompts.len)) * 100.0 |
| 250 | + }); |
| 251 | + std.debug.print(" Total tokens generated: {d}\n", .{total_tokens}); |
| 252 | + std.debug.print(" Total time: {d}ms\n", .{total_time_ms}); |
| 253 | + std.debug.print(" Average throughput: {d:.1} tok/s\n", .{avg_tps}); |
| 254 | + std.debug.print("\n", .{}); |
| 255 | + std.debug.print(" Activation quantization: 8-bit per-token absmax\n", .{}); |
| 256 | + std.debug.print(" Weight quantization: QAT (trained ternary)\n", .{}); |
| 257 | + std.debug.print("\n", .{}); |
| 258 | + std.debug.print("═══════════════════════════════════════════════════════════════════\n", .{}); |
| 259 | + std.debug.print(" TEST COMPLETE \n", .{}); |
| 260 | + std.debug.print("═══════════════════════════════════════════════════════════════════\n", .{}); |
| 261 | + std.debug.print("\nφ² + 1/φ² = 3 = TRINITY | KOSCHEI IS IMMORTAL\n\n", .{}); |
| 262 | +} |
| 263 | + |
| 264 | +test "activation quantization functions" { |
| 265 | + const forward = @import("bitnet_forward.zig"); |
| 266 | + |
| 267 | + // Test quantize in place |
| 268 | + var input = [_]f32{ 0.5, -1.0, 0.25, 0.75, -0.5 }; |
| 269 | + const scale = forward.quantizeActivationsInPlace(&input); |
| 270 | + _ = scale; |
| 271 | + |
| 272 | + // Values should be close to original (quantization noise) |
| 273 | + try std.testing.expectApproxEqAbs(@as(f32, 0.5), input[0], 0.01); |
| 274 | + try std.testing.expectApproxEqAbs(@as(f32, -1.0), input[1], 0.01); |
| 275 | +} |
0 commit comments