xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Passes.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Arith/Utils/Utils.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
19 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/Support/LLVM.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/FoldUtils.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS
31 #define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS
32 #define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS
33 #include "mlir/Dialect/Linalg/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 using namespace mlir::linalg;
38 
39 static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc,
40                                                      AffineMap map,
41                                                      ArrayRef<Value> vals) {
42   if (map.isEmpty())
43     return {};
44 
45   assert(map.getNumInputs() == vals.size());
46   SmallVector<Value> res;
47   res.reserve(map.getNumResults());
48   auto dims = map.getNumDims();
49   for (auto e : map.getResults()) {
50     auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
51     SmallVector<Value> operands(vals);
52     affine::canonicalizeMapAndOperands(&exprMap, &operands);
53     res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands));
54   }
55   return res;
56 }
57 
58 template <typename LoadOpTy, typename StoreOpTy, typename OpType>
59 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op,
60                                      ArrayRef<Value> indexedValues,
61                                      ArrayRef<SmallVector<Value>> indexing,
62                                      ArrayRef<Value> outputBuffers) {
63   auto &block = op->getRegion(0).front();
64   IRMapping map;
65   map.map(block.getArguments(), indexedValues);
66   for (auto &op : block.without_terminator()) {
67     auto *newOp = b.clone(op, map);
68     map.map(op.getResults(), newOp->getResults());
69   }
70 
71   Operation *terminator = block.getTerminator();
72   for (OpOperand &operand : terminator->getOpOperands()) {
73     Value toStore = map.lookupOrDefault(operand.get());
74     b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
75                         indexing[operand.getOperandNumber()]);
76   }
77 }
78 
79 // Returns a pair that contains input indices and output indices of a
80 // SingleInputPoolingOp `op`.
81 struct InputAndOutputIndices {
82   SmallVector<Value> inputs;
83   SmallVector<Value> outputs;
84 };
85 template <typename SingleInputPoolingOp>
86 static InputAndOutputIndices
87 getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs,
88                          SingleInputPoolingOp op) {
89   auto mapsRange = op.getIndexingMapsArray();
90   auto maps = llvm::to_vector<8>(
91       llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
92   return InputAndOutputIndices{
93       makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
94       makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
95 }
96 
97 /// Emits the MLIR for the scalar part of the generic op by:
98 ///   1. Emitting load ops for each input and output view in order. This is
99 ///      achieved by applying the appropriate input or output map to the
100 ///      enclosing induction variables.
101 ///   2. Emitting a call to `op.fun()` that takes as arguments the scalars
102 ///      from point 1. above.
103 ///   3. Emitting store ops to store the results of 2. to the output
104 ///      views.
105 ///
106 /// An example output may resemble:
107 ///
108 /// ```
109 ///    scf.for %i = %c0 to %0 step %c1 {
110 ///      scf.for %j = %c0 to %1 step %c1 {
111 ///        scf.for %k = %c0 to %4 step %c1 {
112 ///          %11 = load %arg0[%i, %j] :
113 ///            memref<?x?xf32, stride_specification>
114 ///          %12 = load %arg1[%i, %j, %k] :
115 ///            memref<?x?x?xf32, stride_specification>
116 ///          %13 = load %arg2[%i, %k, %j] :
117 ///            memref<?x?x?xf32, stride_specification>
118 ///          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
119 ///          store %14#0, %arg1[%i, %j, %k] :
120 ///            memref<?x?x?Xf32, stride_specification>
121 ///          store %14#1, %arg2[%i, %k, %j] :
122 ///            memref<?x?x?Xf32, stride_specification>
123 ///       }
124 ///      }
125 ///    }
126 /// ```
127 template <typename LoadOpTy, typename StoreOpTy>
128 static void emitScalarImplementation(OpBuilder &b, Location loc,
129                                      ArrayRef<Value> allIvs,
130                                      LinalgOp linalgOp) {
131   assert(linalgOp.hasPureBufferSemantics() &&
132          "expected linalg op with buffer semantics");
133   SmallVector<Value> indexedValues;
134   indexedValues.reserve(linalgOp->getNumOperands());
135 
136   auto allIvsPlusDims = SmallVector<Value>(allIvs);
137 
138   // TODO: Avoid the loads if the corresponding argument of the
139   // region has no uses.
140   // 1.a. Emit load from input operand or for scalars access the operand itself.
141   for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
142     if (linalgOp.isScalar(inputOperand)) {
143       indexedValues.push_back(inputOperand->get());
144       continue;
145     }
146     auto indexing = makeCanonicalAffineApplies(
147         b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
148     indexedValues.push_back(
149         b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
150   }
151   // 1.b. Emit load from output views.
152   for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
153     SmallVector<Value> indexing = makeCanonicalAffineApplies(
154         b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
155         allIvsPlusDims);
156     indexedValues.push_back(
157         b.create<LoadOpTy>(loc, outputOperand.get(), indexing));
158   }
159 
160   // TODO: When a region inliner exists, use it.
161   // 2. Inline region, currently only works for a single basic block.
162   // 3. Emit store.
163   SmallVector<SmallVector<Value>, 8> indexing;
164   SmallVector<Value> outputBuffers;
165   for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
166     if (!isa<MemRefType>(outputOperand.get().getType()))
167       continue;
168     indexing.push_back(makeCanonicalAffineApplies(
169         b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
170         allIvsPlusDims));
171     outputBuffers.push_back(outputOperand.get());
172   }
173   inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
174                                                 indexing, outputBuffers);
175 }
176 
177 /// Replace the index operations in the body of the loop nest by the matching
178 /// induction variables.
179 static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
180                                                 LinalgOp linalgOp,
181                                                 ArrayRef<Operation *> loopOps) {
182   // Extract the induction variables of the loop nest from outer to inner.
183   SmallVector<Value> allIvs;
184   for (Operation *loopOp : loopOps) {
185     llvm::TypeSwitch<Operation *>(loopOp)
186         .Case([&](scf::ParallelOp parallelOp) {
187           allIvs.append(parallelOp.getInductionVars());
188         })
189         .Case([&](scf::ForOp forOp) {
190           allIvs.push_back(forOp.getInductionVar());
191         })
192         .Case([&](affine::AffineForOp affineForOp) {
193           allIvs.push_back(affineForOp.getInductionVar());
194         })
195         .Default([&](Operation *op) { assert(false && "unexpected op"); });
196   }
197   assert(linalgOp.getNumLoops() == allIvs.size() &&
198          "expected the number of loops and induction variables to match");
199   // Replace the index operations in the body of the innermost loop op.
200   if (!loopOps.empty()) {
201     auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
202     for (Region *r : loopOp.getLoopRegions())
203       for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
204         rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
205   }
206 }
207 
208 template <typename LoopTy>
209 static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
210                                                   LinalgOp linalgOp) {
211   using LoadOpTy =
212       std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
213                          affine::AffineLoadOp, memref::LoadOp>;
214   using StoreOpTy =
215       std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value,
216                          affine::AffineStoreOp, memref::StoreOp>;
217 
218   // The flattened loopToOperandRangesMaps is expected to be an invertible
219   // permutation map (which is asserted in the inverse calculation).
220   assert(linalgOp.hasPureBufferSemantics() &&
221          "expected linalg op with buffer semantics");
222 
223   auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
224   auto iteratorTypes = linalgOp.getIteratorTypesArray();
225 
226   SmallVector<Value> allIvs;
227   GenerateLoopNest<LoopTy>::doit(
228       rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
229       [&](OpBuilder &b, Location loc, ValueRange ivs,
230           ValueRange operandValuesToUse) -> scf::ValueVector {
231         assert(operandValuesToUse == linalgOp->getOperands() &&
232                "expect operands are captured and not passed by loop argument");
233         allIvs.append(ivs.begin(), ivs.end());
234         emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
235         return scf::ValueVector{};
236       });
237   // Number of loop ops might be different from the number of ivs since some
238   // loops like affine.parallel and scf.parallel have multiple ivs.
239   SetVector<Operation *> loopSet;
240   for (Value iv : allIvs) {
241     if (!iv)
242       return failure();
243     // The induction variable is a block argument of the entry block of the
244     // loop operation.
245     BlockArgument ivVal = dyn_cast<BlockArgument>(iv);
246     if (!ivVal)
247       return failure();
248     loopSet.insert(ivVal.getOwner()->getParentOp());
249   }
250   LinalgLoops loops(loopSet.begin(), loopSet.end());
251   // Replace all index operations in the loop body.
252   replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops);
253   return loops;
254 }
255 
256 namespace {
257 template <typename LoopType>
258 class LinalgRewritePattern : public RewritePattern {
259 public:
260   LinalgRewritePattern(MLIRContext *context)
261       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
262 
263   LogicalResult matchAndRewrite(Operation *op,
264                                 PatternRewriter &rewriter) const override {
265     auto linalgOp = dyn_cast<LinalgOp>(op);
266     if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
267       return rewriter.notifyMatchFailure(
268           op, "expected linalg op with buffer semantics");
269     }
270     if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
271       return failure();
272     rewriter.eraseOp(op);
273     return success();
274   }
275 };
276 
277 /// Local folding pattern for AffineApplyOp that we can apply greedily.
278 /// This replaces AffineApplyOp by the proper value in cases where the
279 /// associated map is trivial.
280 /// A trivial map here is defined as a map with a single result and either:
281 ///   1. Zero operand + returns a single AffineConstantExpr
282 ///   2. One operand + returns a single AffineDimExpr
283 ///   3. One operand + returns a single AffineSymbolExpr
284 //
285 /// In the first case, the AffineApplyOp is replaced by a new constant. In the
286 /// other cases, it is replaced by its unique operand.
287 struct FoldAffineOp : public RewritePattern {
288   FoldAffineOp(MLIRContext *context)
289       : RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {}
290 
291   LogicalResult matchAndRewrite(Operation *op,
292                                 PatternRewriter &rewriter) const override {
293     auto affineApplyOp = cast<affine::AffineApplyOp>(op);
294     auto map = affineApplyOp.getAffineMap();
295     if (map.getNumResults() != 1 || map.getNumInputs() > 1)
296       return failure();
297 
298     AffineExpr expr = map.getResult(0);
299     if (map.getNumInputs() == 0) {
300       if (auto val = dyn_cast<AffineConstantExpr>(expr)) {
301         rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue());
302         return success();
303       }
304       return failure();
305     }
306     if (dyn_cast<AffineDimExpr>(expr) || dyn_cast<AffineSymbolExpr>(expr)) {
307       rewriter.replaceOp(op, op->getOperand(0));
308       return success();
309     }
310     return failure();
311   }
312 };
313 
314 template <typename LoopType>
315 static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
316   MLIRContext *context = enclosingOp->getContext();
317   RewritePatternSet patterns(context);
318   patterns.add<LinalgRewritePattern<LoopType>>(context);
319   memref::DimOp::getCanonicalizationPatterns(patterns, context);
320   tensor::DimOp::getCanonicalizationPatterns(patterns, context);
321   affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
322   patterns.add<FoldAffineOp>(context);
323   // Just apply the patterns greedily.
324   (void)applyPatternsGreedily(enclosingOp, std::move(patterns));
325 }
326 
327 struct LowerToAffineLoops
328     : public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> {
329   using impl::ConvertLinalgToAffineLoopsPassBase<
330       LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase;
331   void getDependentDialects(DialectRegistry &registry) const override {
332     registry.insert<memref::MemRefDialect>();
333   }
334   void runOnOperation() override {
335     lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation());
336   }
337 };
338 
339 struct LowerToLoops : public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> {
340   using impl::ConvertLinalgToLoopsPassBase<
341       LowerToLoops>::ConvertLinalgToLoopsPassBase;
342   void getDependentDialects(DialectRegistry &registry) const override {
343     registry.insert<memref::MemRefDialect, scf::SCFDialect>();
344   }
345   void runOnOperation() override {
346     lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
347   }
348 };
349 
350 struct LowerToParallelLoops
351     : public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> {
352   using impl::ConvertLinalgToParallelLoopsPassBase<
353       LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase;
354   void runOnOperation() override {
355     lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
356   }
357 };
358 
359 } // namespace
360 
361 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
362 FailureOr<LinalgLoops>
363 mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) {
364   return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp);
365 }
366 
367 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
368 FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(RewriterBase &rewriter,
369                                                      LinalgOp linalgOp) {
370   return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
371 }
372 
373 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
374 FailureOr<LinalgLoops>
375 mlir::linalg::linalgOpToParallelLoops(RewriterBase &rewriter,
376                                       LinalgOp linalgOp) {
377   return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
378 }
379