xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision ea75511319d9dff8c38c8794c3949c40b63a38d7)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Interfaces/TilingInterface.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "tile-using-interface"
26 
27 using namespace mlir;
28 
29 scf::SCFTilingOptions &
30 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
31   assert(!tileSizeComputationFunction && "tile sizes already set");
32   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
33   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
34     OpBuilder::InsertionGuard guard(b);
35     b.setInsertionPointToStart(
36         &op->getParentOfType<func::FuncOp>().getBody().front());
37     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
38       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
39       return v;
40     }));
41   };
42   return *this;
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // TileUsingSCFForOp pattern implementation.
47 //===----------------------------------------------------------------------===//
48 
49 /// Generate an empty loop nest that represents the tiled loop nest shell.
50 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
51 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
52 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
53 /// the
54 ///   tile processed within the inner most loop.
55 static SmallVector<scf::ForOp>
56 generateTileLoopNest(OpBuilder &builder, Location loc,
57                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
58                      SmallVector<OpFoldResult> &offsets,
59                      SmallVector<OpFoldResult> &sizes) {
60   assert(!loopRanges.empty() && "expected at least one loop range");
61   assert(loopRanges.size() == tileSizeVals.size() &&
62          "expected as many tile sizes as loop ranges");
63   OpBuilder::InsertionGuard guard(builder);
64   SmallVector<scf::ForOp> loops;
65   offsets.resize(loopRanges.size());
66   sizes.resize(loopRanges.size());
67 
68   // The tile size to use (to avoid out of bounds access) is  minimum of
69   // `tileSize` and `ub - iv`, where `iv` is the induction variable
70   // of the tiled loop.
71   AffineExpr s0, s1, d0;
72   bindDims(builder.getContext(), d0);
73   bindSymbols(builder.getContext(), s0, s1);
74   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
75 
76   for (auto loopRange : llvm::enumerate(loopRanges)) {
77     // No loops if tile size is zero. Set offset and size to the loop
78     // offset and size.
79     if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
80       offsets[loopRange.index()] = loopRange.value().offset;
81       sizes[loopRange.index()] = loopRange.value().size;
82       continue;
83     }
84 
85     auto loop = builder.create<scf::ForOp>(
86         loc, loopRange.value().offset, loopRange.value().size,
87         tileSizeVals[loopRange.index()], ValueRange{},
88         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
89             ValueRange /*iterArgs*/) {
90           Value boundedTileSize = builder.create<AffineMinOp>(
91               bodyLoc, minMap,
92               ValueRange{iv, tileSizeVals[loopRange.index()],
93                          loopRange.value().size});
94           sizes[loopRange.index()] = boundedTileSize;
95           builder.create<scf::YieldOp>(loc);
96         });
97     offsets[loopRange.index()] = loop.getInductionVar();
98     loops.push_back(loop);
99     builder.setInsertionPoint(loop.getBody()->getTerminator());
100   }
101   return loops;
102 }
103 
104 scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
105                                           scf::SCFTilingOptions options,
106                                           PatternBenefit benefit)
107     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
108       options(std::move(options)) {}
109 
110 scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
111                                           MLIRContext *context,
112                                           scf::SCFTilingOptions options,
113                                           PatternBenefit benefit)
114     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
115       options(std::move(options)) {}
116 
117 FailureOr<scf::SCFTilingResult>
118 scf::TileUsingSCFForOp::returningMatchAndRewrite(
119     TilingInterface op, PatternRewriter &rewriter) const {
120   OpBuilder::InsertionGuard guard(rewriter);
121   rewriter.setInsertionPointAfter(op);
122 
123   if (!options.tileSizeComputationFunction) {
124     return rewriter.notifyMatchFailure(
125         op, "missing tile size computation function");
126   }
127 
128   // 1. Get the range of the loops that are represented by the operation.
129   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
130   size_t numLoops = iterationDomain.size();
131   if (numLoops == 0) {
132     return rewriter.notifyMatchFailure(
133         op, "unable to tile op with no iteration domain");
134   }
135 
136   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
137   // skips tiling a particular dimension. This convention is significantly
138   // simpler to handle instead of adjusting affine maps to account for missing
139   // dimensions.
140   SmallVector<Value, 4> tileSizeVector =
141       options.tileSizeComputationFunction(rewriter, op);
142   if (tileSizeVector.size() < iterationDomain.size()) {
143     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
144     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
145   }
146 
147   scf::SCFTilingResult tilingResult;
148   SmallVector<OpFoldResult> offsets, sizes;
149   {
150     // 3. Materialize an empty loop nest that iterates over the tiles. These
151     // loops for now do not return any values even if the original operation has
152     // results.
153     tilingResult.loops = generateTileLoopNest(
154         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
155 
156     LLVM_DEBUG({
157       if (!tilingResult.loops.empty()) {
158         llvm::errs() << "LoopNest shell :\n";
159         tilingResult.loops.front().dump();
160         llvm::errs() << "\n";
161       }
162     });
163 
164     // 4. Generate the tiled implementation within the inner most loop.
165     if (!tilingResult.loops.empty())
166       rewriter.setInsertionPoint(
167           tilingResult.loops.back().getBody()->getTerminator());
168     SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
169         rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
170     if (tiledImplementation.size() != 1) {
171       return rewriter.notifyMatchFailure(
172           op, "expected tiled implementation to return a single op");
173     }
174     tilingResult.tiledOp = tiledImplementation[0];
175 
176     LLVM_DEBUG({
177       if (!tilingResult.loops.empty()) {
178         llvm::errs() << "After tiled implementation :\n";
179         tilingResult.loops.front().dump();
180         llvm::errs() << "\n";
181       }
182     });
183   }
184 
185   if (op->getNumResults() == 0) {
186     rewriter.eraseOp(op);
187     return tilingResult;
188   }
189 
190   // 5. If the original operations has results, modify the loop nest to yield
191   // the replacement values.
192   SmallVector<Value> replacements;
193   if (tilingResult.loops.empty()) {
194     // 5a. If there were no loops, the tiled implementation results are the
195     // replacements.
196     rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
197     return tilingResult;
198   }
199 
200   // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
201   // replacement values using destructive updates. Use the `TilingInterface`
202   // to get the position of the result tiles and use that to generate the
203   // destructive update pattern, i.e.,
204   //
205   // ```mlir
206   // scf.for %iv0 = ... {
207   //   %0 = tiled_op
208   // }
209   // ```
210   //
211   // is transformed to
212   //
213   // ```mlir
214   // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
215   //   %0 = tiled_op
216   //   %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
217   //   scf.yield %1
218   // }
219   // ```
220   NewYieldValueFn yieldValueFn =
221       [&](OpBuilder &b, Location loc,
222           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
223     SmallVector<Value> yieldedValues;
224     Attribute one = b.getIndexAttr(1);
225     for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
226       SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
227       if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
228                                           resultTileOffsets,
229                                           resultTileSizes))) {
230         op.emitOpError("unable to get position of result ")
231             << resultNum << " of the tiled implementation";
232         return {};
233       }
234       SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
235                                                   one);
236       Value yieldedValue = b.create<tensor::InsertSliceOp>(
237           op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
238           newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
239           resultTileStrides);
240       yieldedValues.push_back(yieldedValue);
241     }
242     return yieldedValues;
243   };
244   SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
245       rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
246       yieldValueFn);
247   for (auto loop : llvm::enumerate(tilingResult.loops)) {
248     rewriter.eraseOp(loop.value());
249     tilingResult.loops[loop.index()] = newLoops[loop.index()];
250   }
251   rewriter.replaceOp(op, tilingResult.loops.front().getResults());
252   return tilingResult;
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
257 //===----------------------------------------------------------------------===//
258 
259 scf::TileConsumerAndFuseProducersUsingSCFForOp::
260     TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
261                                               scf::SCFTilingOptions options,
262                                               PatternBenefit benefit)
263     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
264       tilingPattern(context, std::move(options)) {}
265 
266 scf::TileConsumerAndFuseProducersUsingSCFForOp::
267     TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
268                                               MLIRContext *context,
269                                               scf::SCFTilingOptions options,
270                                               PatternBenefit benefit)
271     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
272       tilingPattern(context, std::move(options)) {}
273 
274 /// Return the `Value` that is defined by an operation that implements
275 /// the `TilingInterface`. Looks through `iter_args` of scf.for nest
276 /// if required.
277 static Optional<OpResult> getFusableProducer(Value v) {
278   while (auto blockArg = v.dyn_cast<BlockArgument>()) {
279     auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
280     if (!loopOp)
281       return llvm::None;
282     v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
283   }
284   if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
285     return llvm::None;
286   return v.cast<OpResult>();
287 }
288 
289 FailureOr<scf::SCFTileAndFuseResult>
290 scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
291     TilingInterface op, PatternRewriter &rewriter) const {
292   // This transformation is only valid for ops that return values (i.e. not
293   // valid to use with operations that have memref operands).
294   if (!op->getNumResults()) {
295     return rewriter.notifyMatchFailure(
296         op, "invalid pattern for op with no results");
297   }
298 
299   // 1. First tile the consumer.
300   SCFTileAndFuseResult tileAndFuseResult;
301   {
302     FailureOr<SCFTilingResult> tilingResult =
303         tilingPattern.returningMatchAndRewrite(op, rewriter);
304     if (failed(tilingResult)) {
305       return failure();
306     }
307     tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
308     tileAndFuseResult.loops = std::move(tilingResult->loops);
309   }
310 
311   // 2. Typically, the operands of the tiled operation are slices of the
312   //    operands of the untiled operation. These are expressed in IR using
313   //    `tensor.extract_slice` operations with source being the operands of the
314   //    untiled operation. Create a worklist of these `tensor.extract_slice`
315   //    operations. If the producers of the source of the `tensor.extract_slice`
316   //    can be tiled such that the tiled value is generated in-place, that
317   //    effectively tiles + fuses the operations.
318   auto addCandidateSlices = [](Operation *fusedOp,
319                                std::deque<tensor::ExtractSliceOp> &candidates) {
320     for (Value operand : fusedOp->getOperands())
321       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
322         candidates.push_back(sliceOp);
323   };
324 
325   std::deque<tensor::ExtractSliceOp> candidates;
326   addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
327   OpBuilder::InsertionGuard g(rewriter);
328   while (!candidates.empty()) {
329     // 2a. Traverse the slices in BFS fashion.
330     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
331     candidates.pop_front();
332 
333     // 2b. Get the producer of the source (potentially walking through
334     // `iter_args` of nested `scf.for`)
335     Optional<OpResult> fusableProducer =
336         getFusableProducer(candidateSliceOp.source());
337     if (!fusableProducer)
338       continue;
339 
340     // 2c. Generate the tiled implementation of the producer of the source
341     rewriter.setInsertionPoint(candidateSliceOp);
342     FailureOr<Value> fusedProducerValue =
343         tensor::replaceExtractSliceWithTiledProducer(
344             rewriter, candidateSliceOp, fusableProducer.getValue());
345     if (failed(fusedProducerValue))
346       continue;
347     rewriter.replaceOp(candidateSliceOp, fusedProducerValue.getValue());
348 
349     // 2d. The operands of the fused producer might themselved be slices of
350     //     values produced by operations that implement the `TilingInterface`.
351     //     Add these operations to the worklist.
352     Operation *fusedProducer = fusedProducerValue->getDefiningOp();
353     tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
354     addCandidateSlices(fusedProducer, candidates);
355 
356     // 2e. If the operation being fused creates a value that is used as `outs`
357     //     in the tiled operation, the result of the unfused operation will be
358     //     used in the `iter_args` of the tiled loop generated. When the
359     //     operation is fused, this use in `iter_args` needs to be modified to
360     //     use the destination of the fused operation. For example, starting
361     //     with
362     //
363     //     ```mlir
364     //     %0 = linalg.init_tensor ...
365     //     %1 = linalg.fill ... outs(%0:...)...
366     //     %2 = linalg.matmul ... outs(%1:...)....
367     //     ```
368     //
369     //     First the `linalg.matmul` gets tiled
370     //
371     //     ```mlir
372     //     %0 = linalg.init_tensor
373     //     %1 = linalg.fill
374     //     %2 = scf.for .... iter_args(%arg0 = %1)...
375     //        ...
376     //        ... = linalg.matmul ...
377     //
378     //     ```
379     //
380     //     When the `linalg.fill` gets fused, the `iter_args` needs to be
381     //     modified
382     //
383     //     ```mlir
384     //     %0 = linalg.init_tensor
385     //     %1 = scf.for ... iter_args(%arg0 = %0)...
386     //        ...
387     //        %2 = linalg.fill ...
388     //        %3 = linalg.matmul ... outs(%2: ...)...
389     //     ```
390     TilingInterface unfusedProducerOp =
391         cast<TilingInterface>(fusableProducer->getOwner());
392     scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
393     SmallVector<Value> unfusedProducerOpDestValues =
394         unfusedProducerOp.getDestinationOperands(rewriter);
395     for (OpOperand &uses : unfusedProducerOp->getUses()) {
396       if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
397         unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
398         unsigned operandNumber = uses.getOperandNumber();
399         outerMostTiledLoop->setOperand(
400             operandNumber, unfusedProducerOpDestValues[resultNumber]);
401       }
402     }
403   }
404   return tileAndFuseResult;
405 }
406