@@ -81,19 +81,54 @@ float16to32 (bfloat16_bits f16)
8181 return f32 .v ;
8282}
8383
84+ #define SBGEMM_LARGEST 256
85+
8486int
8587main (int argc , char * argv [])
8688{
8789 blasint m , n , k ;
8890 int i , j , l ;
8991 blasint x , y ;
9092 int ret = 0 ;
91- int loop = 100 ;
93+ int loop = SBGEMM_LARGEST ;
9294 char transA = 'N' , transB = 'N' ;
9395 float alpha = 1.0 , beta = 0.0 ;
9496
9597 for (x = 0 ; x <= loop ; x ++ )
9698 {
99+ if ((x > 100 ) && (x != SBGEMM_LARGEST )) continue ;
100+ m = k = n = x ;
101+ float * A = (float * )malloc (m * k * sizeof (FLOAT ));
102+ float * B = (float * )malloc (k * n * sizeof (FLOAT ));
103+ float * C = (float * )malloc (m * n * sizeof (FLOAT ));
104+ bfloat16_bits * AA = (bfloat16_bits * )malloc (m * k * sizeof (bfloat16_bits ));
105+ bfloat16_bits * BB = (bfloat16_bits * )malloc (k * n * sizeof (bfloat16_bits ));
106+ float * DD = (float * )malloc (m * n * sizeof (FLOAT ));
107+ float * CC = (float * )malloc (m * n * sizeof (FLOAT ));
108+ if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
109+ (DD == NULL ) || (CC == NULL ))
110+ return 1 ;
111+ bfloat16 atmp ,btmp ;
112+ blasint one = 1 ;
113+
114+ for (j = 0 ; j < m ; j ++ )
115+ {
116+ for (i = 0 ; i < k ; i ++ )
117+ {
118+ A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
119+ sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
120+ AA [j * k + i ].v = atmp ;
121+ }
122+ }
123+ for (j = 0 ; j < n ; j ++ )
124+ {
125+ for (i = 0 ; i < k ; i ++ )
126+ {
127+ B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
128+ sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
129+ BB [j * k + i ].v = btmp ;
130+ }
131+ }
97132 for (y = 0 ; y < 4 ; y ++ )
98133 {
99134 if ((y == 0 ) || (y == 2 )) {
@@ -106,40 +141,19 @@ main (int argc, char *argv[])
106141 } else {
107142 transB = 'T' ;
108143 }
109- m = k = n = x ;
110- float A [m * k ];
111- float B [k * n ];
112- float C [m * n ];
113- bfloat16_bits AA [m * k ], BB [k * n ];
114- float DD [m * n ], CC [m * n ];
115- bfloat16 atmp ,btmp ;
116- blasint one = 1 ;
117144
118- for (j = 0 ; j < m ; j ++ )
119- {
120- for (i = 0 ; i < m ; i ++ )
121- {
122- A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
123- B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
124- C [j * k + i ] = 0 ;
125- sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
126- sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
127- AA [j * k + i ].v = atmp ;
128- BB [j * k + i ].v = btmp ;
129- CC [j * k + i ] = 0 ;
130- DD [j * k + i ] = 0 ;
131- }
132- }
145+ memset (CC , 0 , m * n * sizeof (FLOAT ));
146+ memset (DD , 0 , m * n * sizeof (FLOAT ));
147+ memset (C , 0 , m * n * sizeof (FLOAT ));
148+
133149 SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
134150 & m , B , & k , & beta , C , & m );
135151 SBGEMM (& transA , & transB , & m , & n , & k , & alpha , (bfloat16 * ) AA ,
136152 & m , (bfloat16 * )BB , & k , & beta , CC , & m );
153+
137154 for (i = 0 ; i < n ; i ++ )
138155 for (j = 0 ; j < m ; j ++ )
139- if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
140- ret ++ ;
141- for (i = 0 ; i < n ; i ++ )
142- for (j = 0 ; j < m ; j ++ )
156+ {
143157 for (l = 0 ; l < k ; l ++ )
144158 if (transA == 'N' && transB == 'N' )
145159 {
@@ -158,11 +172,19 @@ main (int argc, char *argv[])
158172 DD [i * m + j ] +=
159173 float16to32 (AA [k * j + l ]) * float16to32 (BB [i + l * n ]);
160174 }
161- for ( i = 0 ; i < n ; i ++ )
162- for ( j = 0 ; j < m ; j ++ )
163- if (CC [i * m + j ] != DD [i * m + j ])
175+ if ( fabs ( CC [ i * m + j ] - C [ i * m + j ]) > 1.0 )
176+ ret ++ ;
177+ if (fabs ( CC [i * m + j ] - DD [i * m + j ]) > 1.0 )
164178 ret ++ ;
179+ }
165180 }
181+ free (A );
182+ free (B );
183+ free (C );
184+ free (AA );
185+ free (BB );
186+ free (DD );
187+ free (CC );
166188 }
167189
168190 if (ret != 0 )
0 commit comments