@@ -21,8 +21,6 @@ namespace mlir {
2121namespace triton {
2222namespace gpu {
2323
24- namespace {
25-
2624// Get the highest version supported for the hardware and the dot.
2725static int getMMAVersionSafe (int computeCapability, DotOp op) {
2826 // List supported mma version in order of preference.
@@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4745 return 0 ;
4846}
4947
50- SmallVector<unsigned > warpsPerTileV2 (DotOp dotOp, const ArrayRef< int64_t > shape,
51- int numWarps) {
48+ SmallVector<unsigned >
49+ warpsPerTileV2 (Operation *dotOp, const ArrayRef< int64_t > shape, int numWarps) {
5250 auto rank = shape.size ();
5351 // Early exit for batched matmul
5452 if (rank == 3 )
@@ -112,10 +110,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
112110}
113111
114112SmallVector<unsigned , 2 >
115- warpsPerTileV3 (DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
113+ warpsPerTileV3 (Operation * dotOp, const ArrayRef<int64_t > shape, int numWarps,
116114 const SmallVector<unsigned , 3 > &instrShape) {
117115 SetVector<Operation *> slices;
118- mlir::getForwardSlice (dotOp. getResult (), &slices);
116+ mlir::getForwardSlice (dotOp-> getResult (0 ), &slices);
119117 // Contains a chained dot. We prefer to assign warps to one axis
120118 // to facilitate use cases like flash attention, allowing reductions within
121119 // the same warp.
@@ -181,6 +179,21 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
181179 auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
182180 newLayout, SharedMemorySpace);
183181 rewriter.setInsertionPointAfterValue (arg);
182+
183+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
184+ // to SharedEncoding.
185+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
186+ argType.getEncoding ())) {
187+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
188+ // then pass it to the LocalAllocOp.
189+ auto newArgType = RankedTensorType::get (
190+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
191+ auto dotOperandToBlockedCvt =
192+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
193+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
194+ dotOperandToBlockedCvt);
195+ }
196+
184197 return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
185198}
186199
@@ -204,7 +217,7 @@ getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) {
204217}
205218
206219SmallVector<unsigned , 3 >
207- getWarpsPerTile (DotOp dotOp, const ArrayRef<int64_t > shape, int version,
220+ getWarpsPerTile (Operation* dotOp, const ArrayRef<int64_t > shape, int version,
208221 int numWarps, const SmallVector<unsigned , 3 > &instrShape) {
209222 switch (version) {
210223 case 2 :
@@ -218,6 +231,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
218231}
219232
220233static bool bwdFilter (Operation *op) {
234+ // Dot operand layout assignment to Predicates are not currently supported
235+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
236+ // condition limits visibility of the original bit-width so that predicate
237+ // are not considered, hence, kwidth can never be = 32.
238+ if (isa<arith::UIToFPOp>(op)) {
239+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
240+ if (srcType.isInteger (1 ))
241+ return false ;
242+ }
243+
221244 return op->getNumOperands () == 1 &&
222245 (isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
223246 isPureUnaryInlineAsm (op) ||
@@ -237,7 +260,7 @@ static bool bwdFilter(Operation *op) {
237260// result, kwidth can be the bitwidth of the lower precision primitive.
238261// Conversely, in the downcasting scenario, no reordering is performed,
239262// making it directory use the lower precision primitive.
240- static int computeOrigBitWidth (Value x) {
263+ int computeOrigBitWidth (Value x) {
241264 int finalBitWidth = getElementTypeOrSelf (x).getIntOrFloatBitWidth ();
242265 int origBitWidth = finalBitWidth;
243266 SetVector<Operation *> slice;
@@ -257,6 +280,9 @@ static int computeOrigBitWidth(Value x) {
257280 }
258281 return origBitWidth;
259282}
283+ // Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
284+ // extension.
285+ namespace {
260286
261287class BlockedToMMA : public mlir ::OpRewritePattern<DotOp> {
262288 int computeCapability;
@@ -1147,6 +1173,11 @@ class TritonGPUAccelerateMatmulPass
11471173 }
11481174};
11491175
1176+ Value getSharedMemMMAOperand (Value v, mlir::PatternRewriter &rewriter,
1177+ int opIdx, bool allowTranspose) {
1178+ return getSharedMemoryMMAOperand (v, rewriter, opIdx, allowTranspose);
1179+ }
1180+
11501181} // namespace gpu
11511182} // namespace triton
11521183} // namespace mlir
0 commit comments