xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision c584771f54cf94bb396c22f5cca895dd3f23c245)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Interfaces/TilingInterface.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "tile-using-interface"
26 
27 using namespace mlir;
28 
29 scf::SCFTilingOptions &
30 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
31   assert(!tileSizeComputationFunction && "tile sizes already set");
32   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
33   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
34     OpBuilder::InsertionGuard guard(b);
35     b.setInsertionPointToStart(
36         &op->getParentOfType<func::FuncOp>().getBody().front());
37     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
38       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
39       return v;
40     }));
41   };
42   return *this;
43 }
44 
45 /// Generate an empty loop nest that represents the tiled loop nest shell.
46 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
47 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
48 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
49 /// the
50 ///   tile processed within the inner most loop.
51 static SmallVector<scf::ForOp>
52 generateTileLoopNest(OpBuilder &builder, Location loc,
53                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
54                      SmallVector<OpFoldResult> &offsets,
55                      SmallVector<OpFoldResult> &sizes) {
56   assert(!loopRanges.empty() && "expected at least one loop range");
57   assert(loopRanges.size() == tileSizeVals.size() &&
58          "expected as many tile sizes as loop ranges");
59   OpBuilder::InsertionGuard guard(builder);
60   SmallVector<scf::ForOp> loops;
61   offsets.resize(loopRanges.size());
62   sizes.resize(loopRanges.size());
63 
64   // The tile size to use (to avoid out of bounds access) is  minimum of
65   // `tileSize` and `ub - iv`, where `iv` is the induction variable
66   // of the tiled loop.
67   AffineExpr s0, s1, d0;
68   bindDims(builder.getContext(), d0);
69   bindSymbols(builder.getContext(), s0, s1);
70   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
71 
72   for (auto loopRange : llvm::enumerate(loopRanges)) {
73     // No loops if tile size is zero. Set offset and size to the loop
74     // offset and size.
75     if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
76       offsets[loopRange.index()] = loopRange.value().offset;
77       sizes[loopRange.index()] = loopRange.value().size;
78       continue;
79     }
80 
81     auto loop = builder.create<scf::ForOp>(
82         loc, loopRange.value().offset, loopRange.value().size,
83         tileSizeVals[loopRange.index()], ValueRange{},
84         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
85             ValueRange /*iterArgs*/) {
86           Value boundedTileSize = builder.create<AffineMinOp>(
87               bodyLoc, minMap,
88               ValueRange{iv, tileSizeVals[loopRange.index()],
89                          loopRange.value().size});
90           sizes[loopRange.index()] = boundedTileSize;
91           builder.create<scf::YieldOp>(loc);
92         });
93     offsets[loopRange.index()] = loop.getInductionVar();
94     loops.push_back(loop);
95     builder.setInsertionPoint(loop.getBody()->getTerminator());
96   }
97   return loops;
98 }
99 
100 scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
101                                           scf::SCFTilingOptions options,
102                                           PatternBenefit benefit)
103     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
104       options(std::move(options)) {}
105 
106 scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
107                                           MLIRContext *context,
108                                           scf::SCFTilingOptions options,
109                                           PatternBenefit benefit)
110     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
111       options(std::move(options)) {}
112 
113 FailureOr<scf::SCFTilingResult>
114 scf::TileUsingSCFForOp::returningMatchAndRewrite(
115     TilingInterface op, PatternRewriter &rewriter) const {
116   OpBuilder::InsertionGuard guard(rewriter);
117   rewriter.setInsertionPointAfter(op);
118 
119   if (!options.tileSizeComputationFunction) {
120     return rewriter.notifyMatchFailure(
121         op, "missing tile size computation function");
122   }
123 
124   // 1. Get the range of the loops that are represented by the operation.
125   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
126   size_t numLoops = iterationDomain.size();
127   if (numLoops == 0) {
128     return rewriter.notifyMatchFailure(
129         op, "unable to tile op with no iteration domain");
130   }
131 
132   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
133   // skips tiling a particular dimension. This convention is significantly
134   // simpler to handle instead of adjusting affine maps to account for missing
135   // dimensions.
136   SmallVector<Value, 4> tileSizeVector =
137       options.tileSizeComputationFunction(rewriter, op);
138   if (tileSizeVector.size() < iterationDomain.size()) {
139     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
140     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
141   }
142 
143   scf::SCFTilingResult tilingResult;
144   SmallVector<OpFoldResult> offsets, sizes;
145   {
146     // 3. Materialize an empty loop nest that iterates over the tiles. These
147     // loops for now do not return any values even if the original operation has
148     // results.
149     tilingResult.loops = generateTileLoopNest(
150         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
151 
152     LLVM_DEBUG({
153       if (!tilingResult.loops.empty()) {
154         llvm::errs() << "LoopNest shell :\n";
155         tilingResult.loops.front().dump();
156         llvm::errs() << "\n";
157       }
158     });
159 
160     // 4. Generate the tiled implementation within the inner most loop.
161     if (!tilingResult.loops.empty())
162       rewriter.setInsertionPoint(
163           tilingResult.loops.back().getBody()->getTerminator());
164     SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
165         rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
166     if (tiledImplementation.size() != 1) {
167       return rewriter.notifyMatchFailure(
168           op, "expected tiled implementation to return a single op");
169     }
170     tilingResult.tiledOp = tiledImplementation[0];
171 
172     LLVM_DEBUG({
173       if (!tilingResult.loops.empty()) {
174         llvm::errs() << "After tiled implementation :\n";
175         tilingResult.loops.front().dump();
176         llvm::errs() << "\n";
177       }
178     });
179   }
180 
181   if (op->getNumResults() == 0) {
182     rewriter.eraseOp(op);
183     return tilingResult;
184   }
185 
186   // 5. If the original operations has results, modify the loop nest to yield
187   // the replacement values.
188   SmallVector<Value> replacements;
189   if (tilingResult.loops.empty()) {
190     // 5a. If there were no loops, the tiled implementation results are the
191     // replacements.
192     rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
193     return tilingResult;
194   }
195 
196   // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
197   // replacement values using destructive updates. Use the `TilingInterface`
198   // to get the position of the result tiles and use that to generate the
199   // destructive update pattern, i.e.,
200   //
201   // ```mlir
202   // scf.for %iv0 = ... {
203   //   %0 = tiled_op
204   // }
205   // ```
206   //
207   // is transformed to
208   //
209   // ```mlir
210   // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
211   //   %0 = tiled_op
212   //   %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
213   //   scf.yield %1
214   // }
215   // ```
216   NewYieldValueFn yieldValueFn =
217       [&](OpBuilder &b, Location loc,
218           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
219     SmallVector<Value> yieldedValues;
220     Attribute one = b.getIndexAttr(1);
221     for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
222       SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
223       if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
224                                           resultTileOffsets,
225                                           resultTileSizes))) {
226         op.emitOpError("unable to get position of result ")
227             << resultNum << " of the tiled implementation";
228         return {};
229       }
230       SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
231                                                   one);
232       Value yieldedValue = b.create<tensor::InsertSliceOp>(
233           op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
234           newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
235           resultTileStrides);
236       yieldedValues.push_back(yieldedValue);
237     }
238     return yieldedValues;
239   };
240   SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
241       rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
242       yieldValueFn);
243   for (auto loop : llvm::enumerate(tilingResult.loops)) {
244     rewriter.eraseOp(loop.value());
245     tilingResult.loops[loop.index()] = newLoops[loop.index()];
246   }
247   rewriter.replaceOp(op, tilingResult.loops.front().getResults());
248   return tilingResult;
249 }
250