xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision ce349ff1a483e5d08bdd394dddea63b549e646fd)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
25 #include "mlir/Interfaces/TilingInterface.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "tile-using-interface"
29 
30 using namespace mlir;
31 
32 scf::SCFTilingOptions &
33 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
34   assert(!tileSizeComputationFunction && "tile sizes already set");
35   SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
36   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
37     OpBuilder::InsertionGuard guard(b);
38     b.setInsertionPointToStart(
39         &op->getParentOfType<func::FuncOp>().getBody().front());
40     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
41       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
42       return v;
43     }));
44   };
45   return *this;
46 }
47 
48 /// Helper method to adjust the interchange vector to match the iteration
49 /// domain.
50 static SmallVector<int64_t>
51 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
52                       size_t iterationDomainSize) {
53   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
54   if (filledVector.size() < iterationDomainSize) {
55     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
56     filledVector.append(range.begin(), range.end());
57   }
58   if (filledVector.size() > iterationDomainSize)
59     filledVector.resize(iterationDomainSize);
60   return filledVector;
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // tileUsingSCFForOp implementation.
65 //===----------------------------------------------------------------------===//
66 
67 // Check if `stride` evenly divides the trip count `size - offset`.
68 static bool tileDividesIterationDomain(Range loopRange) {
69   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
70   if (!offsetAsInt)
71     return false;
72   std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
73   if (!sizeAsInt)
74     return false;
75   std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
76   if (!strideAsInt)
77     return false;
78   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
79 }
80 
81 /// Returns the bounded tile size given the current `iv`, `loopRange` and
82 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
83 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
84                                        Range loopRange, Value iv,
85                                        Value tileSize) {
86   std::optional<int64_t> ts = getConstantIntValue(tileSize);
87   if (ts && ts.value() == 1)
88     return getAsOpFoldResult(tileSize);
89 
90   if (tileDividesIterationDomain(
91           Range{loopRange.offset, loopRange.size, tileSize}))
92     return tileSize;
93 
94   // The tile size to use (to avoid out of bounds access) is  minimum of
95   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
96   // loop.
97   AffineExpr s0, s1, d0;
98   bindDims(b.getContext(), d0);
99   bindSymbols(b.getContext(), s0, s1);
100   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
101   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
102   return makeComposedFoldedAffineMin(
103       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
104 }
105 
106 /// Generate an empty loop nest that represents the tiled loop nest shell.
107 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
108 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
109 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
110 /// the
111 ///   tile processed within the inner most loop.
112 static SmallVector<scf::ForOp>
113 generateTileLoopNest(OpBuilder &builder, Location loc,
114                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
115                      SmallVector<OpFoldResult> &offsets,
116                      SmallVector<OpFoldResult> &sizes) {
117   assert(!loopRanges.empty() && "expected at least one loop range");
118   assert(loopRanges.size() == tileSizeVals.size() &&
119          "expected as many tile sizes as loop ranges");
120   OpBuilder::InsertionGuard guard(builder);
121   SmallVector<scf::ForOp> loops;
122   offsets.resize(loopRanges.size());
123   sizes.resize(loopRanges.size());
124 
125   for (auto loopRange : llvm::enumerate(loopRanges)) {
126     Value offset =
127         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
128     Value size =
129         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
130     Value tileSize = tileSizeVals[loopRange.index()];
131     // No loops if tile size is zero. Set offset and size to the loop
132     // offset and size.
133     if (matchPattern(tileSize, m_Zero())) {
134       offsets[loopRange.index()] = offset;
135       sizes[loopRange.index()] = size;
136       continue;
137     }
138 
139     auto loop = builder.create<scf::ForOp>(
140         loc, offset, size, tileSize, ValueRange{},
141         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
142             ValueRange /*iterArgs*/) {
143           sizes[loopRange.index()] = getBoundedTileSize(
144               bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize);
145           builder.create<scf::YieldOp>(loc);
146         });
147     offsets[loopRange.index()] = loop.getInductionVar();
148     loops.push_back(loop);
149     builder.setInsertionPoint(loop.getBody()->getTerminator());
150   }
151   return loops;
152 }
153 
154 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
155 /// construct the destructive update pattern that inserts the yielded
156 /// value into a destination tensor provided by `initValue` at offset
157 /// `tileOffsets` and size `tileSizes`. For example,
158 ///
159 /// ```mlir
160 /// scf.for %iv0 = ... {
161 ///   %0 = tiled_op
162 /// }
163 /// ```
164 ///
165 /// is transformed to
166 ///
167 /// ```mlir
168 /// scf.for %iv0 = ... iter_args(%arg = %0) {
169 ///   %1 = tensor.extract_slice %arg
170 ///   %2 = tiled_op
171 ///   %3 = tensor.insert_slice %2 into %arg
172 ///   scf.yield %3
173 /// }
174 /// ```
175 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
176 static SmallVector<Value>
177 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
178                  ValueRange yieldedValues,
179                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
180                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
181                  MutableArrayRef<scf::ForOp> loops) {
182   NewYieldValueFn yieldValueFn =
183       [&](OpBuilder &b, Location loc,
184           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
185     SmallVector<Value> inserts;
186     for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
187       ArrayRef<OpFoldResult> tileOffsets =
188           tileOffsetsList[yieldedValue.index()];
189       ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
190       SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
191                                             b.getIndexAttr(1));
192       Value insert = b.create<tensor::InsertSliceOp>(
193           loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
194           tileOffsets, tileSizes, tileStrides);
195       inserts.push_back(insert);
196     }
197     return inserts;
198   };
199 
200   SmallVector<scf::ForOp> newLoops =
201       replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
202                                    /*replaceIterOperandsUsesInLoop =*/false);
203   for (const auto &loop : llvm::enumerate(loops)) {
204     rewriter.eraseOp(loop.value());
205     loops[loop.index()] = newLoops[loop.index()];
206   }
207   return llvm::to_vector(llvm::map_range(
208       loops.front().getResults().take_back(yieldedValues.size()),
209       [](OpResult r) -> Value { return r; }));
210 }
211 
212 /// If the tiled operation is destination passing style, update the
213 /// slice of the destination used (which refers to the untiled destination)
214 /// to use the corresponding region argument of the innermost loop.
215 ///
216 /// ```mlir
217 /// %0 =
218 /// scf.for %iv0 = ... iter_args(%arg = %0) {
219 ///   %1 = tensor.extract_slice %0
220 ///   %2 = tiled_op
221 ///   %3 = tensor.insert_slice %2 into %arg
222 ///   scf.yield %3
223 /// }
224 /// ```
225 ///
226 /// is transformed to
227 ///
228 /// ```mlir
229 /// scf.for %iv0 = ... iter_args(%arg = %0) {
230 ///   %1 = tensor.extract_slice %arg
231 ///   %2 = tiled_op
232 ///   %3 = tensor.insert_slice %2 into %arg
233 ///   scf.yield %3
234 /// }
235 /// ```
236 static void
237 updateDestinationOperandsForTiledOp(OpBuilder &builder,
238                                     ValueRange tiledOpDestinationValues,
239                                     ValueRange bbArgsList) {
240   for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
241     auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
242     if (!sliceOp)
243       continue;
244     sliceOp.setOperand(0, bbArgsList[destValue.index()]);
245   }
246 }
247 
248 /// Helper method to yield the values of the tiled op, as well as
249 /// update the destination operands of the tiled op, if it is
250 /// a destination passing style op.
251 static SmallVector<Value>
252 yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
253                  Operation *tiledOp,
254                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
255                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
256                  MutableArrayRef<scf::ForOp> loops) {
257   SmallVector<Value> replacements =
258       yieldTiledValues(rewriter, initValues, tiledOp->getResults(),
259                        tileOffsetsList, tileSizesList, loops);
260   if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
261     auto innerMostLoop = loops.back();
262     SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
263     updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
264                                         innerMostLoop.getRegionIterArgs());
265   }
266   return replacements;
267 }
268 
269 /// Implementation of tiling transformation of `op` that implements the
270 /// `TilingInterface` using `scf.for` to iterate over the tiles.
271 FailureOr<scf::SCFTilingResult>
272 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
273                              const scf::SCFTilingOptions &options) {
274   OpBuilder::InsertionGuard guard(rewriter);
275   rewriter.setInsertionPointAfter(op);
276 
277   if (!options.tileSizeComputationFunction) {
278     return rewriter.notifyMatchFailure(
279         op, "missing tile size computation function");
280   }
281 
282   // 1. Get the range of the loops that are represented by the operation.
283   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
284   size_t numLoops = iterationDomain.size();
285   if (numLoops == 0) {
286     return rewriter.notifyMatchFailure(
287         op, "unable to tile op with no iteration domain");
288   }
289 
290   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
291   // skips tiling a particular dimension. This convention is significantly
292   // simpler to handle instead of adjusting affine maps to account for missing
293   // dimensions.
294   SmallVector<Value> tileSizeVector =
295       options.tileSizeComputationFunction(rewriter, op);
296   if (tileSizeVector.size() < iterationDomain.size()) {
297     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
298     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
299   }
300 
301   scf::SCFTilingResult tilingResult;
302   SmallVector<OpFoldResult> offsets, sizes;
303   {
304     // If there is an interchange specified, permute the iteration domain and
305     // the tile sizes.
306     SmallVector<int64_t> interchangeVector;
307     if (!options.interchangeVector.empty()) {
308       interchangeVector = fillInterchangeVector(options.interchangeVector,
309                                                 iterationDomain.size());
310     }
311     if (!interchangeVector.empty()) {
312       if (!isPermutationVector(interchangeVector)) {
313         return rewriter.notifyMatchFailure(
314             op, "invalid intechange vector, not a permutation of the entire "
315                 "iteration space");
316       }
317 
318       applyPermutationToVector(iterationDomain, interchangeVector);
319       applyPermutationToVector(tileSizeVector, interchangeVector);
320     }
321 
322     // 3. Materialize an empty loop nest that iterates over the tiles. These
323     // loops for now do not return any values even if the original operation has
324     // results.
325     tilingResult.loops = generateTileLoopNest(
326         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
327 
328     if (!interchangeVector.empty()) {
329       auto inversePermutation = invertPermutationVector(interchangeVector);
330       applyPermutationToVector(offsets, inversePermutation);
331       applyPermutationToVector(sizes, inversePermutation);
332     }
333   }
334 
335   LLVM_DEBUG({
336     if (!tilingResult.loops.empty()) {
337       llvm::dbgs() << "LoopNest shell :\n";
338       tilingResult.loops.front().dump();
339       llvm::dbgs() << "\n";
340     }
341   });
342 
343   // 4. Generate the tiled implementation within the inner most loop.
344   if (!tilingResult.loops.empty())
345     rewriter.setInsertionPoint(
346         tilingResult.loops.back().getBody()->getTerminator());
347   SmallVector<Operation *> tiledImplementation =
348       op.getTiledImplementation(rewriter, offsets, sizes);
349   tilingResult.tiledOps.append(tiledImplementation);
350   if (op->getNumResults() == 0) {
351     // nothing more to do.
352     return tilingResult;
353   }
354 
355   // If loops are empty, the tiled op is used as the replacement for the untiled
356   // op.
357   if (tilingResult.loops.empty()) {
358     tilingResult.replacements = llvm::to_vector(
359         llvm::map_range(tiledImplementation[0]->getResults(),
360                         [](OpResult result) -> Value { return result; }));
361     return tilingResult;
362   }
363 
364   // 5. Yield all the results of the tiled operation. The surrounding loop
365   //    nest is modified to insert a destructive update pattern to yield
366   //    from the loop nest values to replace the untiled op with.
367   int64_t numResults = op->getNumResults();
368   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
369       resultSizesList(numResults);
370   for (const auto &result : llvm::enumerate(op->getResults())) {
371     if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
372                                         sizes,
373                                         resultOffsetsList[result.index()],
374                                         resultSizesList[result.index()]))) {
375       return rewriter.notifyMatchFailure(
376           op, "failed to get slice of result produced");
377     }
378   }
379 
380   SmallVector<Value> destinationTensors;
381   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
382                                              destinationTensors)))
383     return rewriter.notifyMatchFailure(op, "failed to get destinations");
384 
385   tilingResult.replacements = yieldTiledValues(
386       rewriter, destinationTensors, tilingResult.tiledOps.back(),
387       resultOffsetsList, resultSizesList, tilingResult.loops);
388 
389   LLVM_DEBUG({
390     if (!tilingResult.loops.empty()) {
391       llvm::dbgs() << "After tiled implementation :\n";
392       tilingResult.loops.front().dump();
393       llvm::dbgs() << "\n";
394     }
395   });
396   return tilingResult;
397 }
398 
399 FailureOr<scf::SCFReductionTilingResult>
400 mlir::scf::tileReductionUsingScf(PatternRewriter &b,
401                                  PartialReductionOpInterface op,
402                                  ArrayRef<OpFoldResult> tileSize) {
403   Location loc = op.getLoc();
404   // Ops implementing PartialReductionOpInterface are expected to implement
405   // TilingInterface.
406   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
407   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
408   SmallVector<Value> tileSizeVector =
409       getValueOrCreateConstantIndexOp(b, loc, tileSize);
410   if (tileSizeVector.size() < iterationDomain.size()) {
411     auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
412     tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
413   }
414   if (op->getNumResults() != 1)
415     return b.notifyMatchFailure(
416         op, "don't support ops with multiple results for now");
417   SmallVector<utils::IteratorType> iterators =
418       tilingInterfaceOp.getLoopIteratorTypes();
419   int64_t numReductionDims = llvm::count(
420       tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction);
421   if (numReductionDims != 1)
422     return b.notifyMatchFailure(
423         op, "only support ops with one reduction dimension.");
424   int reductionDim;
425   for (auto &[idx, iteratorType] :
426        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
427     if (iteratorType == utils::IteratorType::reduction) {
428       reductionDim = idx;
429       break;
430     }
431   }
432   if (static_cast<size_t>(reductionDim) >= tileSize.size())
433     return b.notifyMatchFailure(op, "reduction dimension must be tiled");
434 
435   // 1. create the inital tensor value.
436   FailureOr<Operation *> identityTensor =
437       op.generateInitialTensorForPartialReduction(b, loc, tileSize,
438                                                   reductionDim);
439   if (failed(identityTensor))
440     return b.notifyMatchFailure(op,
441                                 "cannot create a tensor of identity value.");
442   // 2. Create the nested loops.
443   SmallVector<OpFoldResult> offsets, sizes;
444   SmallVector<scf::ForOp> loops = generateTileLoopNest(
445       b, loc, iterationDomain, tileSizeVector, offsets, sizes);
446 
447   // 3. Generate the tiled implementation within the inner most loop.
448   b.setInsertionPoint(loops.back().getBody()->getTerminator());
449   Operation *parallelOp = op.tileToPartialReduction(
450       b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim);
451 
452   SmallVector<OpFoldResult> resultSizesList;
453   for (size_t i = 0; i < offsets.size(); i++)
454     resultSizesList.push_back(
455         b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i));
456   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
457   SmallVector<Value> replacements = yieldTiledValues(
458       b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
459       resultSizesList, loops);
460 
461   auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
462   auto innerMostLoop = loops.back();
463   SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
464   assert(destinationTensors.size() ==
465              innerMostLoop.getRegionIterArgs().size() &&
466          "unexpected number of outputs");
467   updateDestinationOperandsForTiledOp(b, destinationTensors,
468                                       innerMostLoop.getRegionIterArgs());
469 
470   // 4. Apply the merge reduction to combine all the partial values.
471   b.setInsertionPointAfter(*loops.begin());
472   Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim);
473   b.replaceOp(op, mergeOp->getResults());
474 
475   SCFReductionTilingResult results;
476   results.initialOp = *identityTensor;
477   results.loops = std::move(loops);
478   results.parallelTiledOp = parallelOp;
479   results.mergeOp = mergeOp;
480   return results;
481 }
482 //===----------------------------------------------------------------------===//
483 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
484 //===----------------------------------------------------------------------===//
485 
486 /// Return the untiled producer whose slice is used in a tiled consumer. The
487 /// method traverses the tile loop nest (`loops`) if needed, and returns the
488 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
489 /// indicates that this is a destination operand of the consumer. If there was
490 /// no loop traversal needed, the second value of the returned tuple is empty.
491 static std::tuple<OpResult, std::optional<OpOperand *>>
492 getUntiledProducerFromSliceSource(OpOperand *source,
493                                   ArrayRef<scf::ForOp> loops) {
494   std::optional<OpOperand *> destinationIterArg;
495   auto loopIt = loops.rbegin();
496   while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
497     scf::ForOp loop = *loopIt;
498     if (iterArg.getOwner()->getParentOp() != loop)
499       break;
500     source = &loop.getOpOperandForRegionIterArg(iterArg);
501     loopIt++;
502   }
503   if (loopIt == loops.rend())
504     destinationIterArg = source;
505   return {source->get().dyn_cast<OpResult>(), destinationIterArg};
506 }
507 
508 static std::optional<Operation *>
509 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
510                            tensor::ExtractSliceOp candidateSliceOp,
511                            MutableArrayRef<scf::ForOp> loops) {
512   // 1. Get the producer of the source (potentially walking through
513   // `iter_args` of nested `scf.for`)
514   auto [fusableProducer, destinationIterArg] =
515       getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
516                                         loops);
517   if (!fusableProducer)
518     return std::nullopt;
519 
520   // 2. Generate the tiled implementation of the producer of the source
521   OpBuilder::InsertionGuard g(rewriter);
522   rewriter.setInsertionPoint(candidateSliceOp);
523   FailureOr<Value> fusedProducerValue =
524       tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
525                                                    fusableProducer);
526   if (failed(fusedProducerValue))
527     return std::nullopt;
528   rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
529 
530   // 3. If the slice is for a destination operand, for example,
531   //
532   // ```mlir
533   // %0 = linalg.init
534   // %1 = linalg.fill .. outs(%0 : )
535   // %2 = scf.for .. iter_args(%arg0 = %1) {
536   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
537   //     %4 = tensor.extract_slice %arg1 [..]
538   //     .. = linalg.matmul .. outs(%4 : )
539   //   }
540   // }
541   // ```
542   //
543   // the IR is currently
544   //
545   // ```
546   // %0 = linalg.init
547   // %1 = linalg.fill
548   // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
549   //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
550   //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
551   //     %5 = linalg.fill .. outs(%4 : )
552   //     .. = linalg.matmul .. outs(%5 : )
553   //   }
554   // }
555   // ```
556   //
557   // The untiled `linalg.fill` is still used as the `init_value` since it
558   // was originally a destination operand of the untiled `linalg.matmul`.
559   // When fusing an operand that is a destination operand.
560   //   - Update the iter_arg of the outer most loop to use the destination
561   //     of the untiled producer.
562   //   - Update the destination of the slice of the tiled producer generated
563   //     to use the same basic block argument as the slice that was used to
564   //     generate inplace the tiled implementation of the producer.
565   // With this the IR will be.
566   //
567   // ```
568   // %0 = linalg.init
569   // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
570   //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
571   //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
572   //     %4 = linalg.fill .. outs(%3 : )
573   //     .. = linalg.matmul .. outs(%4 : )
574   //   }
575   // }
576   // ```
577   // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
578   // Update to use that when it does become available.
579   scf::ForOp outerMostLoop = loops.front();
580   Optional<unsigned> iterArgNumber;
581   if (destinationIterArg) {
582     iterArgNumber =
583         outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
584   }
585   if (iterArgNumber) {
586     int64_t resultNumber = fusableProducer.getResultNumber();
587     if (auto dstOp =
588             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
589       outerMostLoop.setIterArg(iterArgNumber.value(),
590                                dstOp.getTiedOpOperand(fusableProducer)->get());
591     }
592     if (auto dstOp = fusedProducerValue.value()
593                          .getDefiningOp<DestinationStyleOpInterface>()) {
594       scf::ForOp innerMostLoop = loops.back();
595       updateDestinationOperandsForTiledOp(
596           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
597           innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
598     }
599   }
600   return fusedProducerValue->getDefiningOp();
601 }
602 
603 /// Implementation of tile consumer and fuse producer greedily.
604 FailureOr<scf::SCFTileAndFuseResult>
605 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
606     RewriterBase &rewriter, TilingInterface consumer,
607     const scf::SCFTileAndFuseOptions &options) {
608   // This transformation is only valid for ops that return values (i.e. not
609   // valid to use with operations that have memref operands).
610   if (!consumer->getNumResults()) {
611     return rewriter.notifyMatchFailure(
612         consumer, "invalid pattern for op with no results");
613   }
614 
615   // 1. First tile the consumer.
616   scf::SCFTileAndFuseResult tileAndFuseResult;
617   llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
618   {
619     FailureOr<scf::SCFTilingResult> tilingResult =
620         tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
621     if (failed(tilingResult))
622       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
623     for (auto *tiledOp : tilingResult->tiledOps)
624       tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
625     tileAndFuseResult.loops = std::move(tilingResult->loops);
626     for (const auto &result : llvm::enumerate(
627              llvm::zip(consumer->getResults(), tilingResult->replacements))) {
628       tileAndFuseResult.replacements[std::get<0>(result.value())] =
629           std::get<1>(result.value());
630       yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
631           result.index())] = result.index();
632     }
633   }
634 
635   // If there are no loops generated, fusion is immaterial.
636   if (tileAndFuseResult.loops.empty())
637     return tileAndFuseResult;
638 
639   // 2. Typically, the operands of the tiled operation are slices of the
640   //    operands of the untiled operation. These are expressed in IR using
641   //    `tensor.extract_slice` operations with source being the operands of the
642   //    untiled operation. Create a worklist of these `tensor.extract_slice`
643   //    operations. If the producers of the source of the `tensor.extract_slice`
644   //    can be tiled such that the tiled value is generated in-place, that
645   //    effectively tiles + fuses the operations.
646   auto addCandidateSlices = [](Operation *fusedOp,
647                                std::deque<tensor::ExtractSliceOp> &candidates) {
648     for (Value operand : fusedOp->getOperands())
649       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
650         candidates.push_back(sliceOp);
651   };
652 
653   std::deque<tensor::ExtractSliceOp> candidates;
654   addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
655   OpBuilder::InsertionGuard g(rewriter);
656   while (!candidates.empty()) {
657     // Traverse the slices in BFS fashion.
658     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
659     candidates.pop_front();
660 
661     // The operands of the fused producer might themselved be slices of
662     // values produced by operations that implement the `TilingInterface`.
663     // Add these operations to the worklist.
664     Optional<Operation *> fusedProducer = tileAndFuseProducerOfSlice(
665         rewriter, candidateSliceOp, tileAndFuseResult.loops);
666     if (!fusedProducer)
667       continue;
668 
669     tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer.value());
670     addCandidateSlices(fusedProducer.value(), candidates);
671   }
672   return tileAndFuseResult;
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // lowerToLoopsUsingSCFForOp implementation.
677 //===----------------------------------------------------------------------===//
678 
679 FailureOr<SmallVector<scf::ForOp>>
680 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
681                                      TilingInterface op) {
682   // TODO: Handle cases where the op has results if needed.
683   if (op->getNumResults() > 0) {
684     return rewriter.notifyMatchFailure(
685         op, "unable to lower to loops operations with return values");
686   }
687 
688   SmallVector<Range> domain = op.getIterationDomain(rewriter);
689   SmallVector<Value> ivs;
690   SmallVector<scf::ForOp> loops;
691   Location loc = op.getLoc();
692   for (auto loopRange : domain) {
693     Value offsetVal =
694         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
695     Value sizeVal =
696         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
697     Value strideVal =
698         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
699     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
700                                             strideVal, ValueRange{});
701     loops.push_back(loop);
702     ivs.push_back(loop.getInductionVar());
703     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
704   }
705   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
706     return failure();
707   }
708   return loops;
709 }
710