@@ -662,12 +662,14 @@ void CodeGenTileLangMetal::VisitExpr_(const CallNode *op,
662662 for (int fc = 0 ; fc < nfrag_c; fc++) {
663663 int row_off = fr * frag_rows;
664664 int col_off = fc * frag_cols;
665- os << " for (ushort __i = 0; __i < 8; __i++) { "
666- << " ushort __r = __base_row + (__i >> 2) * 8 + " << row_off << " ; "
667- << " ushort __c = __base_col + (__i & 3) + " << col_off << " ; " ;
668- os << var << " [" << idx << " * " << (nfrag_r * nfrag_c * 8 ) << " + "
669- << elem_offset << " + __i] = "
670- << " __src[__r * " << stride << " + __c]; } " ;
665+ os << " { "
666+ << " ushort __r0 = __base_row + " << row_off << " ; "
667+ << " ushort __r1 = __r0 + 8; "
668+ << " ushort __c0 = __base_col + " << col_off << " ; "
669+ << " *(thread " << dtype << " 4*)(&" << var << " [" << idx << " * " << (nfrag_r * nfrag_c * 8 ) << " + " << elem_offset << " ]) = "
670+ << " *(" << addr_space << " " << dtype << " 4*)(&__src[__r0 * " << stride << " + __c0]); "
671+ << " *(thread " << dtype << " 4*)(&" << var << " [" << idx << " * " << (nfrag_r * nfrag_c * 8 ) << " + " << (elem_offset + 4 ) << " ]) = "
672+ << " *(" << addr_space << " " << dtype << " 4*)(&__src[__r1 * " << stride << " + __c0]); } " ;
671673 elem_offset += 8 ;
672674 }
673675 }
@@ -701,12 +703,14 @@ void CodeGenTileLangMetal::VisitExpr_(const CallNode *op,
701703 for (int fc = 0 ; fc < nfrag_c; fc++) {
702704 int row_off = fr * frag_rows;
703705 int col_off = fc * frag_cols;
704- os << " for (ushort __i = 0; __i < 8; __i++) { "
705- << " ushort __r = __base_row + (__i >> 2) * 8 + " << row_off << " ; "
706- << " ushort __c = __base_col + (__i & 3) + " << col_off << " ; "
707- << " __dst[__r * " << stride << " + __c] = "
708- << var << " [" << idx << " * " << total_elems << " + "
709- << elem_offset << " + __i]; } " ;
706+ os << " { "
707+ << " ushort __r0 = __base_row + " << row_off << " ; "
708+ << " ushort __r1 = __r0 + 8; "
709+ << " ushort __c0 = __base_col + " << col_off << " ; "
710+ << " *(" << addr_space << " " << dtype << " 4*)(&__dst[__r0 * " << stride << " + __c0]) = "
711+ << " *(thread " << dtype << " 4*)(&" << var << " [" << idx << " * " << total_elems << " + " << elem_offset << " ]); "
712+ << " *(" << addr_space << " " << dtype << " 4*)(&__dst[__r1 * " << stride << " + __c0]) = "
713+ << " *(thread " << dtype << " 4*)(&" << var << " [" << idx << " * " << total_elems << " + " << (elem_offset + 4 ) << " ]); } " ;
710714 elem_offset += 8 ;
711715 }
712716 }
0 commit comments