Skip to content

Commit da2c1a2

Browse files
committed
vectorize cooperative_tensor load/store to use 4-wide memory ops
Use half4/float4 pointer casts for cooperative_tensor element loads and stores instead of per-element scalar loops. Each 16x16 fragment load now emits two 4-wide reads (2 rows × 4 cols) per fragment instead of an 8-iteration scalar loop.
1 parent 55bc892 commit da2c1a2

1 file changed

Lines changed: 16 additions & 12 deletions

File tree

src/target/codegen_metal.cc

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -662,12 +662,14 @@ void CodeGenTileLangMetal::VisitExpr_(const CallNode *op,
662662
for (int fc = 0; fc < nfrag_c; fc++) {
663663
int row_off = fr * frag_rows;
664664
int col_off = fc * frag_cols;
665-
os << "for (ushort __i = 0; __i < 8; __i++) { "
666-
<< "ushort __r = __base_row + (__i >> 2) * 8 + " << row_off << "; "
667-
<< "ushort __c = __base_col + (__i & 3) + " << col_off << "; ";
668-
os << var << "[" << idx << " * " << (nfrag_r * nfrag_c * 8) << " + "
669-
<< elem_offset << " + __i] = "
670-
<< "__src[__r * " << stride << " + __c]; } ";
665+
os << "{ "
666+
<< "ushort __r0 = __base_row + " << row_off << "; "
667+
<< "ushort __r1 = __r0 + 8; "
668+
<< "ushort __c0 = __base_col + " << col_off << "; "
669+
<< "*(thread " << dtype << "4*)(&" << var << "[" << idx << " * " << (nfrag_r * nfrag_c * 8) << " + " << elem_offset << "]) = "
670+
<< "*(" << addr_space << " " << dtype << "4*)(&__src[__r0 * " << stride << " + __c0]); "
671+
<< "*(thread " << dtype << "4*)(&" << var << "[" << idx << " * " << (nfrag_r * nfrag_c * 8) << " + " << (elem_offset + 4) << "]) = "
672+
<< "*(" << addr_space << " " << dtype << "4*)(&__src[__r1 * " << stride << " + __c0]); } ";
671673
elem_offset += 8;
672674
}
673675
}
@@ -701,12 +703,14 @@ void CodeGenTileLangMetal::VisitExpr_(const CallNode *op,
701703
for (int fc = 0; fc < nfrag_c; fc++) {
702704
int row_off = fr * frag_rows;
703705
int col_off = fc * frag_cols;
704-
os << "for (ushort __i = 0; __i < 8; __i++) { "
705-
<< "ushort __r = __base_row + (__i >> 2) * 8 + " << row_off << "; "
706-
<< "ushort __c = __base_col + (__i & 3) + " << col_off << "; "
707-
<< "__dst[__r * " << stride << " + __c] = "
708-
<< var << "[" << idx << " * " << total_elems << " + "
709-
<< elem_offset << " + __i]; } ";
706+
os << "{ "
707+
<< "ushort __r0 = __base_row + " << row_off << "; "
708+
<< "ushort __r1 = __r0 + 8; "
709+
<< "ushort __c0 = __base_col + " << col_off << "; "
710+
<< "*(" << addr_space << " " << dtype << "4*)(&__dst[__r0 * " << stride << " + __c0]) = "
711+
<< "*(thread " << dtype << "4*)(&" << var << "[" << idx << " * " << total_elems << " + " << elem_offset << "]); "
712+
<< "*(" << addr_space << " " << dtype << "4*)(&__dst[__r1 * " << stride << " + __c0]) = "
713+
<< "*(thread " << dtype << "4*)(&" << var << "[" << idx << " * " << total_elems << " + " << (elem_offset + 4) << "]); } ";
710714
elem_offset += 8;
711715
}
712716
}

0 commit comments

Comments
 (0)