xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp (revision a758bcdbd92efb64a3482eb95d2769d74e33f5bb)
1 //===- Padding.cpp - Padding of Linalg ops --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
16 
17 #define DEBUG_TYPE "linalg-padding"
18 
19 using namespace mlir;
20 using namespace mlir::linalg;
21 
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
23 #define DBGSNL() (llvm::dbgs() << "\n")
24 
25 /// Compute the padded shape of the given operand. The operand is padded to a
26 /// static bounding box according to the specified padding options.
27 static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
28                                         OpOperand *opOperand,
29                                         const LinalgPaddingOptions &options,
30                                         SmallVector<int64_t> &paddedShape,
31                                         bool &alreadyHasRequestedShape) {
32   AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
33   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
34 
35   // Collect the shape dimensions that are a function of "paddingDimensions",
36   // along with the multiple that they should be padded to ("1" if none).
37   alreadyHasRequestedShape = true;
38   DenseMap<int64_t, int64_t> shapeDimToMultiple;
39   for (const auto &dimEn : enumerate(options.paddingDimensions)) {
40     for (const auto &en : enumerate(indexingMap.getResults())) {
41       if (en.value().isFunctionOfDim(dimEn.value())) {
42         int64_t dimSize = shape[en.index()];
43         if (options.padToMultipleOf.has_value()) {
44           shapeDimToMultiple[en.index()] =
45               (*options.padToMultipleOf)[dimEn.index()];
46         } else {
47           shapeDimToMultiple[en.index()] = 1;
48         }
49         if (ShapedType::isDynamic(dimSize)) {
50           alreadyHasRequestedShape = false;
51         } else if (dimSize % shapeDimToMultiple[en.index()] != 0) {
52           alreadyHasRequestedShape = false;
53         }
54       }
55     }
56   }
57 
58   // Helper function to round a number up to a given multiple.
59   auto ceil = [](int64_t val, int64_t multiple) {
60     return ((val + multiple - 1) / multiple) * multiple;
61   };
62 
63   // Upper bound the sizes to obtain a static bounding box.
64   paddedShape.assign(shape.begin(), shape.end());
65   for (int64_t i = 0, e = shape.size(); i < e; ++i) {
66     LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
67     // Skip dimensions that do not require padding.
68     if (!shapeDimToMultiple.contains(i)) {
69       LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
70       continue;
71     }
72     // Otherwise, try to compute a constant upper bound for the size value.
73     FailureOr<int64_t> upperBound =
74         ValueBoundsConstraintSet::computeConstantBound(
75             presburger::BoundType::UB,
76             {opOperand->get(),
77              /*dim=*/i},
78             /*stopCondition=*/nullptr, /*closedUB=*/true);
79     if (failed(upperBound)) {
80       LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
81       return failure();
82     }
83     paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]);
84     LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
85   }
86 
87   return success();
88 }
89 
90 /// Pad the `opOperand` in the "paddingDimensions" using the padding value and
91 /// the nofold flag found in "paddingValues" and "nofoldFlags", respectively.
92 ///
93 /// Exit early and return the `opOperand` value if it already has the requested
94 /// shape. i.e.:
95 /// - static shape
96 /// - nofold is not set
97 /// - dim sizes are multiples of "padToMultipleOf"
98 ///
99 /// Otherwise, try to pad the shape dimensions that match the iterator
100 /// dimensions "paddingDimensions" and return the tensor::PadOp result if
101 /// padding succeeds or failure otherwise.
102 static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
103     RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
104     const LinalgPaddingOptions &options) {
105   assert(
106       (!options.padToMultipleOf.has_value() ||
107        options.padToMultipleOf->size() == options.paddingDimensions.size()) &&
108       "invalid number of elements in padToMultipleOf");
109 
110   // Compute padded shape.
111   SmallVector<int64_t> paddedShape;
112   bool alreadyHasRequestedShape = false;
113   if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape,
114                                 alreadyHasRequestedShape)))
115     return rewriter.notifyMatchFailure(opToPad,
116                                        "--failed to compute padded shape");
117 
118   // Return the unpadded operand if padding to a static shape is not needed and
119   // if the nofold flag is not set.
120   bool nofold = opOperand->getOperandNumber() < options.nofoldFlags.size()
121                     ? bool(options.nofoldFlags[opOperand->getOperandNumber()])
122                     : false;
123   if (!nofold && alreadyHasRequestedShape)
124     return opOperand->get();
125 
126   // Fail if `paddingValues` specifies no padding value.
127   if (opOperand->getOperandNumber() >= options.paddingValues.size()) {
128     return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
129   }
130   Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()];
131 
132   Value paddingValue;
133   if (auto complexTy = dyn_cast<ComplexType>(
134           getElementTypeOrSelf(opOperand->get().getType()))) {
135     auto complexAttr = cast<ArrayAttr>(paddingAttr);
136     paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
137                                                         complexTy, complexAttr);
138   } else {
139     paddingValue = rewriter.create<arith::ConstantOp>(
140         opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
141   }
142 
143   // Pad the operand to the bounding box defined by `paddedShape`.
144   auto paddedTensorType = RankedTensorType::get(
145       paddedShape, getElementTypeOrSelf(opOperand->get()));
146   LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
147                     << paddedTensorType);
148   return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType,
149                                opOperand->get(), paddingValue, nofold);
150 }
151 
152 LogicalResult
153 linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
154                           const LinalgPaddingOptions &constOptions,
155                           LinalgOp &paddedOp, SmallVector<Value> &replacements,
156                           SmallVector<tensor::PadOp> &padOps) {
157   LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
158   Location loc = opToPad->getLoc();
159 
160   LinalgPaddingOptions options(constOptions);
161   // Allow inference of pad values if they are not explicitly specified.
162   // TODO: be mindful about the value depending on the actual operation.
163   if (options.paddingValues.empty()) {
164     SmallVector<Type> types(opToPad->getOperandTypes());
165     llvm::append_range(types, opToPad->getResultTypes());
166     for (Type t : types) {
167       options.paddingValues.push_back(
168           rewriter.getZeroAttr(getElementTypeOrSelf(t)));
169     }
170   }
171 
172   // TODO: there are cases where we may still want to pad to larger sizes.
173   if (!opToPad.hasPureTensorSemantics())
174     return rewriter.notifyMatchFailure(opToPad,
175                                        "expected operation on tensors");
176 
177   OpBuilder::InsertionGuard g(rewriter);
178   // Set IP after op because we also take the dims of the original output.
179   rewriter.setInsertionPointAfter(opToPad);
180 
181   // Make a copy of the shaped operands and update it.
182   SmallVector<Value> newOperands;
183   newOperands.reserve(opToPad->getNumOperands());
184   for (OpOperand &opOperand : opToPad->getOpOperands()) {
185     FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
186         rewriter, opToPad, &opOperand, options);
187     // Exit if `paddingDimensions` cannot be bounded statically.
188     if (failed(paddedOperand)) {
189       LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : "
190                         << opOperand.get() << " -> FAIL\n");
191       return rewriter.notifyMatchFailure(opToPad,
192                                          "operand cannot be bound statically");
193     }
194     newOperands.push_back(*paddedOperand);
195     if (auto padOp = paddedOperand->getDefiningOp<tensor::PadOp>())
196       padOps.push_back(padOp);
197   }
198 
199   ReifiedRankedShapedTypeDims reifiedResultShapes;
200   if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
201     LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
202     return rewriter.notifyMatchFailure(opToPad,
203                                        "failed to reify result shapes");
204   }
205   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
206          "expected same number of results");
207 
208   // Clone `opToPad` to operate on the statically padded shapes.
209   auto resultTensorTypes =
210       ValueRange(newOperands).take_back(opToPad.getNumDpsInits()).getTypes();
211   // clone **should** properly notify the rewriter.
212   paddedOp = clone(rewriter, opToPad, resultTensorTypes, newOperands);
213   LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
214 
215   // Recover the slice out of the new static results. This keeps the original
216   // linalg op around because it uses the dims of the original results.
217   SmallVector<Value> paddedSubtensorResults;
218   paddedSubtensorResults.reserve(opToPad->getNumResults());
219   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
220     Value paddedResult = en.value();
221     int64_t resultNumber = en.index();
222     int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
223     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
224     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
225     paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
226         loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
227         strides));
228   }
229 
230   if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None) {
231     replacements = std::move(paddedSubtensorResults);
232     return success();
233   }
234 
235   // Copy back unpadded results to the original destination (i.e., inits of the
236   // linalg op), so that the destination buffer of the computation does not
237   // change. If the padding folds away, this will materialize as a memcpy
238   // between two identical buffers, which will then also fold away.
239   assert(static_cast<int64_t>(paddedSubtensorResults.size()) ==
240              opToPad.getNumDpsInits() &&
241          "expected matching number of results");
242   for (auto it :
243        llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) {
244     if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) {
245       replacements.push_back(rewriter
246                                  .create<linalg::CopyOp>(loc, std::get<0>(it),
247                                                          std::get<1>(it).get())
248                                  .getResult(0));
249     } else if (options.copyBackOp ==
250                LinalgPaddingOptions::CopyBackOp::
251                    BufferizationMaterializeInDestination) {
252       replacements.push_back(
253           rewriter
254               .create<bufferization::MaterializeInDestinationOp>(
255                   loc, std::get<0>(it), std::get<1>(it).get())
256               ->getResult(0));
257     } else {
258       llvm_unreachable("unsupported copy back op");
259     }
260   }
261   return success();
262 }
263 
264 FailureOr<LinalgOp>
265 mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp,
266                                   const LinalgPaddingOptions &options) {
267   assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None &&
268          "invalid options");
269 
270   if (!linalgOp.hasPureTensorSemantics())
271     return rewriter.notifyMatchFailure(
272         linalgOp, "only applies to Linalg ops with tensor semantics");
273 
274   // Pad the operation.
275   LinalgOp paddedOp;
276   SmallVector<Value> newResults;
277   SmallVector<tensor::PadOp> padOps;
278   if (failed(rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp,
279                                newResults, padOps)))
280     return rewriter.notifyMatchFailure(linalgOp,
281                                        "failed to rewrite as a padded op");
282 
283   // Hoist the padding.
284   for (const auto &en : enumerate(options.hoistPaddings)) {
285     if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands())
286       break;
287     OpOperand &opOperand = paddedOp->getOpOperand(en.index());
288     auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
289     if (!padOp || en.value() == 0) {
290       (void)rewriter.notifyMatchFailure(linalgOp, "not a tensor.pad -- skip");
291       continue;
292     }
293 
294     // Fail hoisting if the operand shape is not fully static.
295     if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) {
296       (void)rewriter.notifyMatchFailure(linalgOp,
297                                         "non static padding shape -- skip");
298       continue;
299     }
300 
301     tensor::PadOp hoistedOp;
302     SmallVector<TransposeOp> transposeOps;
303     SmallVector<int64_t> transposeVector =
304         en.index() < options.transposePaddings.size()
305             ? options.transposePaddings[en.index()]
306             : SmallVector<int64_t>{};
307 
308     FailureOr<Value> newResult = hoistPaddingOnTensors(
309         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
310     if (failed(newResult)) {
311       (void)rewriter.notifyMatchFailure(linalgOp,
312                                         "failed to apply hoistPadding");
313       continue;
314     }
315     rewriter.replaceOp(padOp, *newResult);
316   }
317 
318   // Replace the original operation to pad.
319   rewriter.replaceOp(linalgOp, newResults);
320 
321   return paddedOp;
322 }
323