|
| 1 | +# Ternary Matrix Multiplication Specification |
| 2 | +# BitNet-style {-1, 0, +1} weights for 20x memory reduction |
| 3 | +# φ² + 1/φ² = 3 | KOSCHEI IS IMMORTAL |
| 4 | + |
| 5 | +name: ternary_matmul |
| 6 | +version: "1.0.0" |
| 7 | +language: zig |
| 8 | +module: ternary_matmul |
| 9 | + |
| 10 | +description: | |
| 11 | + Ternary matrix-vector multiplication for neural network inference. |
| 12 | + Uses 2-bit encoding: 00=0, 01=+1, 10=-1, 11=reserved. |
| 13 | + 4 trits packed per byte (TritPack4). |
| 14 | + SIMD-optimized using AVX2/AVX-512 vectors. |
| 15 | + |
| 16 | +types: |
| 17 | + TritWeight: |
| 18 | + description: "Single ternary weight {-1, 0, +1}" |
| 19 | + fields: |
| 20 | + value: Int |
| 21 | + encoding: |
| 22 | + ZERO: 0b00 |
| 23 | + PLUS_ONE: 0b01 |
| 24 | + MINUS_ONE: 0b10 |
| 25 | + RESERVED: 0b11 |
| 26 | + |
| 27 | + TritPack4: |
| 28 | + description: "4 ternary weights packed in 1 byte" |
| 29 | + fields: |
| 30 | + packed: Int |
| 31 | + width: 8 |
| 32 | + |
| 33 | + TernaryMatrix: |
| 34 | + description: "Packed ternary weight matrix" |
| 35 | + fields: |
| 36 | + data: List<Int> |
| 37 | + rows: Int |
| 38 | + cols: Int |
| 39 | + cols_packed: Int |
| 40 | + |
| 41 | + MemoryStats: |
| 42 | + description: "Memory usage statistics" |
| 43 | + fields: |
| 44 | + float32_bytes: Int |
| 45 | + ternary_bytes: Int |
| 46 | + compression_ratio: Float |
| 47 | + |
| 48 | +behaviors: |
| 49 | + - name: trit_to_float |
| 50 | + given: TritWeight with 2-bit encoding |
| 51 | + when: Converting to float for computation |
| 52 | + then: Returns -1.0, 0.0, or +1.0 |
| 53 | + |
| 54 | + - name: float_to_trit |
| 55 | + given: Float value |
| 56 | + when: Quantizing to ternary |
| 57 | + then: Returns nearest trit (threshold at 0.5) |
| 58 | + |
| 59 | + - name: pack_trits |
| 60 | + given: 4 TritWeight values |
| 61 | + when: Packing for storage |
| 62 | + then: Returns single byte with 4 trits |
| 63 | + |
| 64 | + - name: unpack_trits |
| 65 | + given: Packed byte |
| 66 | + when: Extracting for computation |
| 67 | + then: Returns 4 TritWeight values |
| 68 | + |
| 69 | + - name: ternary_matvec |
| 70 | + given: Packed weight matrix and input vector |
| 71 | + when: Computing matrix-vector product |
| 72 | + then: Output vector with dot products (no multiplications, only add/sub) |
| 73 | + |
| 74 | + - name: simd_ternary_matvec |
| 75 | + given: Packed weights, input vector, SIMD width 8 |
| 76 | + when: Computing with AVX2 vectors |
| 77 | + then: 8x speedup via vectorized sign lookup |
| 78 | + |
| 79 | + - name: simd_ternary_matvec_16 |
| 80 | + given: Packed weights, input vector, SIMD width 16 |
| 81 | + when: Computing with AVX-512 vectors |
| 82 | + then: 16x speedup via wider vectors |
| 83 | + |
| 84 | + - name: batch_ternary_matvec |
| 85 | + given: Packed weights, input vector, batch of 4 rows |
| 86 | + when: Processing multiple output rows |
| 87 | + then: 4 rows computed in parallel |
| 88 | + |
| 89 | + - name: compute_memory_stats |
| 90 | + given: Matrix dimensions (rows, cols) |
| 91 | + when: Analyzing memory savings |
| 92 | + then: Returns compression ratio (~20x vs float32) |
| 93 | + |
| 94 | +optimizations: |
| 95 | + - name: sign_lookup_table |
| 96 | + description: "LUT for trit→sign: [0.0, 1.0, -1.0, 0.0]" |
| 97 | + |
| 98 | + - name: no_multiplication |
| 99 | + description: "y += sign * x becomes y += x or y -= x based on sign" |
| 100 | + |
| 101 | + - name: cache_friendly |
| 102 | + description: "Row-major layout, sequential memory access" |
| 103 | + |
| 104 | + - name: simd_reduction |
| 105 | + description: "@reduce(.Add, vec) for horizontal sum" |
| 106 | + |
| 107 | +benchmarks: |
| 108 | + - name: throughput |
| 109 | + metric: "GFLOPS equivalent" |
| 110 | + target: ">100 GFLOPS on AVX2" |
| 111 | + |
| 112 | + - name: memory_bandwidth |
| 113 | + metric: "GB/s" |
| 114 | + target: "Near memory bandwidth limit" |
| 115 | + |
| 116 | + - name: latency |
| 117 | + metric: "ns per row" |
| 118 | + target: "<100ns for 4096-dim row" |
| 119 | + |
| 120 | +integration: |
| 121 | + - target: bytecode_vm |
| 122 | + description: "OP_TERNARY_MATVEC opcode" |
| 123 | + |
| 124 | + - target: model_loader |
| 125 | + description: "Load .tri model files" |
| 126 | + |
| 127 | + - target: inference_pipeline |
| 128 | + description: "Replace float matmul in forward pass" |
0 commit comments