xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 9e6585030533e901a8c24dcb05b38d3f0d10331f)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Interfaces/TilingInterface.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "tile-using-interface"
26 
27 using namespace mlir;
28 
29 scf::SCFTilingOptions &
30 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
31   assert(!tileSizeComputationFunction && "tile sizes already set");
32   SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
33   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
34     OpBuilder::InsertionGuard guard(b);
35     b.setInsertionPointToStart(
36         &op->getParentOfType<func::FuncOp>().getBody().front());
37     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
38       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
39       return v;
40     }));
41   };
42   return *this;
43 }
44 
45 /// Helper method to adjust the interchange vector to match the iteration
46 /// domain.
47 static SmallVector<unsigned>
48 fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
49                       size_t iterationDomainSize) {
50   SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector);
51   if (filledVector.size() < iterationDomainSize) {
52     auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize);
53     filledVector.append(range.begin(), range.end());
54   }
55   if (filledVector.size() > iterationDomainSize)
56     filledVector.resize(iterationDomainSize);
57   return filledVector;
58 }
59 
60 /// Helper method to apply permutation to a vector
61 template <typename T>
62 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
63                                                ArrayRef<unsigned> interchange) {
64   assert(interchange.size() == vector.size());
65   return llvm::to_vector(
66       llvm::map_range(interchange, [&](unsigned val) { return vector[val]; }));
67 }
68 /// Helper method to apply to invert a permutation.
69 static SmallVector<unsigned>
70 invertPermutationVector(ArrayRef<unsigned> interchange) {
71   SmallVector<unsigned> inversion(interchange.size());
72   for (auto pos : llvm::enumerate(interchange)) {
73     inversion[pos.value()] = pos.index();
74   }
75   return inversion;
76 }
77 /// Method to check if an interchange vector is a permutation.
78 static bool isPermutation(ArrayRef<unsigned> interchange) {
79   llvm::SmallDenseSet<unsigned, 4> seenVals;
80   for (auto val : interchange) {
81     if (seenVals.count(val))
82       return false;
83     seenVals.insert(val);
84   }
85   return seenVals.size() == interchange.size();
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // TileUsingSCFForOp pattern implementation.
90 //===----------------------------------------------------------------------===//
91 
92 /// Generate an empty loop nest that represents the tiled loop nest shell.
93 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
94 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
95 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
96 /// the
97 ///   tile processed within the inner most loop.
98 static SmallVector<scf::ForOp>
99 generateTileLoopNest(OpBuilder &builder, Location loc,
100                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
101                      SmallVector<OpFoldResult> &offsets,
102                      SmallVector<OpFoldResult> &sizes) {
103   assert(!loopRanges.empty() && "expected at least one loop range");
104   assert(loopRanges.size() == tileSizeVals.size() &&
105          "expected as many tile sizes as loop ranges");
106   OpBuilder::InsertionGuard guard(builder);
107   SmallVector<scf::ForOp> loops;
108   offsets.resize(loopRanges.size());
109   sizes.resize(loopRanges.size());
110 
111   // The tile size to use (to avoid out of bounds access) is  minimum of
112   // `tileSize` and `ub - iv`, where `iv` is the induction variable
113   // of the tiled loop.
114   AffineExpr s0, s1, d0;
115   bindDims(builder.getContext(), d0);
116   bindSymbols(builder.getContext(), s0, s1);
117   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
118 
119   for (auto loopRange : llvm::enumerate(loopRanges)) {
120     // No loops if tile size is zero. Set offset and size to the loop
121     // offset and size.
122     if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
123       offsets[loopRange.index()] = loopRange.value().offset;
124       sizes[loopRange.index()] = loopRange.value().size;
125       continue;
126     }
127 
128     auto loop = builder.create<scf::ForOp>(
129         loc, loopRange.value().offset, loopRange.value().size,
130         tileSizeVals[loopRange.index()], ValueRange{},
131         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
132             ValueRange /*iterArgs*/) {
133           Value boundedTileSize = builder.create<AffineMinOp>(
134               bodyLoc, minMap,
135               ValueRange{iv, tileSizeVals[loopRange.index()],
136                          loopRange.value().size});
137           sizes[loopRange.index()] = boundedTileSize;
138           builder.create<scf::YieldOp>(loc);
139         });
140     offsets[loopRange.index()] = loop.getInductionVar();
141     loops.push_back(loop);
142     builder.setInsertionPoint(loop.getBody()->getTerminator());
143   }
144   return loops;
145 }
146 
147 scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
148                                           scf::SCFTilingOptions options,
149                                           PatternBenefit benefit)
150     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
151       options(std::move(options)) {}
152 
153 scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
154                                           MLIRContext *context,
155                                           scf::SCFTilingOptions options,
156                                           PatternBenefit benefit)
157     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
158       options(std::move(options)) {}
159 
160 FailureOr<scf::SCFTilingResult>
161 scf::TileUsingSCFForOp::returningMatchAndRewrite(
162     TilingInterface op, PatternRewriter &rewriter) const {
163   OpBuilder::InsertionGuard guard(rewriter);
164   rewriter.setInsertionPointAfter(op);
165 
166   if (!options.tileSizeComputationFunction) {
167     return rewriter.notifyMatchFailure(
168         op, "missing tile size computation function");
169   }
170 
171   // 1. Get the range of the loops that are represented by the operation.
172   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
173   size_t numLoops = iterationDomain.size();
174   if (numLoops == 0) {
175     return rewriter.notifyMatchFailure(
176         op, "unable to tile op with no iteration domain");
177   }
178 
179   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
180   // skips tiling a particular dimension. This convention is significantly
181   // simpler to handle instead of adjusting affine maps to account for missing
182   // dimensions.
183   SmallVector<Value> tileSizeVector =
184       options.tileSizeComputationFunction(rewriter, op);
185   if (tileSizeVector.size() < iterationDomain.size()) {
186     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
187     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
188   }
189 
190   scf::SCFTilingResult tilingResult;
191   SmallVector<OpFoldResult> offsets, sizes;
192   {
193     // If there is an interchange specified, permute the iteration domain and
194     // the tile sizes.
195     SmallVector<unsigned> interchangeVector;
196     if (!options.interchangeVector.empty()) {
197       interchangeVector = fillInterchangeVector(options.interchangeVector,
198                                                 iterationDomain.size());
199     }
200     if (!interchangeVector.empty()) {
201       if (!isPermutation(interchangeVector)) {
202         return rewriter.notifyMatchFailure(
203             op, "invalid intechange vector, not a permutation of the entire "
204                 "iteration space");
205       }
206 
207       iterationDomain =
208           applyPermutationToVector(iterationDomain, interchangeVector);
209       tileSizeVector =
210           applyPermutationToVector(tileSizeVector, interchangeVector);
211     }
212 
213     // 3. Materialize an empty loop nest that iterates over the tiles. These
214     // loops for now do not return any values even if the original operation has
215     // results.
216     tilingResult.loops = generateTileLoopNest(
217         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
218 
219     if (!interchangeVector.empty()) {
220       auto inversePermutation = invertPermutationVector(interchangeVector);
221       offsets = applyPermutationToVector(offsets, inversePermutation);
222       sizes = applyPermutationToVector(sizes, inversePermutation);
223     }
224 
225     LLVM_DEBUG({
226       if (!tilingResult.loops.empty()) {
227         llvm::errs() << "LoopNest shell :\n";
228         tilingResult.loops.front().dump();
229         llvm::errs() << "\n";
230       }
231     });
232 
233     // 4. Generate the tiled implementation within the inner most loop.
234     if (!tilingResult.loops.empty())
235       rewriter.setInsertionPoint(
236           tilingResult.loops.back().getBody()->getTerminator());
237     SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
238         rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
239     if (tiledImplementation.size() != 1) {
240       return rewriter.notifyMatchFailure(
241           op, "expected tiled implementation to return a single op");
242     }
243     tilingResult.tiledOp = tiledImplementation[0];
244 
245     LLVM_DEBUG({
246       if (!tilingResult.loops.empty()) {
247         llvm::errs() << "After tiled implementation :\n";
248         tilingResult.loops.front().dump();
249         llvm::errs() << "\n";
250       }
251     });
252   }
253 
254   if (op->getNumResults() == 0) {
255     rewriter.eraseOp(op);
256     return tilingResult;
257   }
258 
259   // 5. If the original operations has results, modify the loop nest to yield
260   // the replacement values.
261   SmallVector<Value> replacements;
262   if (tilingResult.loops.empty()) {
263     // 5a. If there were no loops, the tiled implementation results are the
264     // replacements.
265     rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
266     return tilingResult;
267   }
268 
269   // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
270   // replacement values using destructive updates. Use the `TilingInterface`
271   // to get the position of the result tiles and use that to generate the
272   // destructive update pattern, i.e.,
273   //
274   // ```mlir
275   // scf.for %iv0 = ... {
276   //   %0 = tiled_op
277   // }
278   // ```
279   //
280   // is transformed to
281   //
282   // ```mlir
283   // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
284   //   %0 = tiled_op
285   //   %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
286   //   scf.yield %1
287   // }
288   // ```
289   NewYieldValueFn yieldValueFn =
290       [&](OpBuilder &b, Location loc,
291           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
292     SmallVector<Value> yieldedValues;
293     Attribute one = b.getIndexAttr(1);
294     for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
295       SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
296       if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
297                                           resultTileOffsets,
298                                           resultTileSizes))) {
299         op.emitOpError("unable to get position of result ")
300             << resultNum << " of the tiled implementation";
301         return {};
302       }
303       SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
304                                                   one);
305       Value yieldedValue = b.create<tensor::InsertSliceOp>(
306           op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
307           newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
308           resultTileStrides);
309       yieldedValues.push_back(yieldedValue);
310     }
311     return yieldedValues;
312   };
313   SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
314       rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
315       yieldValueFn);
316   for (const auto &loop : llvm::enumerate(tilingResult.loops)) {
317     rewriter.eraseOp(loop.value());
318     tilingResult.loops[loop.index()] = newLoops[loop.index()];
319   }
320   rewriter.replaceOp(op, tilingResult.loops.front().getResults());
321   return tilingResult;
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
326 //===----------------------------------------------------------------------===//
327 
328 scf::TileConsumerAndFuseProducersUsingSCFForOp::
329     TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
330                                               scf::SCFTilingOptions options,
331                                               PatternBenefit benefit)
332     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
333       tilingPattern(context, std::move(options)) {}
334 
335 scf::TileConsumerAndFuseProducersUsingSCFForOp::
336     TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
337                                               MLIRContext *context,
338                                               scf::SCFTilingOptions options,
339                                               PatternBenefit benefit)
340     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
341       tilingPattern(context, std::move(options)) {}
342 
343 /// Return the `Value` that is defined by an operation that implements
344 /// the `TilingInterface`. Looks through `iter_args` of scf.for nest
345 /// if required.
346 static Optional<OpResult> getFusableProducer(Value v) {
347   while (auto blockArg = v.dyn_cast<BlockArgument>()) {
348     auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
349     if (!loopOp)
350       return llvm::None;
351     v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
352   }
353   if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
354     return llvm::None;
355   return v.cast<OpResult>();
356 }
357 
358 // Replace iter args of the outer most loop with region args of the inner most
359 // one.
360 static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor,
361                             PatternRewriter &rewriter) {
362   assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() &&
363          "expect same number of iter args");
364   Block *block = &(*innerFor.getRegion().begin());
365   for (auto it :
366        llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) {
367     Value source = std::get<0>(it);
368     Value target = std::get<1>(it);
369     source.replaceUsesWithIf(target, [&](OpOperand &use) {
370       return use.getOwner()->getBlock() == block;
371     });
372   }
373 }
374 
375 FailureOr<scf::SCFTileAndFuseResult>
376 scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
377     TilingInterface op, PatternRewriter &rewriter) const {
378   // This transformation is only valid for ops that return values (i.e. not
379   // valid to use with operations that have memref operands).
380   if (!op->getNumResults()) {
381     return rewriter.notifyMatchFailure(
382         op, "invalid pattern for op with no results");
383   }
384 
385   // 1. First tile the consumer.
386   SCFTileAndFuseResult tileAndFuseResult;
387   {
388     FailureOr<SCFTilingResult> tilingResult =
389         tilingPattern.returningMatchAndRewrite(op, rewriter);
390     if (failed(tilingResult)) {
391       return failure();
392     }
393     tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
394     tileAndFuseResult.loops = std::move(tilingResult->loops);
395   }
396 
397   // 2. Typically, the operands of the tiled operation are slices of the
398   //    operands of the untiled operation. These are expressed in IR using
399   //    `tensor.extract_slice` operations with source being the operands of the
400   //    untiled operation. Create a worklist of these `tensor.extract_slice`
401   //    operations. If the producers of the source of the `tensor.extract_slice`
402   //    can be tiled such that the tiled value is generated in-place, that
403   //    effectively tiles + fuses the operations.
404   auto addCandidateSlices = [](Operation *fusedOp,
405                                std::deque<tensor::ExtractSliceOp> &candidates) {
406     for (Value operand : fusedOp->getOperands())
407       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
408         candidates.push_back(sliceOp);
409   };
410 
411   std::deque<tensor::ExtractSliceOp> candidates;
412   addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
413   OpBuilder::InsertionGuard g(rewriter);
414   while (!candidates.empty()) {
415     // 2a. Traverse the slices in BFS fashion.
416     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
417     candidates.pop_front();
418 
419     // 2b. Get the producer of the source (potentially walking through
420     // `iter_args` of nested `scf.for`)
421     Optional<OpResult> fusableProducer =
422         getFusableProducer(candidateSliceOp.getSource());
423     if (!fusableProducer)
424       continue;
425 
426     // 2c. Generate the tiled implementation of the producer of the source
427     rewriter.setInsertionPoint(candidateSliceOp);
428     FailureOr<Value> fusedProducerValue =
429         tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
430                                                      fusableProducer.value());
431     if (failed(fusedProducerValue))
432       continue;
433     rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
434 
435     // 2d. The operands of the fused producer might themselved be slices of
436     //     values produced by operations that implement the `TilingInterface`.
437     //     Add these operations to the worklist.
438     Operation *fusedProducer = fusedProducerValue->getDefiningOp();
439     tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
440     addCandidateSlices(fusedProducer, candidates);
441 
442     // 2e. If the operation being fused creates a value that is used as `outs`
443     //     in the tiled operation, the result of the unfused operation will be
444     //     used in the `iter_args` of the tiled loop generated. When the
445     //     operation is fused, this use in `iter_args` needs to be modified to
446     //     use the destination of the fused operation. For example, starting
447     //     with
448     //
449     //     ```mlir
450     //     %0 = linalg.init_tensor ...
451     //     %1 = linalg.fill ... outs(%0:...)...
452     //     %2 = linalg.matmul ... outs(%1:...)....
453     //     ```
454     //
455     //     First the `linalg.matmul` gets tiled
456     //
457     //     ```mlir
458     //     %0 = linalg.init_tensor
459     //     %1 = linalg.fill
460     //     %2 = scf.for .... iter_args(%arg0 = %1)...
461     //        ...
462     //        ... = linalg.matmul ...
463     //
464     //     ```
465     //
466     //     When the `linalg.fill` gets fused, the `iter_args` needs to be
467     //     modified
468     //
469     //     ```mlir
470     //     %0 = linalg.init_tensor
471     //     %1 = scf.for ... iter_args(%arg0 = %0)...
472     //        ...
473     //        %2 = linalg.fill ...
474     //        %3 = linalg.matmul ... outs(%2: ...)...
475     //     ```
476     TilingInterface unfusedProducerOp =
477         cast<TilingInterface>(fusableProducer->getOwner());
478     scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
479     SmallVector<Value> unfusedProducerOpDestValues =
480         unfusedProducerOp.getDestinationOperands(rewriter);
481     for (OpOperand &uses : unfusedProducerOp->getUses()) {
482       if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
483         unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
484         unsigned operandNumber = uses.getOperandNumber();
485         outerMostTiledLoop->setOperand(
486             operandNumber, unfusedProducerOpDestValues[resultNumber]);
487       }
488     }
489   }
490   replaceIterArgs(tileAndFuseResult.loops.front(),
491                   tileAndFuseResult.loops.back(), rewriter);
492   return tileAndFuseResult;
493 }
494