@@ -6,17 +6,24 @@ const gguf = @import("gguf_reader.zig");
66const inference = @import ("gguf_inference.zig" );
77const transformer = @import ("gguf_transformer.zig" );
88const simd = @import ("simd_matmul.zig" );
9+ const ternary = @import ("ternary_weights.zig" );
910
1011pub const FullModel = struct {
1112 allocator : std.mem.Allocator ,
1213 reader : gguf.GGUFReader ,
1314 config : inference.ModelConfig ,
1415
16+ // Ternary mode flag
17+ use_ternary : bool = false ,
18+
1519 // Core weights
1620 token_embedding : []f32 ,
1721 output_weight : []f32 ,
1822 output_norm : []f32 ,
1923
24+ // Ternary weights (optional - for BitNet models)
25+ ternary_output_weight : ? []u8 = null ,
26+
2027 // Per-layer weights
2128 layers : []LayerWeights ,
2229
@@ -48,6 +55,15 @@ pub const FullModel = struct {
4855 w_gate : []f32 ,
4956 w_up : []f32 ,
5057 w_down : []f32 ,
58+
59+ // Ternary versions (optional)
60+ ternary_wq : ? []u8 = null ,
61+ ternary_wk : ? []u8 = null ,
62+ ternary_wv : ? []u8 = null ,
63+ ternary_wo : ? []u8 = null ,
64+ ternary_w_gate : ? []u8 = null ,
65+ ternary_w_up : ? []u8 = null ,
66+ ternary_w_down : ? []u8 = null ,
5167 };
5268
5369 pub fn init (allocator : std.mem.Allocator , path : []const u8 ) ! FullModel {
@@ -263,6 +279,64 @@ pub const FullModel = struct {
263279 }
264280 }
265281
282+ /// Enable ternary mode - quantize all weights to {-1, 0, +1}
283+ /// This provides 16x memory savings and faster inference on CPU
284+ pub fn enableTernaryMode (self : * FullModel ) ! void {
285+ if (self .use_ternary ) return ; // Already enabled
286+
287+ std .debug .print ("\n Converting to ternary weights...\n " , .{});
288+ const stats = ternary .MemoryStats .calculate (self .countParameters ());
289+ stats .print ();
290+
291+ // Convert output weights
292+ const threshold = ternary .calculateThreshold (self .output_weight );
293+ self .ternary_output_weight = try ternary .quantizeToTernary (self .allocator , self .output_weight , threshold );
294+
295+ // Convert layer weights
296+ for (self .layers ) | * layer | {
297+ const t_wq = ternary .calculateThreshold (layer .wq );
298+ const t_wk = ternary .calculateThreshold (layer .wk );
299+ const t_wv = ternary .calculateThreshold (layer .wv );
300+ const t_wo = ternary .calculateThreshold (layer .wo );
301+ const t_gate = ternary .calculateThreshold (layer .w_gate );
302+ const t_up = ternary .calculateThreshold (layer .w_up );
303+ const t_down = ternary .calculateThreshold (layer .w_down );
304+
305+ layer .ternary_wq = try ternary .quantizeToTernary (self .allocator , layer .wq , t_wq );
306+ layer .ternary_wk = try ternary .quantizeToTernary (self .allocator , layer .wk , t_wk );
307+ layer .ternary_wv = try ternary .quantizeToTernary (self .allocator , layer .wv , t_wv );
308+ layer .ternary_wo = try ternary .quantizeToTernary (self .allocator , layer .wo , t_wo );
309+ layer .ternary_w_gate = try ternary .quantizeToTernary (self .allocator , layer .w_gate , t_gate );
310+ layer .ternary_w_up = try ternary .quantizeToTernary (self .allocator , layer .w_up , t_up );
311+ layer .ternary_w_down = try ternary .quantizeToTernary (self .allocator , layer .w_down , t_down );
312+ }
313+
314+ self .use_ternary = true ;
315+ std .debug .print ("Ternary mode enabled!\n " , .{});
316+ }
317+
318+ /// Count total parameters
319+ fn countParameters (self : * const FullModel ) usize {
320+ var count : usize = self .token_embedding .len + self .output_weight .len + self .output_norm .len ;
321+ for (self .layers ) | layer | {
322+ count += layer .wq .len + layer .wk .len + layer .wv .len + layer .wo .len ;
323+ count += layer .w_gate .len + layer .w_up .len + layer .w_down .len ;
324+ count += layer .attn_norm .len + layer .ffn_norm .len ;
325+ }
326+ return count ;
327+ }
328+
329+ /// Matrix-vector multiply with automatic ternary/float selection
330+ fn matVecAuto (self : * const FullModel , output : []f32 , weights_f32 : []const f32 , weights_ternary : ? []const u8 , input : []const f32 , rows : usize , cols : usize ) void {
331+ if (self .use_ternary ) {
332+ if (weights_ternary ) | tw | {
333+ ternary .ternaryMatVec (output , tw , input , rows , cols );
334+ return ;
335+ }
336+ }
337+ inference .matVec (output , weights_f32 , input , rows , cols );
338+ }
339+
266340 // Forward pass for single token - OPTIMIZED with pre-allocated buffers
267341 pub fn forward (self : * FullModel , token : u32 , pos : usize ) ! []f32 {
268342 const hidden_size = self .config .hidden_size ;
@@ -280,9 +354,9 @@ pub const FullModel = struct {
280354 // Final RMS norm
281355 inference .rmsNorm (self .buf_temp , self .buf_hidden , self .output_norm , self .config .rms_norm_eps );
282356
283- // Output projection (only allocation is for return value)
357+ // Output projection (only allocation is for return value) - with ternary support
284358 const logits = try self .allocator .alloc (f32 , self .config .vocab_size );
285- inference . matVec (logits , self .output_weight , self .buf_temp , self .config .vocab_size , hidden_size );
359+ self . matVecAuto (logits , self .output_weight , self . ternary_output_weight , self .buf_temp , self .config .vocab_size , hidden_size );
286360
287361 return logits ;
288362 }
@@ -420,10 +494,10 @@ pub const FullModel = struct {
420494 // Pre-attention norm (use buf_normed)
421495 inference .rmsNorm (self .buf_normed , input , layer .attn_norm , rms_eps );
422496
423- // Compute Q, K, V (use buf_q, buf_k, buf_v)
424- inference . matVec (self .buf_q , layer .wq , self .buf_normed , num_heads * head_dim , hidden_size );
425- inference . matVec (self .buf_k , layer .wk , self .buf_normed , num_kv_heads * head_dim , hidden_size );
426- inference . matVec (self .buf_v , layer .wv , self .buf_normed , num_kv_heads * head_dim , hidden_size );
497+ // Compute Q, K, V (use buf_q, buf_k, buf_v) - with ternary support
498+ self . matVecAuto (self .buf_q , layer .wq , layer . ternary_wq , self .buf_normed , num_heads * head_dim , hidden_size );
499+ self . matVecAuto (self .buf_k , layer .wk , layer . ternary_wk , self .buf_normed , num_kv_heads * head_dim , hidden_size );
500+ self . matVecAuto (self .buf_v , layer .wv , layer . ternary_wv , self .buf_normed , num_kv_heads * head_dim , hidden_size );
427501
428502 // Apply RoPE
429503 for (0.. num_heads ) | h | {
@@ -471,8 +545,8 @@ pub const FullModel = struct {
471545 }
472546 }
473547
474- // Output projection (use buf_attn_proj)
475- inference . matVec (self .buf_attn_proj , layer .wo , self .buf_attn_out , hidden_size , num_heads * head_dim );
548+ // Output projection (use buf_attn_proj) - with ternary support
549+ self . matVecAuto (self .buf_attn_proj , layer .wo , layer . ternary_wo , self .buf_attn_out , hidden_size , num_heads * head_dim );
476550
477551 // Residual
478552 for (0.. hidden_size ) | i | {
@@ -482,17 +556,17 @@ pub const FullModel = struct {
482556 // Pre-FFN norm
483557 inference .rmsNorm (self .buf_normed , output , layer .ffn_norm , rms_eps );
484558
485- // FFN with SwiGLU (use buf_ffn_gate, buf_ffn_up)
486- inference . matVec (self .buf_ffn_gate , layer .w_gate , self .buf_normed , intermediate_size , hidden_size );
487- inference . matVec (self .buf_ffn_up , layer .w_up , self .buf_normed , intermediate_size , hidden_size );
559+ // FFN with SwiGLU (use buf_ffn_gate, buf_ffn_up) - with ternary support
560+ self . matVecAuto (self .buf_ffn_gate , layer .w_gate , layer . ternary_w_gate , self .buf_normed , intermediate_size , hidden_size );
561+ self . matVecAuto (self .buf_ffn_up , layer .w_up , layer . ternary_w_up , self .buf_normed , intermediate_size , hidden_size );
488562
489563 // SwiGLU
490564 for (0.. intermediate_size ) | i | {
491565 self .buf_ffn_gate [i ] = inference .silu (self .buf_ffn_gate [i ]) * self .buf_ffn_up [i ];
492566 }
493567
494- // Down projection (use buf_ffn_out)
495- inference . matVec (self .buf_ffn_out , layer .w_down , self .buf_ffn_gate , hidden_size , intermediate_size );
568+ // Down projection (use buf_ffn_out) - with ternary support
569+ self . matVecAuto (self .buf_ffn_out , layer .w_down , layer . ternary_w_down , self .buf_ffn_gate , hidden_size , intermediate_size );
496570
497571 // Residual
498572 for (0.. hidden_size ) | i | {
0 commit comments