@@ -568,6 +568,16 @@ pub const TernaryKVCache = struct {
568568 // Quantization threshold (fraction of max)
569569 threshold_ratio : f32 ,
570570
571+ // Quantization mode
572+ quant_mode : QuantMode ,
573+
574+ pub const QuantMode = enum {
575+ fixed_threshold , // Original: threshold = max * ratio
576+ adaptive_mean , // Adaptive: threshold = mean(abs) * ratio
577+ no_threshold , // All non-zero values quantized (best accuracy)
578+ rms_scale , // Use RMS for scale (better for attention)
579+ };
580+
571581 pub fn init (
572582 allocator : std.mem.Allocator ,
573583 num_kv_heads : usize ,
@@ -588,7 +598,8 @@ pub const TernaryKVCache = struct {
588598 .k_scales = try allocator .alloc (f32 , max_seq_len ),
589599 .v_scales = try allocator .alloc (f32 , max_seq_len ),
590600 .seq_len = 0 ,
591- .threshold_ratio = 0.3 , // Values < 30% of max become 0
601+ .threshold_ratio = 0.0 ,
602+ .quant_mode = .rms_scale , // RMS-based scaling for better accuracy
592603 };
593604 }
594605
@@ -625,32 +636,50 @@ pub const TernaryKVCache = struct {
625636 }
626637
627638 /// Quantize f32 vector to ternary packed bytes
639+ /// Returns scale factor for dequantization
628640 fn quantizeVector (self : * const TernaryKVCache , dst : []u8 , src : []const f32 ) f32 {
629- // Find max absolute value for scale
641+ // Calculate statistics
630642 var max_abs : f32 = 0.0 ;
643+ var sum_abs : f32 = 0.0 ;
644+ var sum_sq : f32 = 0.0 ;
631645 for (src ) | v | {
632646 const abs_v = @abs (v );
633647 if (abs_v > max_abs ) max_abs = abs_v ;
648+ sum_abs += abs_v ;
649+ sum_sq += v * v ;
634650 }
635651
636652 if (max_abs == 0.0 ) {
637653 @memset (dst , 0 );
638654 return 1.0 ;
639655 }
640656
641- const threshold = max_abs * self .threshold_ratio ;
642- const inv_scale = 1.0 / max_abs ;
657+ const n = @as (f32 , @floatFromInt (src .len ));
658+ const mean_abs = sum_abs / n ;
659+ const rms = @sqrt (sum_sq / n );
660+
661+ // Calculate scale and threshold based on mode
662+ const scale : f32 = switch (self .quant_mode ) {
663+ .fixed_threshold , .no_threshold , .adaptive_mean = > max_abs ,
664+ .rms_scale = > rms * 1.5 , // RMS * sqrt(2) approximates max for normal distribution
665+ };
666+
667+ const threshold : f32 = switch (self .quant_mode ) {
668+ .fixed_threshold = > max_abs * self .threshold_ratio ,
669+ .adaptive_mean = > mean_abs * self .threshold_ratio ,
670+ .no_threshold = > 0.0 ,
671+ .rms_scale = > rms * 0.5 , // Half RMS as threshold
672+ };
643673
644674 // Pack 4 values per byte
645675 var byte_idx : usize = 0 ;
646676 var bit_pos : u3 = 0 ;
647677 var current_byte : u8 = 0 ;
648678
649679 for (src ) | v | {
650- const normalized = v * inv_scale ;
651- const trit : u2 = if (normalized > threshold * inv_scale )
680+ const trit : u2 = if (v > threshold )
652681 0b01 // +1
653- else if (normalized < - threshold * inv_scale )
682+ else if (v < - threshold )
654683 0b10 // -1
655684 else
656685 0b00 ; // 0
@@ -670,7 +699,7 @@ pub const TernaryKVCache = struct {
670699 dst [byte_idx ] = current_byte ;
671700 }
672701
673- return max_abs ;
702+ return scale ;
674703 }
675704
676705 /// Compute dot product between f32 query and ternary key (no full dequantization)
@@ -798,6 +827,30 @@ pub const TernaryKVCache = struct {
798827 self .seq_len = 0 ;
799828 }
800829
830+ /// Set quantization mode for accuracy tuning
831+ pub fn setQuantMode (self : * TernaryKVCache , mode : QuantMode , threshold : f32 ) void {
832+ self .quant_mode = mode ;
833+ self .threshold_ratio = threshold ;
834+ }
835+
836+ /// Use high-accuracy mode (no threshold, all values quantized)
837+ pub fn setHighAccuracy (self : * TernaryKVCache ) void {
838+ self .quant_mode = .no_threshold ;
839+ self .threshold_ratio = 0.0 ;
840+ }
841+
842+ /// Use balanced mode (small threshold for noise reduction)
843+ pub fn setBalanced (self : * TernaryKVCache ) void {
844+ self .quant_mode = .adaptive_mean ;
845+ self .threshold_ratio = 0.1 ;
846+ }
847+
848+ /// Use high-compression mode (aggressive threshold)
849+ pub fn setHighCompression (self : * TernaryKVCache ) void {
850+ self .quant_mode = .fixed_threshold ;
851+ self .threshold_ratio = 0.3 ;
852+ }
853+
801854 /// Memory usage in bytes
802855 pub fn memoryUsage (self : * const TernaryKVCache ) usize {
803856 return self .k_cache .len + self .v_cache .len +
0 commit comments