@@ -111,6 +111,12 @@ class PaddedTypeConverter : public TypeConverter {
111111// Known padding value kinds for generated padding values.
112112enum class PaddingValueKind { kZero , kOne };
113113
114+ // Returns true if the operation has custom padding handling implemented in
115+ // this file and should be excluded from GenericOpPattern.
116+ bool hasCustomPadHandling (Operation* op) {
117+ return isa<stablehlo::SliceOp, stablehlo::DotGeneralOp>(op);
118+ }
119+
114120class PaddingCache {
115121 public:
116122 // Registers the padding kind for a value. The conversion pattern
@@ -186,6 +192,75 @@ Value createPaddedValue(RankedTensorType paddedType, Value value,
186192 return padOp;
187193}
188194
195+ // Returns 'inputVal' if the value is not padded or the value already has
196+ // 'requiredKind' as PaddingValueKind. Otherwise, uses compare-and-select to
197+ // produce a new padded value from inputVal with the requiredKind padding and
198+ // returns the new value.
199+ //
200+ // We ensure all dimensions that require padding are padded with requireKind
201+ // unless dimsToEnforce is provided, in which case only the specified
202+ // dimensions are padded.
203+ Value ensurePadding (
204+ Value inputVal, RankedTensorType origType, PaddingValueKind requiredKind,
205+ OpBuilder& b, Location loc, PaddingCache& cache,
206+ std::optional<ArrayRef<int64_t >> dimsToEnforce = std::nullopt ) {
207+ // Return early if no padding is applied or the cached padding already
208+ // matches.
209+ auto paddedType = cast<RankedTensorType>(inputVal.getType ());
210+ if (origType == paddedType) {
211+ return inputVal;
212+ }
213+ std::optional<PaddingValueKind> currentKind = cache.getPadding (inputVal);
214+ if (currentKind && *currentKind == requiredKind) {
215+ return inputVal;
216+ }
217+
218+ // Build a mask that is `true` for the original (unpadded) data region.
219+ // An element is in the original region if its index along each padded
220+ // dimension is less than the original unpadded size (index < original_size).
221+ Value validDataMask;
222+ for (auto [dim, origSize] : llvm::enumerate (origType.getShape ())) {
223+ if (origSize == paddedType.getDimSize (dim) ||
224+ (dimsToEnforce && !llvm::is_contained (*dimsToEnforce, dim))) {
225+ continue ;
226+ }
227+ auto iotaType =
228+ RankedTensorType::get (paddedType.getShape (), b.getI32Type ());
229+ Value iota = stablehlo::IotaOp::create (b, loc, iotaType, dim);
230+ Value limit = stablehlo::ConstantOp::create (
231+ b, loc,
232+ DenseElementsAttr::get (RankedTensorType::get ({}, b.getI32Type ()),
233+ b.getI32IntegerAttr (origSize)));
234+ Value broadcastLimit = stablehlo::BroadcastInDimOp::create (
235+ b, loc, iotaType, limit, b.getDenseI64ArrayAttr ({}));
236+ Value mask = stablehlo::CompareOp::create (
237+ b, loc, iota, broadcastLimit, stablehlo::ComparisonDirection::LT);
238+ validDataMask = validDataMask
239+ ? stablehlo::AndOp::create (b, loc, validDataMask, mask)
240+ : mask;
241+ }
242+
243+ if (!validDataMask) {
244+ return inputVal;
245+ }
246+
247+ // Create the constant with the new padding value and broadcast it to the
248+ // same shape as 'inputVal'.
249+ Value newPaddingScalar =
250+ createConstant (b, loc, paddedType.getElementType (), requiredKind);
251+ Value newPaddingValue = stablehlo::BroadcastInDimOp::create (
252+ b, loc, paddedType, newPaddingScalar, b.getDenseI64ArrayAttr ({}));
253+
254+ // Keep the original data from 'inputVal' (where mask is true), and replace
255+ // the padded region with 'newPaddingValue' (where mask is false).
256+ Value select = stablehlo::SelectOp::create (b, loc, validDataMask, inputVal,
257+ newPaddingValue);
258+ if (!dimsToEnforce) {
259+ cache.setPadding (select, requiredKind);
260+ }
261+ return select;
262+ }
263+
189264// Converts op to its local version by replacing its operands with the already
190265// converted operands.
191266LogicalResult padGenericOp (Operation* op, ValueRange operands,
@@ -263,7 +338,7 @@ class GenericOpPattern : public ConversionPattern {
263338 Dialect* dialect = op->getDialect ();
264339 if ((dialect && dialect->getNamespace () != " stablehlo" &&
265340 !isa<sdy::ReturnOp>(op)) ||
266- isa<stablehlo::SliceOp> (op)) {
341+ hasCustomPadHandling (op)) {
267342 return failure ();
268343 }
269344 return padGenericOp (op, operands, rewriter,
@@ -427,15 +502,17 @@ class AllSliceOpPattern : public OpConversionPattern<sdy::AllSliceOp> {
427502 return padGenericOp (op, adaptor.getOperands (), rewriter, converter);
428503 }
429504
430- Value padOp = createPaddedValue (cast<RankedTensorType>(paddedInputType),
431- input, PaddingValueKind::kZero , symbolTable,
432- rewriter, cache);
505+ PaddingValueKind paddingKind = PaddingValueKind::kZero ;
506+ Value padOp =
507+ createPaddedValue (cast<RankedTensorType>(paddedInputType), input,
508+ paddingKind, symbolTable, rewriter, cache);
433509 OperationState state (op->getLoc (), op->getName ());
434510 state.addOperands ({padOp});
435511 state.addTypes (
436512 {getPaddedType (op.getResult ().getType (), outSharding, symbolTable)});
437513 state.addAttributes (op->getAttrs ());
438514 Operation* newOp = rewriter.create (state);
515+ cache.setPadding (newOp->getResult (0 ), paddingKind);
439516
440517 rewriter.replaceOp (op, newOp->getResults ());
441518 return success ();
@@ -489,8 +566,51 @@ class StablehloSliceOpPattern : public OpConversionPattern<stablehlo::SliceOp> {
489566 }
490567};
491568
569+ class StablehloDotGeneralOpPattern
570+ : public OpConversionPattern<stablehlo::DotGeneralOp> {
571+ public:
572+ StablehloDotGeneralOpPattern (TypeConverter& converter, MLIRContext* ctx,
573+ PaddingCache& cache)
574+ : OpConversionPattern(converter, ctx), cache(cache) {}
575+
576+ LogicalResult matchAndRewrite (
577+ stablehlo::DotGeneralOp op, OpAdaptor adaptor,
578+ ConversionPatternRewriter& rewriter) const override {
579+ auto * converter =
580+ static_cast <const PaddedTypeConverter*>(getTypeConverter ());
581+
582+ Location loc = op.getLoc ();
583+ stablehlo::DotDimensionNumbersAttr dimNums = op.getDotDimensionNumbers ();
584+
585+ Value lhs = adaptor.getOperands ()[0 ];
586+ auto lhsOrigType = dyn_cast<RankedTensorType>(op->getOperand (0 ).getType ());
587+ if (!lhsOrigType) {
588+ return failure ();
589+ }
590+ Value paddedLhs =
591+ ensurePadding (lhs, lhsOrigType, PaddingValueKind::kZero , rewriter, loc,
592+ cache, dimNums.getLhsContractingDimensions ());
593+
594+ Value rhs = adaptor.getOperands ()[1 ];
595+ auto rhsOrigType = dyn_cast<RankedTensorType>(op->getOperand (1 ).getType ());
596+ if (!rhsOrigType) {
597+ return failure ();
598+ }
599+ Value paddedRhs =
600+ ensurePadding (rhs, rhsOrigType, PaddingValueKind::kZero , rewriter, loc,
601+ cache, dimNums.getRhsContractingDimensions ());
602+
603+ return padGenericOp (op, {paddedLhs, paddedRhs}, rewriter, converter);
604+ }
605+
606+ private:
607+ PaddingCache& cache;
608+ };
609+
492610struct PadForDivisibilityPass
493611 : public impl::PadForDivisibilityPassBase<PadForDivisibilityPass> {
612+ using PadForDivisibilityPassBase::PadForDivisibilityPassBase;
613+
494614 protected:
495615 void runOnOperation () final {
496616 // FuncOpPattern enforces that function inputs and outputs are always fully
@@ -508,7 +628,8 @@ struct PadForDivisibilityPass
508628 AllGatherOpPattern>(typeConverter, &getContext ());
509629 // Sharing the padding cache reference across pattern instances is safe from
510630 // data races because pattern application within a function is sequential.
511- patterns.add <AllSliceOpPattern>(typeConverter, &getContext (), paddingCache);
631+ patterns.add <AllSliceOpPattern, StablehloDotGeneralOpPattern>(
632+ typeConverter, &getContext (), paddingCache);
512633 ConversionTarget target (getContext ());
513634
514635 auto isLegalType = [&](Type type, TensorShardingAttr sharding) {
0 commit comments