Skip to content

Commit 02dad3d

Browse files
Varchocopybara-github
authored andcommitted
[SDY][PadForDivisibility] Add DotGeneralOpPattern to enforce kZero padding
- Implement DotGeneralOpPattern to enforce kZero padding on operands of stablehlo.dot_general operations. - Exclude DotGeneralOp from GenericOpPattern to avoid ambiguity. - Add MLIR test case verifying correct padding propagation for dot_general. Also adds a general util... - Create ensurePadding helper to insert select operations when cached padding differs from required padding. PiperOrigin-RevId: 919119485
1 parent c3c8777 commit 02dad3d

2 files changed

Lines changed: 226 additions & 5 deletions

File tree

shardy/dialect/sdy/transforms/export/pad_for_divisibility.cc

Lines changed: 126 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ class PaddedTypeConverter : public TypeConverter {
111111
// Known padding value kinds for generated padding values.
112112
enum class PaddingValueKind { kZero, kOne };
113113

114+
// Returns true if the operation has custom padding handling implemented in
115+
// this file and should be excluded from GenericOpPattern.
116+
bool hasCustomPadHandling(Operation* op) {
117+
return isa<stablehlo::SliceOp, stablehlo::DotGeneralOp>(op);
118+
}
119+
114120
class PaddingCache {
115121
public:
116122
// Registers the padding kind for a value. The conversion pattern
@@ -186,6 +192,75 @@ Value createPaddedValue(RankedTensorType paddedType, Value value,
186192
return padOp;
187193
}
188194

195+
// Returns 'inputVal' if the value is not padded or the value already has
196+
// 'requiredKind' as PaddingValueKind. Otherwise, uses compare-and-select to
197+
// produce a new padded value from inputVal with the requiredKind padding and
198+
// returns the new value.
199+
//
200+
// We ensure all dimensions that require padding are padded with requireKind
201+
// unless dimsToEnforce is provided, in which case only the specified
202+
// dimensions are padded.
203+
Value ensurePadding(
204+
Value inputVal, RankedTensorType origType, PaddingValueKind requiredKind,
205+
OpBuilder& b, Location loc, PaddingCache& cache,
206+
std::optional<ArrayRef<int64_t>> dimsToEnforce = std::nullopt) {
207+
// Return early if no padding is applied or the cached padding already
208+
// matches.
209+
auto paddedType = cast<RankedTensorType>(inputVal.getType());
210+
if (origType == paddedType) {
211+
return inputVal;
212+
}
213+
std::optional<PaddingValueKind> currentKind = cache.getPadding(inputVal);
214+
if (currentKind && *currentKind == requiredKind) {
215+
return inputVal;
216+
}
217+
218+
// Build a mask that is `true` for the original (unpadded) data region.
219+
// An element is in the original region if its index along each padded
220+
// dimension is less than the original unpadded size (index < original_size).
221+
Value validDataMask;
222+
for (auto [dim, origSize] : llvm::enumerate(origType.getShape())) {
223+
if (origSize == paddedType.getDimSize(dim) ||
224+
(dimsToEnforce && !llvm::is_contained(*dimsToEnforce, dim))) {
225+
continue;
226+
}
227+
auto iotaType =
228+
RankedTensorType::get(paddedType.getShape(), b.getI32Type());
229+
Value iota = stablehlo::IotaOp::create(b, loc, iotaType, dim);
230+
Value limit = stablehlo::ConstantOp::create(
231+
b, loc,
232+
DenseElementsAttr::get(RankedTensorType::get({}, b.getI32Type()),
233+
b.getI32IntegerAttr(origSize)));
234+
Value broadcastLimit = stablehlo::BroadcastInDimOp::create(
235+
b, loc, iotaType, limit, b.getDenseI64ArrayAttr({}));
236+
Value mask = stablehlo::CompareOp::create(
237+
b, loc, iota, broadcastLimit, stablehlo::ComparisonDirection::LT);
238+
validDataMask = validDataMask
239+
? stablehlo::AndOp::create(b, loc, validDataMask, mask)
240+
: mask;
241+
}
242+
243+
if (!validDataMask) {
244+
return inputVal;
245+
}
246+
247+
// Create the constant with the new padding value and broadcast it to the
248+
// same shape as 'inputVal'.
249+
Value newPaddingScalar =
250+
createConstant(b, loc, paddedType.getElementType(), requiredKind);
251+
Value newPaddingValue = stablehlo::BroadcastInDimOp::create(
252+
b, loc, paddedType, newPaddingScalar, b.getDenseI64ArrayAttr({}));
253+
254+
// Keep the original data from 'inputVal' (where mask is true), and replace
255+
// the padded region with 'newPaddingValue' (where mask is false).
256+
Value select = stablehlo::SelectOp::create(b, loc, validDataMask, inputVal,
257+
newPaddingValue);
258+
if (!dimsToEnforce) {
259+
cache.setPadding(select, requiredKind);
260+
}
261+
return select;
262+
}
263+
189264
// Converts op to its local version by replacing its operands with the already
190265
// converted operands.
191266
LogicalResult padGenericOp(Operation* op, ValueRange operands,
@@ -263,7 +338,7 @@ class GenericOpPattern : public ConversionPattern {
263338
Dialect* dialect = op->getDialect();
264339
if ((dialect && dialect->getNamespace() != "stablehlo" &&
265340
!isa<sdy::ReturnOp>(op)) ||
266-
isa<stablehlo::SliceOp>(op)) {
341+
hasCustomPadHandling(op)) {
267342
return failure();
268343
}
269344
return padGenericOp(op, operands, rewriter,
@@ -427,15 +502,17 @@ class AllSliceOpPattern : public OpConversionPattern<sdy::AllSliceOp> {
427502
return padGenericOp(op, adaptor.getOperands(), rewriter, converter);
428503
}
429504

430-
Value padOp = createPaddedValue(cast<RankedTensorType>(paddedInputType),
431-
input, PaddingValueKind::kZero, symbolTable,
432-
rewriter, cache);
505+
PaddingValueKind paddingKind = PaddingValueKind::kZero;
506+
Value padOp =
507+
createPaddedValue(cast<RankedTensorType>(paddedInputType), input,
508+
paddingKind, symbolTable, rewriter, cache);
433509
OperationState state(op->getLoc(), op->getName());
434510
state.addOperands({padOp});
435511
state.addTypes(
436512
{getPaddedType(op.getResult().getType(), outSharding, symbolTable)});
437513
state.addAttributes(op->getAttrs());
438514
Operation* newOp = rewriter.create(state);
515+
cache.setPadding(newOp->getResult(0), paddingKind);
439516

440517
rewriter.replaceOp(op, newOp->getResults());
441518
return success();
@@ -489,8 +566,51 @@ class StablehloSliceOpPattern : public OpConversionPattern<stablehlo::SliceOp> {
489566
}
490567
};
491568

569+
class StablehloDotGeneralOpPattern
570+
: public OpConversionPattern<stablehlo::DotGeneralOp> {
571+
public:
572+
StablehloDotGeneralOpPattern(TypeConverter& converter, MLIRContext* ctx,
573+
PaddingCache& cache)
574+
: OpConversionPattern(converter, ctx), cache(cache) {}
575+
576+
LogicalResult matchAndRewrite(
577+
stablehlo::DotGeneralOp op, OpAdaptor adaptor,
578+
ConversionPatternRewriter& rewriter) const override {
579+
auto* converter =
580+
static_cast<const PaddedTypeConverter*>(getTypeConverter());
581+
582+
Location loc = op.getLoc();
583+
stablehlo::DotDimensionNumbersAttr dimNums = op.getDotDimensionNumbers();
584+
585+
Value lhs = adaptor.getOperands()[0];
586+
auto lhsOrigType = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
587+
if (!lhsOrigType) {
588+
return failure();
589+
}
590+
Value paddedLhs =
591+
ensurePadding(lhs, lhsOrigType, PaddingValueKind::kZero, rewriter, loc,
592+
cache, dimNums.getLhsContractingDimensions());
593+
594+
Value rhs = adaptor.getOperands()[1];
595+
auto rhsOrigType = dyn_cast<RankedTensorType>(op->getOperand(1).getType());
596+
if (!rhsOrigType) {
597+
return failure();
598+
}
599+
Value paddedRhs =
600+
ensurePadding(rhs, rhsOrigType, PaddingValueKind::kZero, rewriter, loc,
601+
cache, dimNums.getRhsContractingDimensions());
602+
603+
return padGenericOp(op, {paddedLhs, paddedRhs}, rewriter, converter);
604+
}
605+
606+
private:
607+
PaddingCache& cache;
608+
};
609+
492610
struct PadForDivisibilityPass
493611
: public impl::PadForDivisibilityPassBase<PadForDivisibilityPass> {
612+
using PadForDivisibilityPassBase::PadForDivisibilityPassBase;
613+
494614
protected:
495615
void runOnOperation() final {
496616
// FuncOpPattern enforces that function inputs and outputs are always fully
@@ -508,7 +628,8 @@ struct PadForDivisibilityPass
508628
AllGatherOpPattern>(typeConverter, &getContext());
509629
// Sharing the padding cache reference across pattern instances is safe from
510630
// data races because pattern application within a function is sequential.
511-
patterns.add<AllSliceOpPattern>(typeConverter, &getContext(), paddingCache);
631+
patterns.add<AllSliceOpPattern, StablehloDotGeneralOpPattern>(
632+
typeConverter, &getContext(), paddingCache);
512633
ConversionTarget target(getContext());
513634

514635
auto isLegalType = [&](Type type, TensorShardingAttr sharding) {
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// RUN: sdy_opt %s -sdy-pad-for-divisibility | FileCheck %s
2+
3+
sdy.mesh @mesh_4_2 = <["x"=4, "y"=2]>
4+
5+
// CHECK-LABEL: func @padded_contracting_dims_reuse
6+
func.func @padded_contracting_dims_reuse(%arg0: tensor<4x7xf32>, %arg1: tensor<7x5xf32>) -> tensor<4x5xf32> {
7+
// Pad LHS with zero (contracting dimension).
8+
// CHECK: %[[CST0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
9+
// CHECK: %[[PAD0:.*]] = stablehlo.pad %arg0, %[[CST0]], low = [0, 0], high = [0, 1], interior = [0, 0] : (tensor<4x7xf32>, tensor<f32>) -> tensor<4x8xf32>
10+
// CHECK: %[[SLICE0:.*]] = sdy.all_slice [{}, {"y"}] %[[PAD0]] out_sharding=<@mesh_4_2, [{}, {"y"}]> : tensor<4x8xf32>
11+
12+
// Pad RHS with zero (both contracting and non-contracting dimensions).
13+
// CHECK: %[[CST1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
14+
// CHECK: %[[PAD1:.*]] = stablehlo.pad %arg1, %[[CST1]], low = [0, 0], high = [1, 3], interior = [0, 0] : (tensor<7x5xf32>, tensor<f32>) -> tensor<8x8xf32>
15+
// CHECK: %[[SLICE1:.*]] = sdy.all_slice [{"y"}, {"x"}] %[[PAD1]] out_sharding=<@mesh_4_2, [{"y"}, {"x"}]> : tensor<8x8xf32>
16+
17+
// Perform dot_general (result is padded on non-contracting dimension).
18+
// CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[SLICE0]], %[[SLICE1]], contracting_dims = [1] x [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4_2, [{}, {"x"}]>]>} : (tensor<4x8xf32>, tensor<8x8xf32>) -> tensor<4x8xf32>
19+
20+
// Trim the padded result back to original shape.
21+
// CHECK: %[[TRIM:.*]] = stablehlo.slice %[[DOT]] [0:4, 0:5] : (tensor<4x8xf32>) -> tensor<4x5xf32>
22+
// CHECK: return %[[TRIM]] : tensor<4x5xf32>
23+
24+
%0 = sdy.all_slice [{}, {"y"}] %arg0 out_sharding=<@mesh_4_2, [{}, {"y"}]> : tensor<4x7xf32>
25+
%1 = sdy.all_slice [{"y"}, {"x"}] %arg1 out_sharding=<@mesh_4_2, [{"y"}, {"x"}]> : tensor<7x5xf32>
26+
%2 = stablehlo.dot_general %0, %1, contracting_dims = [1] x [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4_2, [{}, {"x"}]>]>} : (tensor<4x7xf32>, tensor<7x5xf32>) -> tensor<4x5xf32>
27+
%3 = stablehlo.slice %2 [0:4, 0:5] : (tensor<4x5xf32>) -> tensor<4x5xf32>
28+
return %3 : tensor<4x5xf32>
29+
}
30+
31+
// CHECK-LABEL: func @padded_contracting_dims_not_reuse
32+
func.func @padded_contracting_dims_not_reuse(%arg0: tensor<4x7xf32>, %arg1: tensor<7x5xf32>) -> tensor<4x5xf32> {
33+
// Prepare padded LHS and RHS with unknown padding (via abs).
34+
// CHECK: %[[PAD0:.*]] = stablehlo.pad %arg0, {{.*}}
35+
// CHECK: %[[LHS_SLICE:.*]] = sdy.all_slice [{}, {"y"}] %[[PAD0]]
36+
// CHECK: %[[LHS_ABS:.*]] = stablehlo.abs %[[LHS_SLICE]] {{.*}}
37+
// CHECK: %[[PAD1:.*]] = stablehlo.pad %arg1, {{.*}}
38+
// CHECK: %[[RHS_SLICE:.*]] = sdy.all_slice [{"y"}, {"x"}] %[[PAD1]]
39+
// CHECK: %[[RHS_ABS:.*]] = stablehlo.abs %[[RHS_SLICE]] {{.*}}
40+
41+
// Enforce zero-padding on LHS contracting dim (dim 1).
42+
// CHECK: %[[LHS_IOTA:.*]] = stablehlo.iota{{.*}}dim = 1
43+
// CHECK: %[[LHS_LIMIT:.*]] = stablehlo.constant dense<7> : tensor<i32>
44+
// CHECK: %[[LHS_LIMIT_BCAST:.*]] = stablehlo.broadcast_in_dim %[[LHS_LIMIT]], dims = []
45+
// CHECK: %[[LHS_MASK:.*]] = stablehlo.compare{{.*}}LT, %[[LHS_IOTA]], %[[LHS_LIMIT_BCAST]]
46+
// CHECK: %[[LHS_CST:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
47+
// CHECK: %[[LHS_BCAST:.*]] = stablehlo.broadcast_in_dim %[[LHS_CST]], dims = []
48+
// CHECK: %[[LHS_SELECT:.*]] = stablehlo.select %[[LHS_MASK]], %[[LHS_ABS]], %[[LHS_BCAST]]
49+
50+
// Enforce zero-padding on RHS contracting dim (dim 0).
51+
// CHECK: %[[RHS_IOTA:.*]] = stablehlo.iota{{.*}}dim = 0
52+
// CHECK: %[[RHS_LIMIT:.*]] = stablehlo.constant dense<7> : tensor<i32>
53+
// CHECK: %[[RHS_LIMIT_BCAST:.*]] = stablehlo.broadcast_in_dim %[[RHS_LIMIT]], dims = []
54+
// CHECK: %[[RHS_MASK:.*]] = stablehlo.compare{{.*}}LT, %[[RHS_IOTA]], %[[RHS_LIMIT_BCAST]]
55+
// CHECK: %[[RHS_CST:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
56+
// CHECK: %[[RHS_BCAST:.*]] = stablehlo.broadcast_in_dim %[[RHS_CST]], dims = []
57+
// CHECK: %[[RHS_SELECT:.*]] = stablehlo.select %[[RHS_MASK]], %[[RHS_ABS]], %[[RHS_BCAST]]
58+
// CHECK-NOT: stablehlo.iota {{.*}} dim = 1
59+
60+
// Perform dot_general and trim result.
61+
// CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[LHS_SELECT]], %[[RHS_SELECT]], contracting_dims = [1] x [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4_2, [{}, {"x"}]>]>} : (tensor<4x8xf32>, tensor<8x8xf32>) -> tensor<4x8xf32>
62+
// CHECK: %[[TRIM:.*]] = stablehlo.slice %[[DOT]] [0:4, 0:5]
63+
// CHECK: return %[[TRIM]]
64+
65+
%0 = sdy.all_slice [{}, {"y"}] %arg0 out_sharding=<@mesh_4_2, [{}, {"y"}]> : tensor<4x7xf32>
66+
%1 = stablehlo.abs %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4_2, [{}, {"y"}]>]>} : tensor<4x7xf32>
67+
%2 = sdy.all_slice [{"y"}, {"x"}] %arg1 out_sharding=<@mesh_4_2, [{"y"}, {"x"}]> : tensor<7x5xf32>
68+
%3 = stablehlo.abs %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4_2, [{"y"}, {"x"}]>]>} : tensor<7x5xf32>
69+
%4 = stablehlo.dot_general %1, %3, contracting_dims = [1] x [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4_2, [{}, {"x"}]>]>} : (tensor<4x7xf32>, tensor<7x5xf32>) -> tensor<4x5xf32>
70+
%5 = stablehlo.slice %4 [0:4, 0:5] : (tensor<4x5xf32>) -> tensor<4x5xf32>
71+
return %5 : tensor<4x5xf32>
72+
}
73+
74+
// CHECK-LABEL: func @padded_non_contracting_dims_any
75+
func.func @padded_non_contracting_dims_any(%arg0: tensor<3x8xf32>, %arg1: tensor<8x5xf32>) -> tensor<3x5xf32> {
76+
// Prepare padded LHS with unknown padding.
77+
// CHECK: %[[PAD0:.*]] = stablehlo.pad %arg0, {{.*}}
78+
// CHECK: %[[LHS_SLICE:.*]] = sdy.all_slice [{"y"}, {}] %[[PAD0]]
79+
// CHECK: %[[LHS_ABS:.*]] = stablehlo.abs %[[LHS_SLICE]] {{.*}}
80+
81+
// Verify no select is generated for non-contracting dim.
82+
// CHECK-NOT: stablehlo.select
83+
// CHECK-NOT: stablehlo.compare
84+
85+
// Prepare padded RHS.
86+
// CHECK: %[[PAD1:.*]] = stablehlo.pad %arg1, {{.*}}
87+
// CHECK: %[[RHS_SLICE:.*]] = sdy.all_slice [{}, {"x"}] %[[PAD1]]
88+
89+
// Perform dot_general and trim result.
90+
// CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[LHS_ABS]], %[[RHS_SLICE]], contracting_dims = [1] x [0] {{.*}}
91+
// CHECK: %[[TRIM:.*]] = stablehlo.slice %[[DOT]] [0:3, 0:5]
92+
// CHECK: return %[[TRIM]]
93+
94+
%0 = sdy.all_slice [{"y"}, {}] %arg0 out_sharding=<@mesh_4_2, [{"y"}, {}]> : tensor<3x8xf32>
95+
%1 = stablehlo.abs %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4_2, [{"y"}, {}]>]>} : tensor<3x8xf32>
96+
%2 = sdy.all_slice [{}, {"x"}] %arg1 out_sharding=<@mesh_4_2, [{}, {"x"}]> : tensor<8x5xf32>
97+
%3 = stablehlo.dot_general %1, %2, contracting_dims = [1] x [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4_2, [{"y"}, {"x"}]>]>} : (tensor<3x8xf32>, tensor<8x5xf32>) -> tensor<3x5xf32>
98+
%4 = stablehlo.slice %3 [0:3, 0:5] : (tensor<3x5xf32>) -> tensor<3x5xf32>
99+
return %4 : tensor<3x5xf32>
100+
}

0 commit comments

Comments
 (0)