1- use crate :: components:: weight :: Weight ;
1+ use crate :: { components:: quant :: vec_dot , format :: gguf :: GgufType , tensor :: Tensor } ;
22use rayon:: prelude:: * ;
33
4- // matrix-vector multiplication
5- // x shape: (in_channels,)
6- // w shape: (out_channels, in_channels) stored in row-major order (like safetensors)
7- // out shape: (out_channels,)
8- #[ allow( dead_code) ]
9- pub fn naive_matmul < W : Weight > ( out : & mut [ f32 ] , x : & [ f32 ] , weight : & [ W ] ) {
10- for i in 0 ..out. len ( ) {
11- let mut sum = 0.0_f32 ;
12-
13- for k in 0 ..x. len ( ) {
14- sum += x[ k] * weight[ k + i * x. len ( ) ] . to_f32 ( ) ;
15- }
16- out[ i] = sum;
4+ pub fn bf16_to_f32 ( n : u16 ) -> f32 {
5+ f32:: from_bits ( ( n as u32 ) << 16 )
6+ }
7+
8+ pub trait FloatType : Copy + Sync + Send {
9+ fn to_f32 ( self ) -> f32 ;
10+ }
11+
12+ impl FloatType for f32 {
13+ fn to_f32 ( self ) -> f32 {
14+ self
1715 }
1816}
1917
20- // parallel matmul
21- pub fn matmul < W : Weight > ( out : & mut [ f32 ] , x : & [ f32 ] , weight : & [ W ] ) {
18+ impl FloatType for u16 {
19+ // f32: [1 sign] [8 exponent] [23 mantissa] = 32 bits
20+ // bf16: [1 sign] [8 exponent] [ 7 mantissa] = 16 bits
21+ // To convert BF16 → f32, put these 16 bits in the upper 16 bits of a 32-bit word
22+ // and zero-fill the bottom
23+ fn to_f32 ( self ) -> f32 {
24+ bf16_to_f32 ( self )
25+ }
26+ }
27+
28+ pub fn matmul_gguf ( out : & mut [ f32 ] , x : & [ f32 ] , weight : & [ u8 ] , dtype : GgufType , n_cols : usize ) {
29+ let row_bytes = dtype. row_bytes ( n_cols) ;
30+
31+ out. par_iter_mut ( ) . enumerate ( ) . for_each ( |( i, o) | {
32+ let row = & weight[ i * row_bytes..( i + 1 ) * row_bytes] ;
33+ * o = vec_dot ( row, x, dtype) ;
34+ } ) ;
35+ }
36+
37+ pub fn matmul_float < W : FloatType > ( out : & mut [ f32 ] , x : & [ f32 ] , weight : & [ W ] ) {
2238 out. par_iter_mut ( ) . enumerate ( ) . for_each ( |( i, o) | {
2339 // i = row index, o = &mut f32 (that output element)
2440 let in_channels: usize = x. len ( ) ;
@@ -30,6 +46,15 @@ pub fn matmul<W: Weight>(out: &mut [f32], x: &[f32], weight: &[W]) {
3046 } ) ;
3147}
3248
49+ /// Matrix-vector multiply
50+ pub fn matmul ( out : & mut [ f32 ] , x : & [ f32 ] , weight : Tensor ) {
51+ match weight {
52+ Tensor :: F32 ( w) => matmul_float :: < f32 > ( out, x, w) ,
53+ Tensor :: BF16 ( w) => matmul_float :: < u16 > ( out, x, w) ,
54+ Tensor :: Quantized { data, dtype } => matmul_gguf ( out, x, data, dtype, x. len ( ) ) ,
55+ }
56+ }
57+
3358#[ cfg( test) ]
3459mod tests {
3560 use super :: * ;
@@ -41,40 +66,31 @@ mod tests {
4166
4267 #[ test]
4368 fn test_matmul_2x3 ( ) {
44- // weight: 2 rows × 3 cols (in_channels=3, out has 2 elements)
45- // row 0: [1, 2, 3]
46- // row 1: [4, 5, 6]
4769 let weight = vec ! [ 1.0_f32 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
4870 let x = vec ! [ 1.0_f32 , 2.0 , 3.0 ] ;
4971 let mut out = vec ! [ 0.0_f32 ; 2 ] ;
50- let mut parallel_out = vec ! [ 0.0_f32 ; 2 ] ;
5172
52- naive_matmul ( & mut out, & x, & weight) ;
53- matmul ( & mut parallel_out, & x, & weight) ;
73+ matmul ( & mut out, & x, Tensor :: F32 ( & weight) ) ;
5474
5575 assert ! ( approx( out[ 0 ] , 14.0 ) ) ;
5676 assert ! ( approx( out[ 1 ] , 32.0 ) ) ;
57-
58- // check parallel output
59- assert ! ( approx( parallel_out[ 0 ] , 14.0 ) ) ;
60- assert ! ( approx( parallel_out[ 1 ] , 32.0 ) ) ;
6177 }
6278
6379 #[ test]
6480 fn test_matmul_bf16 ( ) {
65- // Helper: convert f32 to bf16 (top 16 bits)
6681 fn f32_to_bf16 ( x : f32 ) -> u16 {
6782 ( x. to_bits ( ) >> 16 ) as u16
6883 }
6984
70- let weight: Vec < u16 > = vec ! [ 1.0_f32 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ]
85+ let weight: Vec < u16 > = [ 1.0_f32 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ]
7186 . iter ( )
7287 . map ( |& v| f32_to_bf16 ( v) )
7388 . collect ( ) ;
7489 let x = vec ! [ 1.0_f32 , 2.0 , 3.0 ] ;
7590 let mut out = vec ! [ 0.0_f32 ; 2 ] ;
7691
77- matmul ( & mut out, & x, & weight) ;
92+ matmul ( & mut out, & x, Tensor :: BF16 ( & weight) ) ;
93+
7894 assert ! ( approx( out[ 0 ] , 14.0 ) ) ;
7995 assert ! ( approx( out[ 1 ] , 32.0 ) ) ;
8096 }
0 commit comments