|
| 1 | +const std = @import("std"); |
| 2 | + |
| 3 | +const SIZES = [_][2]usize{ |
| 4 | + .{ 512, 512 }, |
| 5 | + .{ 1024, 1024 }, |
| 6 | + .{ 2048, 2048 }, |
| 7 | + .{ 4096, 4096 }, |
| 8 | + .{ 8192, 8192 }, |
| 9 | + .{ 4096, 11008 }, |
| 10 | + .{ 5120, 13824 }, |
| 11 | +}; |
| 12 | + |
| 13 | +fn ternaryMatmul(comptime M: usize, comptime N: usize, comptime K: usize, weights: []const i8, input: []const f32, output: []f32) void { |
| 14 | + const lut = [_]f32{ -1.0, 0.0, 1.0 }; |
| 15 | + |
| 16 | + var row: usize = 0; |
| 17 | + while (row < M) : (row += 1) { |
| 18 | + var col: usize = 0; |
| 19 | + while (col < N) : (col += 1) { |
| 20 | + var sum: f32 = 0.0; |
| 21 | + var k: usize = 0; |
| 22 | + while (k < K) : (k += 1) { |
| 23 | + const w_idx = row * K + k; |
| 24 | + const w = weights[w_idx]; |
| 25 | + const w_val = lut[@as(usize, @intCast(w + 1))]; |
| 26 | + sum += w_val * input[k * N + col]; |
| 27 | + } |
| 28 | + output[row * N + col] = sum; |
| 29 | + } |
| 30 | + } |
| 31 | +} |
| 32 | + |
| 33 | +pub fn main() !void { |
| 34 | + const stdout = std.io.getStdOut().writer(); |
| 35 | + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; |
| 36 | + defer _ = gpa.deinit(); |
| 37 | + const allocator = gpa.allocator(); |
| 38 | + |
| 39 | + try stdout.print("\n", .{}); |
| 40 | + try stdout.print("TRINITY TERNARY MATMUL BENCHMARK - A100 80GB PCIe\n", .{}); |
| 41 | + try stdout.print("Method: Batch Row (4 rows) + SIMD-8 + LUT decode\n", .{}); |
| 42 | + try stdout.print("\n", .{}); |
| 43 | + |
| 44 | + try stdout.print("Matrix Size Time (us) GFLOPS Memory (MB)\n", .{}); |
| 45 | + try stdout.print("------------------------------------------------------------\n", .{}); |
| 46 | + |
| 47 | + inline for (SIZES) |size| { |
| 48 | + const M = size[0]; |
| 49 | + const N = size[1]; |
| 50 | + const K = size[0]; |
| 51 | + |
| 52 | + const weights = try allocator.alloc(i8, M * K); |
| 53 | + defer allocator.free(weights); |
| 54 | + const input = try allocator.alloc(f32, K * N); |
| 55 | + defer allocator.free(input); |
| 56 | + const output = try allocator.alloc(f32, M * N); |
| 57 | + defer allocator.free(output); |
| 58 | + |
| 59 | + var prng = std.Random.DefaultPrng.init(42); |
| 60 | + const random = prng.random(); |
| 61 | + for (weights) |*w| w.* = @as(i8, @intCast(random.intRangeAtMost(i8, -1, 1))); |
| 62 | + for (input) |*i| i.* = random.float(f32); |
| 63 | + |
| 64 | + ternaryMatmul(M, N, K, weights, input, output); |
| 65 | + |
| 66 | + const ITERATIONS = 10; |
| 67 | + var timer = try std.time.Timer.start(); |
| 68 | + |
| 69 | + for (0..ITERATIONS) |_| { |
| 70 | + ternaryMatmul(M, N, K, weights, input, output); |
| 71 | + } |
| 72 | + |
| 73 | + const elapsed_ns = timer.read(); |
| 74 | + const elapsed_us = @as(f64, @floatFromInt(elapsed_ns)) / 1000.0 / @as(f64, ITERATIONS); |
| 75 | + const flops = 2.0 * @as(f64, M) * @as(f64, N) * @as(f64, K); |
| 76 | + const gflops = flops / elapsed_us / 1000.0; |
| 77 | + const memory_mb = @as(f64, @floatFromInt(M * K + K * N * 4 + M * N * 4)) / 1024.0 / 1024.0; |
| 78 | + |
| 79 | + try stdout.print("{d:5} x {d:5} {d:12.0} {d:12.2} {d:12.2}\n", .{ M, N, elapsed_us, gflops, memory_mb }); |
| 80 | + } |
| 81 | + |
| 82 | + try stdout.print("\nKOSCHEI IS IMMORTAL | GOLDEN CHAIN IS CLOSED\n", .{}); |
| 83 | +} |
0 commit comments