Skip to content

Commit f3a75ca

Browse files
committed
fixed dtypes and syntax divergence from codebase
1 parent 629b569 commit f3a75ca

3 files changed

Lines changed: 162 additions & 121 deletions

File tree

ops/brgemm-inl.h

Lines changed: 115 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,36 @@
1313
// See the License for the specific language governing permissions and
1414
// limitations under the License.
1515

16-
// BRGeMM dispatch. Included from matmul-inl.h inside gcpp::HWY_NAMESPACE.
16+
// BRGeMM dispatch for BF16 MatMul on Intel AMX/AVX-512.
17+
18+
#include <stddef.h>
19+
#include <stdint.h>
20+
21+
#include <algorithm>
22+
#include <utility>
23+
#include <vector>
24+
25+
#include "ops/brgemm.h"
26+
#include "ops/matmul.h"
27+
#include "util/mat.h"
28+
#include "util/threading_context.h"
29+
#include "util/zones.h"
30+
#include "hwy/base.h"
31+
32+
// Include guard for (potentially) SIMD code.
33+
#if defined(THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE) == defined(HWY_TARGET_TOGGLE)
34+
#ifdef THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE
35+
#undef THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE
36+
#else
37+
#define THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE
38+
#endif
39+
40+
#include "hwy/highway.h"
41+
42+
HWY_BEFORE_NAMESPACE();
43+
namespace gcpp {
44+
namespace HWY_NAMESPACE {
45+
namespace hn = hwy::HWY_NAMESPACE;
1746

1847
#if GEMMA_ONEDNN_BRGEMM
1948

@@ -55,8 +84,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
5584

5685
ke.K_blk = cfg.K_blk;
5786
ke.N_blk = cfg.N_blk;
58-
ke.M_blk =
59-
static_cast<int64_t>(std::min(static_cast<size_t>(cfg.M_blk), M));
87+
ke.M_blk = std::min(cfg.M_blk, M);
6088

6189
ke.M_tail = M % ke.M_blk;
6290
ke.N_tail = N % ke.N_blk;
@@ -97,10 +125,13 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
97125
ke.m_sizes[1] = ke.M_tail ? ke.M_tail : ke.M_blk;
98126
ke.n_sizes[0] = ke.N_blk;
99127
ke.n_sizes[1] = ke.N_tail ? ke.N_tail : ke.N_blk;
100-
const int64_t ldb_for[2] = {ke.N_blk, ke.N_tail ? ke.N_tail : ke.N_blk};
101-
const int64_t ldc_for[2] = {ke.N_blk, ke.N_tail ? ke.N_tail : ke.N_blk};
128+
const int64_t ldb_for[2] = {static_cast<int64_t>(ke.N_blk),
129+
static_cast<int64_t>(ke.N_tail ? ke.N_tail : ke.N_blk)};
130+
const int64_t ldc_for[2] = {static_cast<int64_t>(ke.N_blk),
131+
static_cast<int64_t>(ke.N_tail ? ke.N_tail : ke.N_blk)};
102132

103-
// Create brgemm kernels for each (M-tile, N-tile) variant.
133+
// Create brgemm kernels for full/tail M and N tile sizes.
134+
// mi=0 is the full M tile, mi=1 is the M-tail; likewise for ni and N.
104135
size_t max_sp = 0;
105136
for (int mi = 0; mi < 2; ++mi) {
106137
for (int ni = 0; ni < 2; ++ni) {
@@ -109,22 +140,25 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
109140
if (mi == 0 && ke.M_full_tiles == 0) continue;
110141
if (ni == 0 && ke.N_full_tiles == 0) continue;
111142

112-
const int64_t ms = ke.m_sizes[mi];
113-
const int64_t ns = ke.n_sizes[ni];
143+
const int64_t ms = static_cast<int64_t>(ke.m_sizes[mi]);
144+
const int64_t ns = static_cast<int64_t>(ke.n_sizes[ni]);
114145

115146
if (ke.K_chunks > 0) {
116-
if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns, ke.K_blk,
117-
ke.K_super_size, ke.lda, ldb_for[ni], ldc_for[ni],
118-
a_dt, b_dt, c_dt, false)) {
147+
if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns,
148+
static_cast<int64_t>(ke.K_blk),
149+
static_cast<int64_t>(ke.K_super_size), ke.lda,
150+
ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt,
151+
false)) {
119152
return;
120153
}
121154
max_sp = std::max(max_sp,
122155
ke.brg_first_all[mi][ni].get_scratchpad_size());
123156
}
124157
if (ke.K_super_blocks > 1) {
125-
if (!MakeBrgemm(ke.brg_full[mi][ni], ms, ns, ke.K_blk,
126-
ke.batch_full, ke.lda, ldb_for[ni], ldc_for[ni],
127-
a_dt, b_dt, c_dt, true)) {
158+
if (!MakeBrgemm(ke.brg_full[mi][ni], ms, ns,
159+
static_cast<int64_t>(ke.K_blk),
160+
static_cast<int64_t>(ke.batch_full), ke.lda,
161+
ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, true)) {
128162
return;
129163
}
130164
max_sp =
@@ -134,7 +168,8 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
134168
const bool rem_is_first = (ke.K_super_blocks == 0);
135169
auto& target = rem_is_first ? ke.brg_first_rem[mi][ni]
136170
: ke.brg_rem[mi][ni];
137-
if (!MakeBrgemm(target, ms, ns, ke.K_blk, ke.batch_rem, ke.lda,
171+
if (!MakeBrgemm(target, ms, ns, static_cast<int64_t>(ke.K_blk),
172+
static_cast<int64_t>(ke.batch_rem), ke.lda,
138173
ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt,
139174
!rem_is_first)) {
140175
return;
@@ -143,7 +178,8 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
143178
}
144179
if (ke.K_tail > 0) {
145180
const bool add_c = (ke.K_chunks > 0);
146-
if (!MakeBrgemm(ke.brg_ktail[mi][ni], ms, ns, ke.K_tail, 1, ke.lda,
181+
if (!MakeBrgemm(ke.brg_ktail[mi][ni], ms, ns,
182+
static_cast<int64_t>(ke.K_tail), 1, ke.lda,
147183
ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt,
148184
add_c)) {
149185
return;
@@ -161,28 +197,30 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
161197
if (ni == 1 && ke.N_tail == 0) continue;
162198
if (ni == 0 && ke.N_full_tiles == 0) continue;
163199

164-
const int64_t ns = ke.n_sizes[ni];
200+
const int64_t ns = static_cast<int64_t>(ke.n_sizes[ni]);
165201
if (ke.K_chunks > 0) {
166-
const int64_t K_full = ke.K_chunks * ke.K_blk;
202+
const int64_t K_full =
203+
static_cast<int64_t>(ke.K_chunks * ke.K_blk);
167204
try {
168205
ke.pack_B[ni] = transform(K_full, ns, pack_type::trans,
169206
ke.ldb_orig, ldb_for[ni], b_dt, b_dt);
170207
if (!ke.pack_B[ni]) return;
171208
ke.pack_B[ni].generate();
172-
ke.blocked_B_size[ni] = ldb_for[ni] * ke.K_blk * ke.b_dt_size;
209+
ke.blocked_B_size[ni] = static_cast<size_t>(ldb_for[ni]) *
210+
ke.K_blk * ke.b_dt_size;
173211
} catch (...) {
174212
return;
175213
}
176214
}
177215
if (ke.K_tail > 0) {
178216
try {
179217
ke.pack_B_ktail[ni] = transform(
180-
ke.K_tail, ns, pack_type::trans, ke.ldb_orig, ldb_for[ni],
181-
b_dt, b_dt);
218+
static_cast<int64_t>(ke.K_tail), ns, pack_type::trans,
219+
ke.ldb_orig, ldb_for[ni], b_dt, b_dt);
182220
if (!ke.pack_B_ktail[ni]) return;
183221
ke.pack_B_ktail[ni].generate();
184222
ke.blocked_B_ktail_size[ni] =
185-
ldb_for[ni] * ke.K_tail * ke.b_dt_size;
223+
static_cast<size_t>(ldb_for[ni]) * ke.K_tail * ke.b_dt_size;
186224
} catch (...) {
187225
return;
188226
}
@@ -194,55 +232,55 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
194232
for (int ni = 0; ni < 2; ++ni) {
195233
if (ni == 1 && ke.N_tail == 0) continue;
196234
if (ni == 0 && ke.N_full_tiles == 0) continue;
197-
const int64_t cur_n = ke.n_sizes[ni];
235+
const size_t cur_n = ke.n_sizes[ni];
198236

199237
if (ke.K_chunks > 0) {
200238
ke.offsets_first_all[ni].resize(ke.K_super_size);
201-
for (int64_t i = 0; i < ke.K_super_size; ++i) {
239+
for (size_t i = 0; i < ke.K_super_size; ++i) {
202240
const int64_t a_off =
203-
i * ke.K_blk * static_cast<int64_t>(ke.a_dt_size);
241+
static_cast<int64_t>(i * ke.K_blk * ke.a_dt_size);
204242
const int64_t b_off =
205243
ke.need_pack
206-
? i * static_cast<int64_t>(ke.blocked_B_size[ni])
207-
: i * cur_n * ke.K_blk * static_cast<int64_t>(ke.b_dt_size);
244+
? static_cast<int64_t>(i * ke.blocked_B_size[ni])
245+
: static_cast<int64_t>(i * cur_n * ke.K_blk * ke.b_dt_size);
208246
ke.offsets_first_all[ni][i] = {a_off, b_off};
209247
}
210248
}
211249

212250
if (ke.K_super_blocks > 1) {
213251
ke.offsets_full[ni].resize(ke.K_super_blocks - 1);
214-
for (int64_t ks = 1; ks < ke.K_super_blocks; ++ks) {
252+
for (size_t ks = 1; ks < ke.K_super_blocks; ++ks) {
215253
auto& tbl = ke.offsets_full[ni][ks - 1];
216254
tbl.resize(ke.batch_full);
217-
const int64_t k_start = ks * ke.K_super_size;
218-
for (int64_t i = 0; i < ke.batch_full; ++i) {
219-
const int64_t k_idx = k_start + i;
255+
const size_t k_start = ks * ke.K_super_size;
256+
for (size_t i = 0; i < ke.batch_full; ++i) {
257+
const size_t k_idx = k_start + i;
220258
const int64_t a_off =
221-
k_idx * ke.K_blk * static_cast<int64_t>(ke.a_dt_size);
259+
static_cast<int64_t>(k_idx * ke.K_blk * ke.a_dt_size);
222260
const int64_t b_off =
223261
ke.need_pack
224-
? k_idx * static_cast<int64_t>(ke.blocked_B_size[ni])
225-
: k_idx * cur_n * ke.K_blk *
226-
static_cast<int64_t>(ke.b_dt_size);
262+
? static_cast<int64_t>(k_idx * ke.blocked_B_size[ni])
263+
: static_cast<int64_t>(k_idx * cur_n * ke.K_blk *
264+
ke.b_dt_size);
227265
tbl[i] = {a_off, b_off};
228266
}
229267
}
230268
}
231269

232270
if (ke.K_super_rem > 0) {
233-
const int64_t k_base = ke.K_super_blocks * ke.K_super_size;
271+
const size_t k_base = ke.K_super_blocks * ke.K_super_size;
234272
auto& rem_tbl = (ke.K_super_blocks == 0) ? ke.offsets_first_rem[ni]
235273
: ke.offsets_rem[ni];
236274
rem_tbl.resize(ke.K_super_rem);
237-
for (int64_t i = 0; i < ke.K_super_rem; ++i) {
238-
const int64_t k_idx = k_base + i;
275+
for (size_t i = 0; i < ke.K_super_rem; ++i) {
276+
const size_t k_idx = k_base + i;
239277
const int64_t a_off =
240-
k_idx * ke.K_blk * static_cast<int64_t>(ke.a_dt_size);
278+
static_cast<int64_t>(k_idx * ke.K_blk * ke.a_dt_size);
241279
const int64_t b_off =
242280
ke.need_pack
243-
? k_idx * static_cast<int64_t>(ke.blocked_B_size[ni])
244-
: k_idx * cur_n * ke.K_blk *
245-
static_cast<int64_t>(ke.b_dt_size);
281+
? static_cast<int64_t>(k_idx * ke.blocked_B_size[ni])
282+
: static_cast<int64_t>(k_idx * cur_n * ke.K_blk *
283+
ke.b_dt_size);
246284
rem_tbl[i] = {a_off, b_off};
247285
}
248286
}
@@ -270,7 +308,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
270308

271309
if (ke.need_pack) {
272310
size_t total_packed = 0;
273-
for (int64_t nt = 0; nt < ke.N_total_tiles; ++nt) {
311+
for (size_t nt = 0; nt < ke.N_total_tiles; ++nt) {
274312
const int ni = (nt < ke.N_full_tiles) ? 0 : 1;
275313
pe.B_tile_offset[nt] = total_packed;
276314
if (ke.K_chunks > 0)
@@ -283,13 +321,13 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
283321
uint8_t* B_packed = pe.B_packed_buf.data();
284322
if (!B_packed) return;
285323

286-
for (int64_t nt = 0; nt < ke.N_total_tiles; ++nt) {
324+
for (size_t nt = 0; nt < ke.N_total_tiles; ++nt) {
287325
const int ni = (nt < ke.N_full_tiles) ? 0 : 1;
288-
const int64_t b_row = (nt < ke.N_full_tiles)
289-
? nt * ke.N_blk
290-
: ke.N_full_tiles * ke.N_blk;
326+
const size_t b_row = (nt < ke.N_full_tiles)
327+
? nt * ke.N_blk
328+
: ke.N_full_tiles * ke.N_blk;
291329
const uint8_t* B_in =
292-
B_base + b_row * ke.ldb_orig * ke.b_dt_size;
330+
B_base + b_row * static_cast<size_t>(ke.ldb_orig) * ke.b_dt_size;
293331

294332
try {
295333
if (ke.K_chunks > 0) {
@@ -320,14 +358,14 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
320358

321359
// Execute one (m, n) tile for a given K-super-block.
322360
const auto execute_tile = [&](size_t m_start, size_t n_start,
323-
int64_t k_super, float* temp_C,
361+
size_t k_super, float* temp_C,
324362
uint8_t* scratch) HWY_ATTR {
325-
const int64_t m_tile_idx = m_start / ke.M_blk;
326-
const int64_t n_tile_idx = n_start / ke.N_blk;
363+
const size_t m_tile_idx = m_start / ke.M_blk;
364+
const size_t n_tile_idx = n_start / ke.N_blk;
327365
const int mi = (m_tile_idx < ke.M_full_tiles) ? 0 : 1;
328366
const int ni = (n_tile_idx < ke.N_full_tiles) ? 0 : 1;
329-
const int64_t cur_m = ke.m_sizes[mi];
330-
const int64_t cur_n = ke.n_sizes[ni];
367+
const size_t cur_m = ke.m_sizes[mi];
368+
const size_t cur_n = ke.n_sizes[ni];
331369

332370
const size_t real_m = (m_tile_idx < ke.M_full_tiles)
333371
? m_tile_idx * ke.M_blk
@@ -336,16 +374,18 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
336374
? n_tile_idx * ke.N_blk
337375
: ke.N_full_tiles * ke.N_blk;
338376

339-
const uint8_t* A_tile = A_base + real_m * ke.lda * ke.a_dt_size;
377+
const uint8_t* A_tile =
378+
A_base + real_m * static_cast<size_t>(ke.lda) * ke.a_dt_size;
340379
const void* B_tile =
341380
ke.need_pack
342381
? static_cast<const void*>(B_packed +
343382
pe.B_tile_offset[n_tile_idx])
344-
: static_cast<const void*>(B_base +
345-
real_n * ke.ldb_orig * ke.b_dt_size);
383+
: static_cast<const void*>(
384+
B_base +
385+
real_n * static_cast<size_t>(ke.ldb_orig) * ke.b_dt_size);
346386

347387
float* C_tile_ptr = temp_C;
348-
const int64_t k_total =
388+
const size_t k_total =
349389
ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0);
350390

351391
if (k_super < ke.K_super_blocks) {
@@ -379,7 +419,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
379419
? static_cast<const void*>(B_packed +
380420
pe.B_ktail_offset[n_tile_idx])
381421
: static_cast<const void*>(
382-
B_base + (real_n * ke.ldb_orig +
422+
B_base + (real_n * static_cast<size_t>(ke.ldb_orig) +
383423
ke.K_chunks * ke.K_blk) *
384424
ke.b_dt_size);
385425
ke.brg_ktail[mi][ni].execute(A_ktail, const_cast<void*>(B_ktail),
@@ -390,19 +430,18 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
390430
const hn::ScalableTag<float> df;
391431
const auto vscale = hn::Set(df, scale);
392432
const size_t lanes = hn::Lanes(df);
393-
for (int64_t m = 0; m < cur_m; ++m) {
433+
for (size_t m = 0; m < cur_m; ++m) {
394434
TC* C_row = C.Row(real_m + m) + real_n;
395435
const float* t_row = C_tile_ptr + m * cur_n;
396436
const float* add_row = add ? add + real_n : nullptr;
397-
int64_t n = 0;
437+
size_t n = 0;
398438
if (add_row) {
399-
for (; n + static_cast<int64_t>(lanes) <= cur_n;
400-
n += static_cast<int64_t>(lanes)) {
439+
for (; n + lanes <= cur_n; n += lanes) {
401440
const auto v = hn::Load(df, t_row + n);
402441
const auto va = hn::Load(df, add_row + n);
403442
const auto result = hn::MulAdd(v, vscale, va);
404443
if constexpr (hwy::IsSame<TC, float>()) {
405-
hn::Store(result, df, reinterpret_cast<float*>(C_row) + n);
444+
hn::Store(result, df, HWY_RCAST_ALIGNED(float*, C_row) + n);
406445
} else {
407446
const hn::Rebind<TC, decltype(df)> dc;
408447
hn::Store(hn::DemoteTo(dc, result), dc, C_row + n);
@@ -413,12 +452,11 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
413452
C_row[n] = hwy::ConvertScalarTo<TC>(val);
414453
}
415454
} else {
416-
for (; n + static_cast<int64_t>(lanes) <= cur_n;
417-
n += static_cast<int64_t>(lanes)) {
455+
for (; n + lanes <= cur_n; n += lanes) {
418456
const auto v = hn::Load(df, t_row + n);
419457
const auto result = hn::Mul(v, vscale);
420458
if constexpr (hwy::IsSame<TC, float>()) {
421-
hn::Store(result, df, reinterpret_cast<float*>(C_row) + n);
459+
hn::Store(result, df, HWY_RCAST_ALIGNED(float*, C_row) + n);
422460
} else {
423461
const hn::Rebind<TC, decltype(df)> dc;
424462
hn::Store(hn::DemoteTo(dc, result), dc, C_row + n);
@@ -434,9 +472,9 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
434472
};
435473

436474
// Parallel dispatch: K-super outer, N middle, M inner (keeps B in L2).
437-
const int64_t k_total_supers =
475+
const size_t k_total_supers =
438476
ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0);
439-
const int64_t k_iters = (k_total_supers > 0) ? k_total_supers : 1;
477+
const size_t k_iters = (k_total_supers > 0) ? k_total_supers : size_t{1};
440478

441479
const size_t num_threads = ctx.pools.MaxWorkersPerCluster();
442480
const size_t total_n_tiles = ke.N_total_tiles;
@@ -466,12 +504,11 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
466504
const size_t total_tc = total_m_tiles * n_tiles_in_range;
467505
float* tc_base = tbufs.EnsureTempC(total_tc);
468506

469-
for (int64_t ks = 0; ks < k_iters; ++ks) {
507+
for (size_t ks = 0; ks < k_iters; ++ks) {
470508
size_t n_idx = 0;
471509
for (size_t nt = n_begin; nt < n_end; ++nt) {
472510
const size_t n = nt * ke.N_blk;
473-
for (int64_t mt = 0; mt < static_cast<int64_t>(total_m_tiles);
474-
++mt) {
511+
for (size_t mt = 0; mt < total_m_tiles; ++mt) {
475512
const size_t m = mt * ke.M_blk;
476513
float* temp_C =
477514
tc_base + (mt * n_tiles_in_range + n_idx) *
@@ -485,8 +522,14 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
485522

486523
dnnl::ukernel::brgemm::release_hw_context();
487524
auto& main_bufs = GetBRGeMMThreadBufs();
488-
main_bufs.hw_ctx_set = false;
489525
main_bufs.hw_ctx_kernel = nullptr;
490526
}
491527

492528
#endif // GEMMA_ONEDNN_BRGEMM
529+
530+
// NOLINTNEXTLINE(google-readability-namespace-comments)
531+
} // namespace HWY_NAMESPACE
532+
} // namespace gcpp
533+
HWY_AFTER_NAMESPACE();
534+
535+
#endif // NOLINT

0 commit comments

Comments
 (0)