@@ -734,3 +734,149 @@ test "simd matmul correctness" {
734734test "benchmark runs" {
735735 try runBenchmark (std .testing .allocator );
736736}
737+
738+ test "simd16_small_matrix" {
739+ const allocator = std .testing .allocator ;
740+ const rows : usize = 16 ;
741+ const cols : usize = 32 ;
742+ const cols_packed = (cols + 3 ) / 4 ;
743+
744+ const weights = try allocator .alloc (u8 , rows * cols_packed );
745+ defer allocator .free (weights );
746+ const input = try allocator .alloc (f32 , cols );
747+ defer allocator .free (input );
748+ const output = try allocator .alloc (f32 , rows );
749+ defer allocator .free (output );
750+
751+ @memset (weights , 0x55 ); // All +1
752+ for (input ) | * v | v .* = 1.0 ;
753+
754+ simdTernaryMatmulOpt16 (output , weights , input , rows , cols );
755+
756+ // Each row should sum to cols (all +1 * 1.0)
757+ for (output ) | v | {
758+ try std .testing .expect (v > 0 );
759+ }
760+ }
761+
762+ test "simd16_zero_weights" {
763+ const allocator = std .testing .allocator ;
764+ const rows : usize = 8 ;
765+ const cols : usize = 16 ;
766+ const cols_packed = (cols + 3 ) / 4 ;
767+
768+ const weights = try allocator .alloc (u8 , rows * cols_packed );
769+ defer allocator .free (weights );
770+ const input = try allocator .alloc (f32 , cols );
771+ defer allocator .free (input );
772+ const output = try allocator .alloc (f32 , rows );
773+ defer allocator .free (output );
774+
775+ @memset (weights , 0x00 ); // All zeros
776+ for (input ) | * v | v .* = 1.0 ;
777+
778+ simdTernaryMatmulOpt16 (output , weights , input , rows , cols );
779+
780+ for (output ) | v | {
781+ try std .testing .expectApproxEqAbs (v , 0.0 , 0.001 );
782+ }
783+ }
784+
785+ test "simd16_negative_weights" {
786+ const allocator = std .testing .allocator ;
787+ const rows : usize = 8 ;
788+ const cols : usize = 16 ;
789+ const cols_packed = (cols + 3 ) / 4 ;
790+
791+ const weights = try allocator .alloc (u8 , rows * cols_packed );
792+ defer allocator .free (weights );
793+ const input = try allocator .alloc (f32 , cols );
794+ defer allocator .free (input );
795+ const output = try allocator .alloc (f32 , rows );
796+ defer allocator .free (output );
797+
798+ @memset (weights , 0xAA ); // All -1
799+ for (input ) | * v | v .* = 1.0 ;
800+
801+ simdTernaryMatmulOpt16 (output , weights , input , rows , cols );
802+
803+ for (output ) | v | {
804+ try std .testing .expect (v < 0 );
805+ }
806+ }
807+
808+ test "simd16_large_matrix" {
809+ const allocator = std .testing .allocator ;
810+ const rows : usize = 256 ;
811+ const cols : usize = 512 ;
812+ const cols_packed = (cols + 3 ) / 4 ;
813+
814+ const weights = try allocator .alloc (u8 , rows * cols_packed );
815+ defer allocator .free (weights );
816+ const input = try allocator .alloc (f32 , cols );
817+ defer allocator .free (input );
818+ const output = try allocator .alloc (f32 , rows );
819+ defer allocator .free (output );
820+
821+ for (weights , 0.. ) | * w , i | w .* = @truncate (i );
822+ for (input , 0.. ) | * v , i | v .* = @as (f32 , @floatFromInt (i % 10 )) / 10.0 ;
823+
824+ simdTernaryMatmulOpt16 (output , weights , input , rows , cols );
825+
826+ // Just verify it runs without crash
827+ try std .testing .expect (output .len == rows );
828+ }
829+
830+ test "simd8_vs_simd16_equivalence" {
831+ const allocator = std .testing .allocator ;
832+ const rows : usize = 32 ;
833+ const cols : usize = 64 ;
834+ const cols_packed = (cols + 3 ) / 4 ;
835+
836+ const weights = try allocator .alloc (u8 , rows * cols_packed );
837+ defer allocator .free (weights );
838+ const input = try allocator .alloc (f32 , cols );
839+ defer allocator .free (input );
840+ const output8 = try allocator .alloc (f32 , rows );
841+ defer allocator .free (output8 );
842+ const output16 = try allocator .alloc (f32 , rows );
843+ defer allocator .free (output16 );
844+
845+ for (weights , 0.. ) | * w , i | w .* = @truncate (i * 7 + 13 );
846+ for (input , 0.. ) | * v , i | v .* = @sin (@as (f32 , @floatFromInt (i )));
847+
848+ simdTernaryMatmulOpt8 (output8 , weights , input , rows , cols );
849+ simdTernaryMatmulOpt16 (output16 , weights , input , rows , cols );
850+
851+ for (0.. rows ) | i | {
852+ try std .testing .expectApproxEqAbs (output8 [i ], output16 [i ], 0.01 );
853+ }
854+ }
855+
856+ test "decode_trit_all_values" {
857+ try std .testing .expectEqual (@as (i32 , 0 ), decodeTrit (0 ));
858+ try std .testing .expectEqual (@as (i32 , 1 ), decodeTrit (1 ));
859+ try std .testing .expectEqual (@as (i32 , -1 ), decodeTrit (2 ));
860+ try std .testing .expectEqual (@as (i32 , 0 ), decodeTrit (3 ));
861+ }
862+
863+ test "simd16_alignment" {
864+ // Test that SIMD-16 handles non-16-aligned cols
865+ const allocator = std .testing .allocator ;
866+ const rows : usize = 4 ;
867+ const cols : usize = 17 ; // Not aligned to 16
868+ const cols_packed = (cols + 3 ) / 4 ;
869+
870+ const weights = try allocator .alloc (u8 , rows * cols_packed );
871+ defer allocator .free (weights );
872+ const input = try allocator .alloc (f32 , cols );
873+ defer allocator .free (input );
874+ const output = try allocator .alloc (f32 , rows );
875+ defer allocator .free (output );
876+
877+ @memset (weights , 0x55 );
878+ for (input ) | * v | v .* = 1.0 ;
879+
880+ simdTernaryMatmulOpt16 (output , weights , input , rows , cols );
881+ try std .testing .expect (output .len == rows );
882+ }
0 commit comments