xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp (revision a758bcdbd92efb64a3482eb95d2769d74e33f5bb)
102c66207SMatthias Springer //===- Padding.cpp - Padding of Linalg ops --------------------------------===//
202c66207SMatthias Springer //
302c66207SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
402c66207SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
502c66207SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
602c66207SMatthias Springer //
702c66207SMatthias Springer //===----------------------------------------------------------------------===//
802c66207SMatthias Springer 
902c66207SMatthias Springer #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1002c66207SMatthias Springer 
11431c49d6SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12a5cee3e3SRobert Suderman #include "mlir/Dialect/Complex/IR/Complex.h"
1302c66207SMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
1402c66207SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1502c66207SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h"
1602c66207SMatthias Springer 
1702c66207SMatthias Springer #define DEBUG_TYPE "linalg-padding"
1802c66207SMatthias Springer 
1902c66207SMatthias Springer using namespace mlir;
2002c66207SMatthias Springer using namespace mlir::linalg;
2102c66207SMatthias Springer 
2202c66207SMatthias Springer #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
2302c66207SMatthias Springer #define DBGSNL() (llvm::dbgs() << "\n")
2402c66207SMatthias Springer 
25369c5d6aSMatthias Springer /// Compute the padded shape of the given operand. The operand is padded to a
267c4c2746SJaved Absar /// static bounding box according to the specified padding options.
27369c5d6aSMatthias Springer static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
28369c5d6aSMatthias Springer                                         OpOperand *opOperand,
29369c5d6aSMatthias Springer                                         const LinalgPaddingOptions &options,
30369c5d6aSMatthias Springer                                         SmallVector<int64_t> &paddedShape,
31369c5d6aSMatthias Springer                                         bool &alreadyHasRequestedShape) {
3202c66207SMatthias Springer   AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
3302c66207SMatthias Springer   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
3402c66207SMatthias Springer 
35369c5d6aSMatthias Springer   // Collect the shape dimensions that are a function of "paddingDimensions",
3602c66207SMatthias Springer   // along with the multiple that they should be padded to ("1" if none).
37369c5d6aSMatthias Springer   alreadyHasRequestedShape = true;
3802c66207SMatthias Springer   DenseMap<int64_t, int64_t> shapeDimToMultiple;
39369c5d6aSMatthias Springer   for (const auto &dimEn : enumerate(options.paddingDimensions)) {
4002c66207SMatthias Springer     for (const auto &en : enumerate(indexingMap.getResults())) {
4102c66207SMatthias Springer       if (en.value().isFunctionOfDim(dimEn.value())) {
4202c66207SMatthias Springer         int64_t dimSize = shape[en.index()];
43369c5d6aSMatthias Springer         if (options.padToMultipleOf.has_value()) {
44369c5d6aSMatthias Springer           shapeDimToMultiple[en.index()] =
45369c5d6aSMatthias Springer               (*options.padToMultipleOf)[dimEn.index()];
46369c5d6aSMatthias Springer         } else {
47369c5d6aSMatthias Springer           shapeDimToMultiple[en.index()] = 1;
48369c5d6aSMatthias Springer         }
4902c66207SMatthias Springer         if (ShapedType::isDynamic(dimSize)) {
5002c66207SMatthias Springer           alreadyHasRequestedShape = false;
5102c66207SMatthias Springer         } else if (dimSize % shapeDimToMultiple[en.index()] != 0) {
5202c66207SMatthias Springer           alreadyHasRequestedShape = false;
5302c66207SMatthias Springer         }
5402c66207SMatthias Springer       }
5502c66207SMatthias Springer     }
5602c66207SMatthias Springer   }
5702c66207SMatthias Springer 
5802c66207SMatthias Springer   // Helper function to round a number up to a given multiple.
5902c66207SMatthias Springer   auto ceil = [](int64_t val, int64_t multiple) {
6002c66207SMatthias Springer     return ((val + multiple - 1) / multiple) * multiple;
6102c66207SMatthias Springer   };
6202c66207SMatthias Springer 
6302c66207SMatthias Springer   // Upper bound the sizes to obtain a static bounding box.
64369c5d6aSMatthias Springer   paddedShape.assign(shape.begin(), shape.end());
6502c66207SMatthias Springer   for (int64_t i = 0, e = shape.size(); i < e; ++i) {
6602c66207SMatthias Springer     LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
6702c66207SMatthias Springer     // Skip dimensions that do not require padding.
6802c66207SMatthias Springer     if (!shapeDimToMultiple.contains(i)) {
6902c66207SMatthias Springer       LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
7002c66207SMatthias Springer       continue;
7102c66207SMatthias Springer     }
7202c66207SMatthias Springer     // Otherwise, try to compute a constant upper bound for the size value.
7302c66207SMatthias Springer     FailureOr<int64_t> upperBound =
7402c66207SMatthias Springer         ValueBoundsConstraintSet::computeConstantBound(
7540dd3aa9SMatthias Springer             presburger::BoundType::UB,
7640dd3aa9SMatthias Springer             {opOperand->get(),
7740dd3aa9SMatthias Springer              /*dim=*/i},
7840dd3aa9SMatthias Springer             /*stopCondition=*/nullptr, /*closedUB=*/true);
7902c66207SMatthias Springer     if (failed(upperBound)) {
807c4c2746SJaved Absar       LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
81369c5d6aSMatthias Springer       return failure();
8202c66207SMatthias Springer     }
8302c66207SMatthias Springer     paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]);
8402c66207SMatthias Springer     LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
8502c66207SMatthias Springer   }
8602c66207SMatthias Springer 
87369c5d6aSMatthias Springer   return success();
88369c5d6aSMatthias Springer }
89369c5d6aSMatthias Springer 
90369c5d6aSMatthias Springer /// Pad the `opOperand` in the "paddingDimensions" using the padding value and
91*a758bcdbSAndrzej Warzyński /// the nofold flag found in "paddingValues" and "nofoldFlags", respectively.
92369c5d6aSMatthias Springer ///
93369c5d6aSMatthias Springer /// Exit early and return the `opOperand` value if it already has the requested
947c4c2746SJaved Absar /// shape. i.e.:
95369c5d6aSMatthias Springer /// - static shape
96369c5d6aSMatthias Springer /// - nofold is not set
97369c5d6aSMatthias Springer /// - dim sizes are multiples of "padToMultipleOf"
98369c5d6aSMatthias Springer ///
99369c5d6aSMatthias Springer /// Otherwise, try to pad the shape dimensions that match the iterator
100369c5d6aSMatthias Springer /// dimensions "paddingDimensions" and return the tensor::PadOp result if
101369c5d6aSMatthias Springer /// padding succeeds or failure otherwise.
102369c5d6aSMatthias Springer static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
103369c5d6aSMatthias Springer     RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
104369c5d6aSMatthias Springer     const LinalgPaddingOptions &options) {
1052049b2adSLorenzo Chelini   assert(
1062049b2adSLorenzo Chelini       (!options.padToMultipleOf.has_value() ||
1072049b2adSLorenzo Chelini        options.padToMultipleOf->size() == options.paddingDimensions.size()) &&
108369c5d6aSMatthias Springer       "invalid number of elements in padToMultipleOf");
109369c5d6aSMatthias Springer 
110369c5d6aSMatthias Springer   // Compute padded shape.
111369c5d6aSMatthias Springer   SmallVector<int64_t> paddedShape;
112369c5d6aSMatthias Springer   bool alreadyHasRequestedShape = false;
113369c5d6aSMatthias Springer   if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape,
114369c5d6aSMatthias Springer                                 alreadyHasRequestedShape)))
115369c5d6aSMatthias Springer     return rewriter.notifyMatchFailure(opToPad,
116369c5d6aSMatthias Springer                                        "--failed to compute padded shape");
117369c5d6aSMatthias Springer 
118369c5d6aSMatthias Springer   // Return the unpadded operand if padding to a static shape is not needed and
119369c5d6aSMatthias Springer   // if the nofold flag is not set.
120*a758bcdbSAndrzej Warzyński   bool nofold = opOperand->getOperandNumber() < options.nofoldFlags.size()
121*a758bcdbSAndrzej Warzyński                     ? bool(options.nofoldFlags[opOperand->getOperandNumber()])
122369c5d6aSMatthias Springer                     : false;
123369c5d6aSMatthias Springer   if (!nofold && alreadyHasRequestedShape)
124369c5d6aSMatthias Springer     return opOperand->get();
125369c5d6aSMatthias Springer 
126369c5d6aSMatthias Springer   // Fail if `paddingValues` specifies no padding value.
127369c5d6aSMatthias Springer   if (opOperand->getOperandNumber() >= options.paddingValues.size()) {
128369c5d6aSMatthias Springer     return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
129369c5d6aSMatthias Springer   }
130369c5d6aSMatthias Springer   Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()];
131a5cee3e3SRobert Suderman 
132a5cee3e3SRobert Suderman   Value paddingValue;
133a5cee3e3SRobert Suderman   if (auto complexTy = dyn_cast<ComplexType>(
134a5cee3e3SRobert Suderman           getElementTypeOrSelf(opOperand->get().getType()))) {
135a5cee3e3SRobert Suderman     auto complexAttr = cast<ArrayAttr>(paddingAttr);
136a5cee3e3SRobert Suderman     paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
137a5cee3e3SRobert Suderman                                                         complexTy, complexAttr);
138a5cee3e3SRobert Suderman   } else {
139a5cee3e3SRobert Suderman     paddingValue = rewriter.create<arith::ConstantOp>(
140369c5d6aSMatthias Springer         opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
141a5cee3e3SRobert Suderman   }
142369c5d6aSMatthias Springer 
14302c66207SMatthias Springer   // Pad the operand to the bounding box defined by `paddedShape`.
14402c66207SMatthias Springer   auto paddedTensorType = RankedTensorType::get(
14502c66207SMatthias Springer       paddedShape, getElementTypeOrSelf(opOperand->get()));
14602c66207SMatthias Springer   LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
14702c66207SMatthias Springer                     << paddedTensorType);
14802c66207SMatthias Springer   return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType,
14902c66207SMatthias Springer                                opOperand->get(), paddingValue, nofold);
15002c66207SMatthias Springer }
15102c66207SMatthias Springer 
1520e06ec59SMatthias Springer LogicalResult
15302c66207SMatthias Springer linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
1541e84e91eSNicolas Vasilache                           const LinalgPaddingOptions &constOptions,
1550e06ec59SMatthias Springer                           LinalgOp &paddedOp, SmallVector<Value> &replacements,
156977cb4fdSMatthias Springer                           SmallVector<tensor::PadOp> &padOps) {
15702c66207SMatthias Springer   LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
15802c66207SMatthias Springer   Location loc = opToPad->getLoc();
15902c66207SMatthias Springer 
1601e84e91eSNicolas Vasilache   LinalgPaddingOptions options(constOptions);
1611e84e91eSNicolas Vasilache   // Allow inference of pad values if they are not explicitly specified.
1621e84e91eSNicolas Vasilache   // TODO: be mindful about the value depending on the actual operation.
1631e84e91eSNicolas Vasilache   if (options.paddingValues.empty()) {
1641e84e91eSNicolas Vasilache     SmallVector<Type> types(opToPad->getOperandTypes());
1651e84e91eSNicolas Vasilache     llvm::append_range(types, opToPad->getResultTypes());
1661e84e91eSNicolas Vasilache     for (Type t : types) {
1671e84e91eSNicolas Vasilache       options.paddingValues.push_back(
1681e84e91eSNicolas Vasilache           rewriter.getZeroAttr(getElementTypeOrSelf(t)));
1691e84e91eSNicolas Vasilache     }
1701e84e91eSNicolas Vasilache   }
1711e84e91eSNicolas Vasilache 
17202c66207SMatthias Springer   // TODO: there are cases where we may still want to pad to larger sizes.
1730a8e3dd4SMatthias Springer   if (!opToPad.hasPureTensorSemantics())
17402c66207SMatthias Springer     return rewriter.notifyMatchFailure(opToPad,
17502c66207SMatthias Springer                                        "expected operation on tensors");
17602c66207SMatthias Springer 
17702c66207SMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
17802c66207SMatthias Springer   // Set IP after op because we also take the dims of the original output.
17902c66207SMatthias Springer   rewriter.setInsertionPointAfter(opToPad);
18002c66207SMatthias Springer 
18102c66207SMatthias Springer   // Make a copy of the shaped operands and update it.
18202c66207SMatthias Springer   SmallVector<Value> newOperands;
18302c66207SMatthias Springer   newOperands.reserve(opToPad->getNumOperands());
18402c66207SMatthias Springer   for (OpOperand &opOperand : opToPad->getOpOperands()) {
18502c66207SMatthias Springer     FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
186369c5d6aSMatthias Springer         rewriter, opToPad, &opOperand, options);
18702c66207SMatthias Springer     // Exit if `paddingDimensions` cannot be bounded statically.
18802c66207SMatthias Springer     if (failed(paddedOperand)) {
18902c66207SMatthias Springer       LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : "
19002c66207SMatthias Springer                         << opOperand.get() << " -> FAIL\n");
19102c66207SMatthias Springer       return rewriter.notifyMatchFailure(opToPad,
19202c66207SMatthias Springer                                          "operand cannot be bound statically");
19302c66207SMatthias Springer     }
19402c66207SMatthias Springer     newOperands.push_back(*paddedOperand);
1950e06ec59SMatthias Springer     if (auto padOp = paddedOperand->getDefiningOp<tensor::PadOp>())
1960e06ec59SMatthias Springer       padOps.push_back(padOp);
19702c66207SMatthias Springer   }
19802c66207SMatthias Springer 
19902c66207SMatthias Springer   ReifiedRankedShapedTypeDims reifiedResultShapes;
20002c66207SMatthias Springer   if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
20102c66207SMatthias Springer     LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
20202c66207SMatthias Springer     return rewriter.notifyMatchFailure(opToPad,
20302c66207SMatthias Springer                                        "failed to reify result shapes");
20402c66207SMatthias Springer   }
20502c66207SMatthias Springer   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
20602c66207SMatthias Springer          "expected same number of results");
20702c66207SMatthias Springer 
20802c66207SMatthias Springer   // Clone `opToPad` to operate on the statically padded shapes.
20902c66207SMatthias Springer   auto resultTensorTypes =
21002c66207SMatthias Springer       ValueRange(newOperands).take_back(opToPad.getNumDpsInits()).getTypes();
21102c66207SMatthias Springer   // clone **should** properly notify the rewriter.
21202c66207SMatthias Springer   paddedOp = clone(rewriter, opToPad, resultTensorTypes, newOperands);
21302c66207SMatthias Springer   LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
21402c66207SMatthias Springer 
21502c66207SMatthias Springer   // Recover the slice out of the new static results. This keeps the original
21602c66207SMatthias Springer   // linalg op around because it uses the dims of the original results.
21702c66207SMatthias Springer   SmallVector<Value> paddedSubtensorResults;
21802c66207SMatthias Springer   paddedSubtensorResults.reserve(opToPad->getNumResults());
21902c66207SMatthias Springer   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
22002c66207SMatthias Springer     Value paddedResult = en.value();
22102c66207SMatthias Springer     int64_t resultNumber = en.index();
22202c66207SMatthias Springer     int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
22302c66207SMatthias Springer     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
22402c66207SMatthias Springer     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
22502c66207SMatthias Springer     paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
22602c66207SMatthias Springer         loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
22702c66207SMatthias Springer         strides));
22802c66207SMatthias Springer   }
229431c49d6SMatthias Springer 
230977cb4fdSMatthias Springer   if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None) {
2310e06ec59SMatthias Springer     replacements = std::move(paddedSubtensorResults);
2320e06ec59SMatthias Springer     return success();
2330e06ec59SMatthias Springer   }
234431c49d6SMatthias Springer 
235431c49d6SMatthias Springer   // Copy back unpadded results to the original destination (i.e., inits of the
236431c49d6SMatthias Springer   // linalg op), so that the destination buffer of the computation does not
2377c4c2746SJaved Absar   // change. If the padding folds away, this will materialize as a memcpy
238431c49d6SMatthias Springer   // between two identical buffers, which will then also fold away.
2390e06ec59SMatthias Springer   assert(static_cast<int64_t>(paddedSubtensorResults.size()) ==
2400e06ec59SMatthias Springer              opToPad.getNumDpsInits() &&
2410e06ec59SMatthias Springer          "expected matching number of results");
242431c49d6SMatthias Springer   for (auto it :
2430b2197b0SMatthias Springer        llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) {
244977cb4fdSMatthias Springer     if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) {
245977cb4fdSMatthias Springer       replacements.push_back(rewriter
246977cb4fdSMatthias Springer                                  .create<linalg::CopyOp>(loc, std::get<0>(it),
2470b2197b0SMatthias Springer                                                          std::get<1>(it).get())
248977cb4fdSMatthias Springer                                  .getResult(0));
249977cb4fdSMatthias Springer     } else if (options.copyBackOp ==
25091464e1dSMatthias Springer                LinalgPaddingOptions::CopyBackOp::
25191464e1dSMatthias Springer                    BufferizationMaterializeInDestination) {
25291464e1dSMatthias Springer       replacements.push_back(
2530fcaca2fSMatthias Springer           rewriter
2540fcaca2fSMatthias Springer               .create<bufferization::MaterializeInDestinationOp>(
2550fcaca2fSMatthias Springer                   loc, std::get<0>(it), std::get<1>(it).get())
2560fcaca2fSMatthias Springer               ->getResult(0));
257977cb4fdSMatthias Springer     } else {
258977cb4fdSMatthias Springer       llvm_unreachable("unsupported copy back op");
259977cb4fdSMatthias Springer     }
260431c49d6SMatthias Springer   }
2610e06ec59SMatthias Springer   return success();
26202c66207SMatthias Springer }
26302c66207SMatthias Springer 
26402c66207SMatthias Springer FailureOr<LinalgOp>
26502c66207SMatthias Springer mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp,
266369c5d6aSMatthias Springer                                   const LinalgPaddingOptions &options) {
267977cb4fdSMatthias Springer   assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None &&
268977cb4fdSMatthias Springer          "invalid options");
269977cb4fdSMatthias Springer 
2700a8e3dd4SMatthias Springer   if (!linalgOp.hasPureTensorSemantics())
27102c66207SMatthias Springer     return rewriter.notifyMatchFailure(
27202c66207SMatthias Springer         linalgOp, "only applies to Linalg ops with tensor semantics");
27302c66207SMatthias Springer 
27402c66207SMatthias Springer   // Pad the operation.
27502c66207SMatthias Springer   LinalgOp paddedOp;
2760e06ec59SMatthias Springer   SmallVector<Value> newResults;
2770e06ec59SMatthias Springer   SmallVector<tensor::PadOp> padOps;
2780e06ec59SMatthias Springer   if (failed(rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp,
279977cb4fdSMatthias Springer                                newResults, padOps)))
28002c66207SMatthias Springer     return rewriter.notifyMatchFailure(linalgOp,
28102c66207SMatthias Springer                                        "failed to rewrite as a padded op");
28202c66207SMatthias Springer 
28302c66207SMatthias Springer   // Hoist the padding.
28402c66207SMatthias Springer   for (const auto &en : enumerate(options.hoistPaddings)) {
28502c66207SMatthias Springer     if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands())
28602c66207SMatthias Springer       break;
28702c66207SMatthias Springer     OpOperand &opOperand = paddedOp->getOpOperand(en.index());
28802c66207SMatthias Springer     auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
28902c66207SMatthias Springer     if (!padOp || en.value() == 0) {
29002c66207SMatthias Springer       (void)rewriter.notifyMatchFailure(linalgOp, "not a tensor.pad -- skip");
29102c66207SMatthias Springer       continue;
29202c66207SMatthias Springer     }
29302c66207SMatthias Springer 
29402c66207SMatthias Springer     // Fail hoisting if the operand shape is not fully static.
29502c66207SMatthias Springer     if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) {
29602c66207SMatthias Springer       (void)rewriter.notifyMatchFailure(linalgOp,
29702c66207SMatthias Springer                                         "non static padding shape -- skip");
29802c66207SMatthias Springer       continue;
29902c66207SMatthias Springer     }
30002c66207SMatthias Springer 
30102c66207SMatthias Springer     tensor::PadOp hoistedOp;
30228039055SHugo Trachino     SmallVector<TransposeOp> transposeOps;
30302c66207SMatthias Springer     SmallVector<int64_t> transposeVector =
30402c66207SMatthias Springer         en.index() < options.transposePaddings.size()
30502c66207SMatthias Springer             ? options.transposePaddings[en.index()]
30602c66207SMatthias Springer             : SmallVector<int64_t>{};
30702c66207SMatthias Springer 
30802c66207SMatthias Springer     FailureOr<Value> newResult = hoistPaddingOnTensors(
30902c66207SMatthias Springer         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
31002c66207SMatthias Springer     if (failed(newResult)) {
31102c66207SMatthias Springer       (void)rewriter.notifyMatchFailure(linalgOp,
31202c66207SMatthias Springer                                         "failed to apply hoistPadding");
31302c66207SMatthias Springer       continue;
31402c66207SMatthias Springer     }
31502c66207SMatthias Springer     rewriter.replaceOp(padOp, *newResult);
31602c66207SMatthias Springer   }
31702c66207SMatthias Springer 
31802c66207SMatthias Springer   // Replace the original operation to pad.
3190e06ec59SMatthias Springer   rewriter.replaceOp(linalgOp, newResults);
32002c66207SMatthias Springer 
32102c66207SMatthias Springer   return paddedOp;
32202c66207SMatthias Springer }
323