1313// See the License for the specific language governing permissions and
1414// limitations under the License.
1515
16- // BRGeMM dispatch. Included from matmul-inl.h inside gcpp::HWY_NAMESPACE.
16+ // BRGeMM dispatch for BF16 MatMul on Intel AMX/AVX-512.
17+
18+ #include < stddef.h>
19+ #include < stdint.h>
20+
21+ #include < algorithm>
22+ #include < utility>
23+ #include < vector>
24+
25+ #include " ops/brgemm.h"
26+ #include " ops/matmul.h"
27+ #include " util/mat.h"
28+ #include " util/threading_context.h"
29+ #include " util/zones.h"
30+ #include " hwy/base.h"
31+
32+ // Include guard for (potentially) SIMD code.
33+ #if defined(THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE) == defined(HWY_TARGET_TOGGLE)
34+ #ifdef THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE
35+ #undef THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE
36+ #else
37+ #define THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE
38+ #endif
39+
40+ #include " hwy/highway.h"
41+
42+ HWY_BEFORE_NAMESPACE ();
43+ namespace gcpp {
44+ namespace HWY_NAMESPACE {
45+ namespace hn = hwy::HWY_NAMESPACE;
1746
1847#if GEMMA_ONEDNN_BRGEMM
1948
@@ -55,8 +84,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
5584
5685 ke.K_blk = cfg.K_blk ;
5786 ke.N_blk = cfg.N_blk ;
58- ke.M_blk =
59- static_cast <int64_t >(std::min (static_cast <size_t >(cfg.M_blk ), M));
87+ ke.M_blk = std::min (cfg.M_blk , M);
6088
6189 ke.M_tail = M % ke.M_blk ;
6290 ke.N_tail = N % ke.N_blk ;
@@ -97,10 +125,13 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
97125 ke.m_sizes [1 ] = ke.M_tail ? ke.M_tail : ke.M_blk ;
98126 ke.n_sizes [0 ] = ke.N_blk ;
99127 ke.n_sizes [1 ] = ke.N_tail ? ke.N_tail : ke.N_blk ;
100- const int64_t ldb_for[2 ] = {ke.N_blk , ke.N_tail ? ke.N_tail : ke.N_blk };
101- const int64_t ldc_for[2 ] = {ke.N_blk , ke.N_tail ? ke.N_tail : ke.N_blk };
128+ const int64_t ldb_for[2 ] = {static_cast <int64_t >(ke.N_blk ),
129+ static_cast <int64_t >(ke.N_tail ? ke.N_tail : ke.N_blk )};
130+ const int64_t ldc_for[2 ] = {static_cast <int64_t >(ke.N_blk ),
131+ static_cast <int64_t >(ke.N_tail ? ke.N_tail : ke.N_blk )};
102132
103- // Create brgemm kernels for each (M-tile, N-tile) variant.
133+ // Create brgemm kernels for full/tail M and N tile sizes.
134+ // mi=0 is the full M tile, mi=1 is the M-tail; likewise for ni and N.
104135 size_t max_sp = 0 ;
105136 for (int mi = 0 ; mi < 2 ; ++mi) {
106137 for (int ni = 0 ; ni < 2 ; ++ni) {
@@ -109,22 +140,25 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
109140 if (mi == 0 && ke.M_full_tiles == 0 ) continue ;
110141 if (ni == 0 && ke.N_full_tiles == 0 ) continue ;
111142
112- const int64_t ms = ke.m_sizes [mi];
113- const int64_t ns = ke.n_sizes [ni];
143+ const int64_t ms = static_cast < int64_t >( ke.m_sizes [mi]) ;
144+ const int64_t ns = static_cast < int64_t >( ke.n_sizes [ni]) ;
114145
115146 if (ke.K_chunks > 0 ) {
116- if (!MakeBrgemm (ke.brg_first_all [mi][ni], ms, ns, ke.K_blk ,
117- ke.K_super_size , ke.lda , ldb_for[ni], ldc_for[ni],
118- a_dt, b_dt, c_dt, false )) {
147+ if (!MakeBrgemm (ke.brg_first_all [mi][ni], ms, ns,
148+ static_cast <int64_t >(ke.K_blk ),
149+ static_cast <int64_t >(ke.K_super_size ), ke.lda ,
150+ ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt,
151+ false )) {
119152 return ;
120153 }
121154 max_sp = std::max (max_sp,
122155 ke.brg_first_all [mi][ni].get_scratchpad_size ());
123156 }
124157 if (ke.K_super_blocks > 1 ) {
125- if (!MakeBrgemm (ke.brg_full [mi][ni], ms, ns, ke.K_blk ,
126- ke.batch_full , ke.lda , ldb_for[ni], ldc_for[ni],
127- a_dt, b_dt, c_dt, true )) {
158+ if (!MakeBrgemm (ke.brg_full [mi][ni], ms, ns,
159+ static_cast <int64_t >(ke.K_blk ),
160+ static_cast <int64_t >(ke.batch_full ), ke.lda ,
161+ ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, true )) {
128162 return ;
129163 }
130164 max_sp =
@@ -134,7 +168,8 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
134168 const bool rem_is_first = (ke.K_super_blocks == 0 );
135169 auto & target = rem_is_first ? ke.brg_first_rem [mi][ni]
136170 : ke.brg_rem [mi][ni];
137- if (!MakeBrgemm (target, ms, ns, ke.K_blk , ke.batch_rem , ke.lda ,
171+ if (!MakeBrgemm (target, ms, ns, static_cast <int64_t >(ke.K_blk ),
172+ static_cast <int64_t >(ke.batch_rem ), ke.lda ,
138173 ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt,
139174 !rem_is_first)) {
140175 return ;
@@ -143,7 +178,8 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
143178 }
144179 if (ke.K_tail > 0 ) {
145180 const bool add_c = (ke.K_chunks > 0 );
146- if (!MakeBrgemm (ke.brg_ktail [mi][ni], ms, ns, ke.K_tail , 1 , ke.lda ,
181+ if (!MakeBrgemm (ke.brg_ktail [mi][ni], ms, ns,
182+ static_cast <int64_t >(ke.K_tail ), 1 , ke.lda ,
147183 ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt,
148184 add_c)) {
149185 return ;
@@ -161,28 +197,30 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
161197 if (ni == 1 && ke.N_tail == 0 ) continue ;
162198 if (ni == 0 && ke.N_full_tiles == 0 ) continue ;
163199
164- const int64_t ns = ke.n_sizes [ni];
200+ const int64_t ns = static_cast < int64_t >( ke.n_sizes [ni]) ;
165201 if (ke.K_chunks > 0 ) {
166- const int64_t K_full = ke.K_chunks * ke.K_blk ;
202+ const int64_t K_full =
203+ static_cast <int64_t >(ke.K_chunks * ke.K_blk );
167204 try {
168205 ke.pack_B [ni] = transform (K_full, ns, pack_type::trans,
169206 ke.ldb_orig , ldb_for[ni], b_dt, b_dt);
170207 if (!ke.pack_B [ni]) return ;
171208 ke.pack_B [ni].generate ();
172- ke.blocked_B_size [ni] = ldb_for[ni] * ke.K_blk * ke.b_dt_size ;
209+ ke.blocked_B_size [ni] = static_cast <size_t >(ldb_for[ni]) *
210+ ke.K_blk * ke.b_dt_size ;
173211 } catch (...) {
174212 return ;
175213 }
176214 }
177215 if (ke.K_tail > 0 ) {
178216 try {
179217 ke.pack_B_ktail [ni] = transform (
180- ke.K_tail , ns, pack_type::trans, ke. ldb_orig , ldb_for[ni] ,
181- b_dt, b_dt);
218+ static_cast < int64_t >( ke.K_tail ) , ns, pack_type::trans,
219+ ke. ldb_orig , ldb_for[ni], b_dt, b_dt);
182220 if (!ke.pack_B_ktail [ni]) return ;
183221 ke.pack_B_ktail [ni].generate ();
184222 ke.blocked_B_ktail_size [ni] =
185- ldb_for[ni] * ke.K_tail * ke.b_dt_size ;
223+ static_cast < size_t >( ldb_for[ni]) * ke.K_tail * ke.b_dt_size ;
186224 } catch (...) {
187225 return ;
188226 }
@@ -194,55 +232,55 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
194232 for (int ni = 0 ; ni < 2 ; ++ni) {
195233 if (ni == 1 && ke.N_tail == 0 ) continue ;
196234 if (ni == 0 && ke.N_full_tiles == 0 ) continue ;
197- const int64_t cur_n = ke.n_sizes [ni];
235+ const size_t cur_n = ke.n_sizes [ni];
198236
199237 if (ke.K_chunks > 0 ) {
200238 ke.offsets_first_all [ni].resize (ke.K_super_size );
201- for (int64_t i = 0 ; i < ke.K_super_size ; ++i) {
239+ for (size_t i = 0 ; i < ke.K_super_size ; ++i) {
202240 const int64_t a_off =
203- i * ke.K_blk * static_cast < int64_t >( ke.a_dt_size );
241+ static_cast < int64_t >( i * ke.K_blk * ke.a_dt_size );
204242 const int64_t b_off =
205243 ke.need_pack
206- ? i * static_cast <int64_t >(ke.blocked_B_size [ni])
207- : i * cur_n * ke.K_blk * static_cast < int64_t >( ke.b_dt_size );
244+ ? static_cast <int64_t >(i * ke.blocked_B_size [ni])
245+ : static_cast < int64_t >( i * cur_n * ke.K_blk * ke.b_dt_size );
208246 ke.offsets_first_all [ni][i] = {a_off, b_off};
209247 }
210248 }
211249
212250 if (ke.K_super_blocks > 1 ) {
213251 ke.offsets_full [ni].resize (ke.K_super_blocks - 1 );
214- for (int64_t ks = 1 ; ks < ke.K_super_blocks ; ++ks) {
252+ for (size_t ks = 1 ; ks < ke.K_super_blocks ; ++ks) {
215253 auto & tbl = ke.offsets_full [ni][ks - 1 ];
216254 tbl.resize (ke.batch_full );
217- const int64_t k_start = ks * ke.K_super_size ;
218- for (int64_t i = 0 ; i < ke.batch_full ; ++i) {
219- const int64_t k_idx = k_start + i;
255+ const size_t k_start = ks * ke.K_super_size ;
256+ for (size_t i = 0 ; i < ke.batch_full ; ++i) {
257+ const size_t k_idx = k_start + i;
220258 const int64_t a_off =
221- k_idx * ke.K_blk * static_cast < int64_t >( ke.a_dt_size );
259+ static_cast < int64_t >( k_idx * ke.K_blk * ke.a_dt_size );
222260 const int64_t b_off =
223261 ke.need_pack
224- ? k_idx * static_cast <int64_t >(ke.blocked_B_size [ni])
225- : k_idx * cur_n * ke.K_blk *
226- static_cast < int64_t >( ke.b_dt_size );
262+ ? static_cast <int64_t >(k_idx * ke.blocked_B_size [ni])
263+ : static_cast < int64_t >( k_idx * cur_n * ke.K_blk *
264+ ke.b_dt_size );
227265 tbl[i] = {a_off, b_off};
228266 }
229267 }
230268 }
231269
232270 if (ke.K_super_rem > 0 ) {
233- const int64_t k_base = ke.K_super_blocks * ke.K_super_size ;
271+ const size_t k_base = ke.K_super_blocks * ke.K_super_size ;
234272 auto & rem_tbl = (ke.K_super_blocks == 0 ) ? ke.offsets_first_rem [ni]
235273 : ke.offsets_rem [ni];
236274 rem_tbl.resize (ke.K_super_rem );
237- for (int64_t i = 0 ; i < ke.K_super_rem ; ++i) {
238- const int64_t k_idx = k_base + i;
275+ for (size_t i = 0 ; i < ke.K_super_rem ; ++i) {
276+ const size_t k_idx = k_base + i;
239277 const int64_t a_off =
240- k_idx * ke.K_blk * static_cast < int64_t >( ke.a_dt_size );
278+ static_cast < int64_t >( k_idx * ke.K_blk * ke.a_dt_size );
241279 const int64_t b_off =
242280 ke.need_pack
243- ? k_idx * static_cast <int64_t >(ke.blocked_B_size [ni])
244- : k_idx * cur_n * ke.K_blk *
245- static_cast < int64_t >( ke.b_dt_size );
281+ ? static_cast <int64_t >(k_idx * ke.blocked_B_size [ni])
282+ : static_cast < int64_t >( k_idx * cur_n * ke.K_blk *
283+ ke.b_dt_size );
246284 rem_tbl[i] = {a_off, b_off};
247285 }
248286 }
@@ -270,7 +308,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
270308
271309 if (ke.need_pack ) {
272310 size_t total_packed = 0 ;
273- for (int64_t nt = 0 ; nt < ke.N_total_tiles ; ++nt) {
311+ for (size_t nt = 0 ; nt < ke.N_total_tiles ; ++nt) {
274312 const int ni = (nt < ke.N_full_tiles ) ? 0 : 1 ;
275313 pe.B_tile_offset [nt] = total_packed;
276314 if (ke.K_chunks > 0 )
@@ -283,13 +321,13 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
283321 uint8_t * B_packed = pe.B_packed_buf .data ();
284322 if (!B_packed) return ;
285323
286- for (int64_t nt = 0 ; nt < ke.N_total_tiles ; ++nt) {
324+ for (size_t nt = 0 ; nt < ke.N_total_tiles ; ++nt) {
287325 const int ni = (nt < ke.N_full_tiles ) ? 0 : 1 ;
288- const int64_t b_row = (nt < ke.N_full_tiles )
289- ? nt * ke.N_blk
290- : ke.N_full_tiles * ke.N_blk ;
326+ const size_t b_row = (nt < ke.N_full_tiles )
327+ ? nt * ke.N_blk
328+ : ke.N_full_tiles * ke.N_blk ;
291329 const uint8_t * B_in =
292- B_base + b_row * ke.ldb_orig * ke.b_dt_size ;
330+ B_base + b_row * static_cast < size_t >( ke.ldb_orig ) * ke.b_dt_size ;
293331
294332 try {
295333 if (ke.K_chunks > 0 ) {
@@ -320,14 +358,14 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
320358
321359 // Execute one (m, n) tile for a given K-super-block.
322360 const auto execute_tile = [&](size_t m_start, size_t n_start,
323- int64_t k_super, float * temp_C,
361+ size_t k_super, float * temp_C,
324362 uint8_t * scratch) HWY_ATTR {
325- const int64_t m_tile_idx = m_start / ke.M_blk ;
326- const int64_t n_tile_idx = n_start / ke.N_blk ;
363+ const size_t m_tile_idx = m_start / ke.M_blk ;
364+ const size_t n_tile_idx = n_start / ke.N_blk ;
327365 const int mi = (m_tile_idx < ke.M_full_tiles ) ? 0 : 1 ;
328366 const int ni = (n_tile_idx < ke.N_full_tiles ) ? 0 : 1 ;
329- const int64_t cur_m = ke.m_sizes [mi];
330- const int64_t cur_n = ke.n_sizes [ni];
367+ const size_t cur_m = ke.m_sizes [mi];
368+ const size_t cur_n = ke.n_sizes [ni];
331369
332370 const size_t real_m = (m_tile_idx < ke.M_full_tiles )
333371 ? m_tile_idx * ke.M_blk
@@ -336,16 +374,18 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
336374 ? n_tile_idx * ke.N_blk
337375 : ke.N_full_tiles * ke.N_blk ;
338376
339- const uint8_t * A_tile = A_base + real_m * ke.lda * ke.a_dt_size ;
377+ const uint8_t * A_tile =
378+ A_base + real_m * static_cast <size_t >(ke.lda ) * ke.a_dt_size ;
340379 const void * B_tile =
341380 ke.need_pack
342381 ? static_cast <const void *>(B_packed +
343382 pe.B_tile_offset [n_tile_idx])
344- : static_cast <const void *>(B_base +
345- real_n * ke.ldb_orig * ke.b_dt_size );
383+ : static_cast <const void *>(
384+ B_base +
385+ real_n * static_cast <size_t >(ke.ldb_orig ) * ke.b_dt_size );
346386
347387 float * C_tile_ptr = temp_C;
348- const int64_t k_total =
388+ const size_t k_total =
349389 ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0 );
350390
351391 if (k_super < ke.K_super_blocks ) {
@@ -379,7 +419,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
379419 ? static_cast <const void *>(B_packed +
380420 pe.B_ktail_offset [n_tile_idx])
381421 : static_cast <const void *>(
382- B_base + (real_n * ke.ldb_orig +
422+ B_base + (real_n * static_cast < size_t >( ke.ldb_orig ) +
383423 ke.K_chunks * ke.K_blk ) *
384424 ke.b_dt_size );
385425 ke.brg_ktail [mi][ni].execute (A_ktail, const_cast <void *>(B_ktail),
@@ -390,19 +430,18 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
390430 const hn::ScalableTag<float > df;
391431 const auto vscale = hn::Set (df, scale);
392432 const size_t lanes = hn::Lanes (df);
393- for (int64_t m = 0 ; m < cur_m; ++m) {
433+ for (size_t m = 0 ; m < cur_m; ++m) {
394434 TC* C_row = C.Row (real_m + m) + real_n;
395435 const float * t_row = C_tile_ptr + m * cur_n;
396436 const float * add_row = add ? add + real_n : nullptr ;
397- int64_t n = 0 ;
437+ size_t n = 0 ;
398438 if (add_row) {
399- for (; n + static_cast <int64_t >(lanes) <= cur_n;
400- n += static_cast <int64_t >(lanes)) {
439+ for (; n + lanes <= cur_n; n += lanes) {
401440 const auto v = hn::Load (df, t_row + n);
402441 const auto va = hn::Load (df, add_row + n);
403442 const auto result = hn::MulAdd (v, vscale, va);
404443 if constexpr (hwy::IsSame<TC, float >()) {
405- hn::Store (result, df, reinterpret_cast < float *>( C_row) + n);
444+ hn::Store (result, df, HWY_RCAST_ALIGNED ( float *, C_row) + n);
406445 } else {
407446 const hn::Rebind<TC, decltype (df)> dc;
408447 hn::Store (hn::DemoteTo (dc, result), dc, C_row + n);
@@ -413,12 +452,11 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
413452 C_row[n] = hwy::ConvertScalarTo<TC>(val);
414453 }
415454 } else {
416- for (; n + static_cast <int64_t >(lanes) <= cur_n;
417- n += static_cast <int64_t >(lanes)) {
455+ for (; n + lanes <= cur_n; n += lanes) {
418456 const auto v = hn::Load (df, t_row + n);
419457 const auto result = hn::Mul (v, vscale);
420458 if constexpr (hwy::IsSame<TC, float >()) {
421- hn::Store (result, df, reinterpret_cast < float *>( C_row) + n);
459+ hn::Store (result, df, HWY_RCAST_ALIGNED ( float *, C_row) + n);
422460 } else {
423461 const hn::Rebind<TC, decltype (df)> dc;
424462 hn::Store (hn::DemoteTo (dc, result), dc, C_row + n);
@@ -434,9 +472,9 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
434472 };
435473
436474 // Parallel dispatch: K-super outer, N middle, M inner (keeps B in L2).
437- const int64_t k_total_supers =
475+ const size_t k_total_supers =
438476 ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0 );
439- const int64_t k_iters = (k_total_supers > 0 ) ? k_total_supers : 1 ;
477+ const size_t k_iters = (k_total_supers > 0 ) ? k_total_supers : size_t { 1 } ;
440478
441479 const size_t num_threads = ctx.pools .MaxWorkersPerCluster ();
442480 const size_t total_n_tiles = ke.N_total_tiles ;
@@ -466,12 +504,11 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
466504 const size_t total_tc = total_m_tiles * n_tiles_in_range;
467505 float * tc_base = tbufs.EnsureTempC (total_tc);
468506
469- for (int64_t ks = 0 ; ks < k_iters; ++ks) {
507+ for (size_t ks = 0 ; ks < k_iters; ++ks) {
470508 size_t n_idx = 0 ;
471509 for (size_t nt = n_begin; nt < n_end; ++nt) {
472510 const size_t n = nt * ke.N_blk ;
473- for (int64_t mt = 0 ; mt < static_cast <int64_t >(total_m_tiles);
474- ++mt) {
511+ for (size_t mt = 0 ; mt < total_m_tiles; ++mt) {
475512 const size_t m = mt * ke.M_blk ;
476513 float * temp_C =
477514 tc_base + (mt * n_tiles_in_range + n_idx) *
@@ -485,8 +522,14 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
485522
486523 dnnl::ukernel::brgemm::release_hw_context ();
487524 auto & main_bufs = GetBRGeMMThreadBufs ();
488- main_bufs.hw_ctx_set = false ;
489525 main_bufs.hw_ctx_kernel = nullptr ;
490526}
491527
492528#endif // GEMMA_ONEDNN_BRGEMM
529+
530+ // NOLINTNEXTLINE(google-readability-namespace-comments)
531+ } // namespace HWY_NAMESPACE
532+ } // namespace gcpp
533+ HWY_AFTER_NAMESPACE ();
534+
535+ #endif // NOLINT
0 commit comments