Skip to content

Commit b1c5cb9

Browse files
densamoilovazhai219
authored andcommitted
cpu: x64: matmul: refactor pick_blocked_B_layout function
(cherry picked from commit 08beeba)
1 parent b619675 commit b1c5cb9

1 file changed

Lines changed: 21 additions & 12 deletions

File tree

src/cpu/x64/matmul/brgemm_matmul_utils.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -629,41 +629,50 @@ status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const {
629629

630630
format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
631631
int n_blk) const {
632+
632633
if (bgmmc.ndims > 3) return format_tag::undef;
633-
if (this->is_int8() || this->is_f8()) switch (n_blk) {
634+
635+
if (is_int8() || is_f8()) {
636+
switch (n_blk) {
634637
case 64: return bgmmc.ndims == 3 ? aCB16b64c4b : BA16a64b4a;
635638
case 48: return bgmmc.ndims == 3 ? aCB16b48c4b : BA16a48b4a;
636639
case 32: return bgmmc.ndims == 3 ? aCB16b32c4b : BA16a32b4a;
637640
case 16: return bgmmc.ndims == 3 ? aCB16b16c4b : BA16a16b4a;
638641
default: return format_tag::undef;
639642
}
643+
}
644+
645+
const bool is_amx_or_avx2_vnni_2 = is_superset(bgmmc.isa, avx512_core_amx)
646+
|| is_superset(bgmmc.isa, avx2_vnni_2);
647+
const bool prefer_amx_or_avx2_vnni_2 = is_f16() || is_f32_f16()
648+
|| is_f32_bf16() || is_f16_with_int_wei()
649+
|| is_f32_with_int_wei();
640650

641-
if (this->is_bf16() || this->is_bf16_with_int_wei()
642-
|| ((this->is_f16() || this->is_f32_f16() || this->is_f32_bf16()
643-
|| this->is_f16_with_int_wei()
644-
|| this->is_f32_with_int_wei())
645-
&& (is_superset(bgmmc.isa, avx512_core_amx)
646-
|| is_superset(bgmmc.isa, avx2_vnni_2))))
651+
if ((prefer_amx_or_avx2_vnni_2 && is_amx_or_avx2_vnni_2) || is_bf16()
652+
|| is_bf16_with_int_wei()) {
647653
switch (n_blk) {
648654
case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a;
649655
case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a;
650656
case 32: return bgmmc.ndims == 3 ? aCB16b32c2b : BA16a32b2a;
651657
case 16: return bgmmc.ndims == 3 ? aCB16b16c2b : BA16a16b2a;
652658
default: return format_tag::undef;
653659
}
660+
}
661+
654662
// Note: bf32 assumes f32 blocking
655-
if (this->is_f32() || this->is_bf32() || this->is_f16()
656-
|| this->is_f32_f16() || this->is_f32_bf16()
657-
|| this->is_f16_with_int_wei() || this->is_tf32()
658-
|| (this->is_f32_with_int_wei()
659-
&& is_superset(bgmmc.isa, avx512_core)))
663+
if (is_f32() || is_bf32() || is_f16() || is_f32_f16() || is_f32_bf16()
664+
|| is_f16_with_int_wei() || is_tf32()
665+
|| (is_f32_with_int_wei()
666+
&& is_superset(bgmmc.isa, avx512_core))) {
660667
switch (n_blk) {
661668
case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
662669
case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
663670
case 32: return bgmmc.ndims == 3 ? aCB16b32c : BA16a32b;
664671
case 16: return bgmmc.ndims == 3 ? aCB16b16c : BA16a16b;
665672
default: return format_tag::undef;
666673
}
674+
}
675+
667676
return format_tag::undef;
668677
}
669678

0 commit comments

Comments
 (0)