@@ -629,41 +629,50 @@ status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const {
629629
630630format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout (
631631 int n_blk) const {
632+
632633 if (bgmmc.ndims > 3 ) return format_tag::undef;
633- if (this ->is_int8 () || this ->is_f8 ()) switch (n_blk) {
634+
635+ if (is_int8 () || is_f8 ()) {
636+ switch (n_blk) {
634637 case 64 : return bgmmc.ndims == 3 ? aCB16b64c4b : BA16a64b4a;
635638 case 48 : return bgmmc.ndims == 3 ? aCB16b48c4b : BA16a48b4a;
636639 case 32 : return bgmmc.ndims == 3 ? aCB16b32c4b : BA16a32b4a;
637640 case 16 : return bgmmc.ndims == 3 ? aCB16b16c4b : BA16a16b4a;
638641 default : return format_tag::undef;
639642 }
643+ }
644+
645+ const bool is_amx_or_avx2_vnni_2 = is_superset (bgmmc.isa , avx512_core_amx)
646+ || is_superset (bgmmc.isa , avx2_vnni_2);
647+ const bool prefer_amx_or_avx2_vnni_2 = is_f16 () || is_f32_f16 ()
648+ || is_f32_bf16 () || is_f16_with_int_wei ()
649+ || is_f32_with_int_wei ();
640650
641- if (this ->is_bf16 () || this ->is_bf16_with_int_wei ()
642- || ((this ->is_f16 () || this ->is_f32_f16 () || this ->is_f32_bf16 ()
643- || this ->is_f16_with_int_wei ()
644- || this ->is_f32_with_int_wei ())
645- && (is_superset (bgmmc.isa , avx512_core_amx)
646- || is_superset (bgmmc.isa , avx2_vnni_2))))
651+ if ((prefer_amx_or_avx2_vnni_2 && is_amx_or_avx2_vnni_2) || is_bf16 ()
652+ || is_bf16_with_int_wei ()) {
647653 switch (n_blk) {
648654 case 64 : return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a;
649655 case 48 : return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a;
650656 case 32 : return bgmmc.ndims == 3 ? aCB16b32c2b : BA16a32b2a;
651657 case 16 : return bgmmc.ndims == 3 ? aCB16b16c2b : BA16a16b2a;
652658 default : return format_tag::undef;
653659 }
660+ }
661+
654662 // Note: bf32 assumes f32 blocking
655- if (this ->is_f32 () || this ->is_bf32 () || this ->is_f16 ()
656- || this ->is_f32_f16 () || this ->is_f32_bf16 ()
657- || this ->is_f16_with_int_wei () || this ->is_tf32 ()
658- || (this ->is_f32_with_int_wei ()
659- && is_superset (bgmmc.isa , avx512_core)))
663+ if (is_f32 () || is_bf32 () || is_f16 () || is_f32_f16 () || is_f32_bf16 ()
664+ || is_f16_with_int_wei () || is_tf32 ()
665+ || (is_f32_with_int_wei ()
666+ && is_superset (bgmmc.isa , avx512_core))) {
660667 switch (n_blk) {
661668 case 64 : return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
662669 case 48 : return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
663670 case 32 : return bgmmc.ndims == 3 ? aCB16b32c : BA16a32b;
664671 case 16 : return bgmmc.ndims == 3 ? aCB16b16c : BA16a16b;
665672 default : return format_tag::undef;
666673 }
674+ }
675+
667676 return format_tag::undef;
668677}
669678
0 commit comments