xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 67bcf9825ad8e41c93c4db06e98f62ee66b4e2cc)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Interfaces/TilingInterface.h"
24 #include "llvm/Support/Debug.h"
25 
26 #define DEBUG_TYPE "tile-using-interface"
27 
28 using namespace mlir;
29 
30 scf::SCFTilingOptions &
31 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
32   assert(!tileSizeComputationFunction && "tile sizes already set");
33   SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
34   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
35     OpBuilder::InsertionGuard guard(b);
36     b.setInsertionPointToStart(
37         &op->getParentOfType<func::FuncOp>().getBody().front());
38     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
39       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
40       return v;
41     }));
42   };
43   return *this;
44 }
45 
46 /// Helper method to adjust the interchange vector to match the iteration
47 /// domain.
48 static SmallVector<unsigned>
49 fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
50                       size_t iterationDomainSize) {
51   SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector);
52   if (filledVector.size() < iterationDomainSize) {
53     auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize);
54     filledVector.append(range.begin(), range.end());
55   }
56   if (filledVector.size() > iterationDomainSize)
57     filledVector.resize(iterationDomainSize);
58   return filledVector;
59 }
60 
61 /// Helper method to apply permutation to a vector
62 template <typename T>
63 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
64                                                ArrayRef<unsigned> interchange) {
65   assert(interchange.size() == vector.size());
66   return llvm::to_vector(
67       llvm::map_range(interchange, [&](unsigned val) { return vector[val]; }));
68 }
69 /// Helper method to apply to invert a permutation.
70 static SmallVector<unsigned>
71 invertPermutationVector(ArrayRef<unsigned> interchange) {
72   SmallVector<unsigned> inversion(interchange.size());
73   for (const auto &pos : llvm::enumerate(interchange)) {
74     inversion[pos.value()] = pos.index();
75   }
76   return inversion;
77 }
78 /// Method to check if an interchange vector is a permutation.
79 static bool isPermutation(ArrayRef<unsigned> interchange) {
80   llvm::SmallDenseSet<unsigned, 4> seenVals;
81   for (auto val : interchange) {
82     if (seenVals.count(val))
83       return false;
84     seenVals.insert(val);
85   }
86   return seenVals.size() == interchange.size();
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // tileUsingSCFForOp implementation.
91 //===----------------------------------------------------------------------===//
92 
93 // Check if `stride` evenly divides the trip count `size - offset`.
94 static bool tileDividesIterationDomain(Range loopRange) {
95   Optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
96   if (!offsetAsInt)
97     return false;
98   Optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
99   if (!sizeAsInt)
100     return false;
101   Optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
102   if (!strideAsInt)
103     return false;
104   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
105 }
106 
107 /// Generate an empty loop nest that represents the tiled loop nest shell.
108 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
109 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
110 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
111 /// the
112 ///   tile processed within the inner most loop.
113 static SmallVector<scf::ForOp>
114 generateTileLoopNest(OpBuilder &builder, Location loc,
115                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
116                      SmallVector<OpFoldResult> &offsets,
117                      SmallVector<OpFoldResult> &sizes) {
118   assert(!loopRanges.empty() && "expected at least one loop range");
119   assert(loopRanges.size() == tileSizeVals.size() &&
120          "expected as many tile sizes as loop ranges");
121   OpBuilder::InsertionGuard guard(builder);
122   SmallVector<scf::ForOp> loops;
123   offsets.resize(loopRanges.size());
124   sizes.resize(loopRanges.size());
125 
126   // The tile size to use (to avoid out of bounds access) is  minimum of
127   // `tileSize` and `ub - iv`, where `iv` is the induction variable
128   // of the tiled loop.
129   AffineExpr s0, s1, d0;
130   bindDims(builder.getContext(), d0);
131   bindSymbols(builder.getContext(), s0, s1);
132   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
133 
134   for (auto loopRange : llvm::enumerate(loopRanges)) {
135     Value offset =
136         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
137     Value size =
138         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
139     // No loops if tile size is zero. Set offset and size to the loop
140     // offset and size.
141     if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
142       offsets[loopRange.index()] = offset;
143       sizes[loopRange.index()] = size;
144       continue;
145     }
146 
147     auto loop = builder.create<scf::ForOp>(
148         loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{},
149         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
150             ValueRange /*iterArgs*/) {
151           bool canAvoidMap = tileDividesIterationDomain(
152               Range{loopRange.value().offset, loopRange.value().size,
153                     tileSizeVals[loopRange.index()]});
154           Value boundedTileSize =
155               (canAvoidMap)
156                   ? tileSizeVals[loopRange.index()]
157                   : builder.create<AffineMinOp>(
158                         bodyLoc, minMap,
159                         ValueRange{iv, tileSizeVals[loopRange.index()], size});
160           sizes[loopRange.index()] = boundedTileSize;
161           builder.create<scf::YieldOp>(loc);
162         });
163     offsets[loopRange.index()] = loop.getInductionVar();
164     loops.push_back(loop);
165     builder.setInsertionPoint(loop.getBody()->getTerminator());
166   }
167   return loops;
168 }
169 
170 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
171 /// construct the destructive update pattern that inserts the yielded
172 /// value into a destination tensor provided by `initValue` at offset
173 /// `tileOffsets` and size `tileSizes`. For example,
174 ///
175 /// ```mlir
176 /// scf.for %iv0 = ... {
177 ///   %0 = tiled_op
178 /// }
179 /// ```
180 ///
181 /// is transformed to
182 ///
183 /// ```mlir
184 /// scf.for %iv0 = ... iter_args(%arg = %0) {
185 ///   %1 = tensor.extract_slice %arg
186 ///   %2 = tiled_op
187 ///   %3 = tensor.insert_slice %2 into %arg
188 ///   scf.yield %3
189 /// }
190 /// ```
191 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
192 static FailureOr<SmallVector<Value>>
193 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
194                  ValueRange yieldedValues,
195                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
196                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
197                  MutableArrayRef<scf::ForOp> loops) {
198   NewYieldValueFn yieldValueFn =
199       [&](OpBuilder &b, Location loc,
200           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
201     SmallVector<Value> inserts;
202     for (auto yieldedValue : llvm::enumerate(yieldedValues)) {
203       ArrayRef<OpFoldResult> tileOffsets =
204           tileOffsetsList[yieldedValue.index()];
205       ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
206       SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
207                                             b.getIndexAttr(1));
208       Value insert = b.create<tensor::InsertSliceOp>(
209           loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
210           tileOffsets, tileSizes, tileStrides);
211       inserts.push_back(insert);
212     }
213     return inserts;
214   };
215 
216   SmallVector<scf::ForOp> newLoops =
217       replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
218                                    /*replaceIterOperandsUsesInLoop =*/false);
219   for (const auto &loop : llvm::enumerate(loops)) {
220     rewriter.eraseOp(loop.value());
221     loops[loop.index()] = newLoops[loop.index()];
222   }
223   return llvm::to_vector(llvm::map_range(
224       loops.front().getResults().take_back(yieldedValues.size()),
225       [](OpResult r) -> Value { return r; }));
226 }
227 
228 /// If the tiled operation is destination passing style, update the
229 /// slice of the destination used (which refers to the untiled destination)
230 /// to use the corresponding region argument of the innermost loop.
231 ///
232 /// ```mlir
233 /// %0 =
234 /// scf.for %iv0 = ... iter_args(%arg = %0) {
235 ///   %1 = tensor.extract_slice %0
236 ///   %2 = tiled_op
237 ///   %3 = tensor.insert_slice %2 into %arg
238 ///   scf.yield %3
239 /// }
240 /// ```
241 ///
242 /// is transformed to
243 ///
244 /// ```mlir
245 /// scf.for %iv0 = ... iter_args(%arg = %0) {
246 ///   %1 = tensor.extract_slice %arg
247 ///   %2 = tiled_op
248 ///   %3 = tensor.insert_slice %2 into %arg
249 ///   scf.yield %3
250 /// }
251 /// ```
252 static void
253 updateDestinationOperandsForTiledOp(OpBuilder &builder,
254                                     ValueRange tiledOpDestinationValues,
255                                     ValueRange bbArgsList) {
256   for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
257     auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
258     if (!sliceOp)
259       continue;
260     sliceOp.setOperand(0, bbArgsList[destValue.index()]);
261   }
262 }
263 
264 /// Implementation of tiling transformation of `op` that implements the
265 /// `TilingInterface` using `scf.for` to iterate over the tiles.
266 FailureOr<scf::SCFTilingResult>
267 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
268                              scf::SCFTilingOptions options) {
269   OpBuilder::InsertionGuard guard(rewriter);
270   rewriter.setInsertionPointAfter(op);
271 
272   if (!options.tileSizeComputationFunction) {
273     return rewriter.notifyMatchFailure(
274         op, "missing tile size computation function");
275   }
276 
277   // 1. Get the range of the loops that are represented by the operation.
278   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
279   size_t numLoops = iterationDomain.size();
280   if (numLoops == 0) {
281     return rewriter.notifyMatchFailure(
282         op, "unable to tile op with no iteration domain");
283   }
284 
285   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
286   // skips tiling a particular dimension. This convention is significantly
287   // simpler to handle instead of adjusting affine maps to account for missing
288   // dimensions.
289   SmallVector<Value> tileSizeVector =
290       options.tileSizeComputationFunction(rewriter, op);
291   if (tileSizeVector.size() < iterationDomain.size()) {
292     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
293     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
294   }
295 
296   scf::SCFTilingResult tilingResult;
297   SmallVector<OpFoldResult> offsets, sizes;
298   {
299     // If there is an interchange specified, permute the iteration domain and
300     // the tile sizes.
301     SmallVector<unsigned> interchangeVector;
302     if (!options.interchangeVector.empty()) {
303       interchangeVector = fillInterchangeVector(options.interchangeVector,
304                                                 iterationDomain.size());
305     }
306     if (!interchangeVector.empty()) {
307       if (!isPermutation(interchangeVector)) {
308         return rewriter.notifyMatchFailure(
309             op, "invalid intechange vector, not a permutation of the entire "
310                 "iteration space");
311       }
312 
313       iterationDomain =
314           applyPermutationToVector(iterationDomain, interchangeVector);
315       tileSizeVector =
316           applyPermutationToVector(tileSizeVector, interchangeVector);
317     }
318 
319     // 3. Materialize an empty loop nest that iterates over the tiles. These
320     // loops for now do not return any values even if the original operation has
321     // results.
322     tilingResult.loops = generateTileLoopNest(
323         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
324 
325     if (!interchangeVector.empty()) {
326       auto inversePermutation = invertPermutationVector(interchangeVector);
327       offsets = applyPermutationToVector(offsets, inversePermutation);
328       sizes = applyPermutationToVector(sizes, inversePermutation);
329     }
330   }
331 
332   LLVM_DEBUG({
333     if (!tilingResult.loops.empty()) {
334       llvm::dbgs() << "LoopNest shell :\n";
335       tilingResult.loops.front().dump();
336       llvm::dbgs() << "\n";
337     }
338   });
339 
340   // 4. Generate the tiled implementation within the inner most loop.
341   if (!tilingResult.loops.empty())
342     rewriter.setInsertionPoint(
343         tilingResult.loops.back().getBody()->getTerminator());
344   SmallVector<Operation *> tiledImplementation =
345       op.getTiledImplementation(rewriter, offsets, sizes);
346   if (tiledImplementation.size() != 1) {
347     return rewriter.notifyMatchFailure(
348         op, "expected tiled implementation to return a single op");
349   }
350   tilingResult.tiledOp = tiledImplementation[0];
351   if (op->getNumResults() == 0) {
352     // nothing more to do.
353     return tilingResult;
354   }
355 
356   // If loops are empty, the tiled op is used as the replacement for the untiled
357   // op.
358   if (tilingResult.loops.empty()) {
359     tilingResult.replacements = llvm::to_vector(
360         llvm::map_range(tiledImplementation[0]->getResults(),
361                         [](OpResult result) -> Value { return result; }));
362     return tilingResult;
363   }
364 
365   // 5. Yield all the results of the tiled operation. The surrounding loop
366   //    nest is modified to insert a destructive update pattern to yield
367   //    from the loop nest values to replace the untiled op with.
368   unsigned numResults = op->getNumResults();
369   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
370       resultSizesList(numResults);
371   for (auto result : llvm::enumerate(op->getResults())) {
372     if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
373                                         sizes,
374                                         resultOffsetsList[result.index()],
375                                         resultSizesList[result.index()]))) {
376       return rewriter.notifyMatchFailure(
377           op, "failed to get slice of result produced");
378     }
379   }
380 
381   FailureOr<SmallVector<Value>> replacementOr =
382       yieldTiledValues(rewriter, op.getDestinationOperands(rewriter),
383                        tilingResult.tiledOp->getResults(), resultOffsetsList,
384                        resultSizesList, tilingResult.loops);
385   if (failed(replacementOr))
386     return rewriter.notifyMatchFailure(op, "failed to yield replacement");
387   if (auto tiledInterfaceOp = dyn_cast<TilingInterface>(tilingResult.tiledOp)) {
388     auto innerMostLoop = tilingResult.loops.back();
389     updateDestinationOperandsForTiledOp(
390         rewriter, tiledInterfaceOp.getDestinationOperands(rewriter),
391         innerMostLoop.getRegionIterArgs());
392   }
393 
394   tilingResult.replacements = replacementOr.value();
395 
396   LLVM_DEBUG({
397     if (!tilingResult.loops.empty()) {
398       llvm::dbgs() << "After tiled implementation :\n";
399       tilingResult.loops.front().dump();
400       llvm::dbgs() << "\n";
401     }
402   });
403   return tilingResult;
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
408 //===----------------------------------------------------------------------===//
409 
410 /// Return the untiled producer whose slice is used in a tiled consumer. The
411 /// method traverses the tile loop nest (`loops`) if needed, and returns the
412 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
413 /// indicates that this is a destination operand of the consumer. If there was
414 /// no loop traversal needed, the second value of the returned tuple is empty.
415 static std::tuple<OpResult, Optional<OpOperand *>>
416 getUntiledProducerFromSliceSource(OpOperand *source,
417                                   ArrayRef<scf::ForOp> loops) {
418   Optional<OpOperand *> destinationIterArg;
419   auto loopIt = loops.rbegin();
420   while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
421     scf::ForOp loop = *loopIt;
422     if (iterArg.getOwner()->getParentOp() != loop)
423       break;
424     source = &loop.getOpOperandForRegionIterArg(iterArg);
425     loopIt++;
426   }
427   if (loopIt == loops.rend())
428     destinationIterArg = source;
429   return {source->get().dyn_cast<OpResult>(), destinationIterArg};
430 }
431 
432 /// Implementation of tile consumer and fuse producer greedily.
433 FailureOr<scf::SCFTileAndFuseResult>
434 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
435     RewriterBase &rewriter, TilingInterface consumer,
436     scf::SCFTileAndFuseOptions options) {
437   // This transformation is only valid for ops that return values (i.e. not
438   // valid to use with operations that have memref operands).
439   if (!consumer->getNumResults()) {
440     return rewriter.notifyMatchFailure(
441         consumer, "invalid pattern for op with no results");
442   }
443 
444   // 1. First tile the consumer.
445   scf::SCFTileAndFuseResult tileAndFuseResult;
446   llvm::SmallDenseMap<Value, unsigned> yieldedValueToResultNumber;
447   {
448     FailureOr<scf::SCFTilingResult> tilingResult =
449         tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
450     if (failed(tilingResult))
451       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
452     tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp);
453     tileAndFuseResult.loops = std::move(tilingResult->loops);
454     for (auto result : llvm::enumerate(
455              llvm::zip(consumer->getResults(), tilingResult->replacements))) {
456       tileAndFuseResult.replacements[std::get<0>(result.value())] =
457           std::get<1>(result.value());
458       yieldedValueToResultNumber[tilingResult->tiledOp->getResult(
459           result.index())] = result.index();
460     }
461   }
462 
463   // If there are no loops generated, fusion is immaterial.
464   if (tileAndFuseResult.loops.empty())
465     return tileAndFuseResult;
466 
467   // 2. Typically, the operands of the tiled operation are slices of the
468   //    operands of the untiled operation. These are expressed in IR using
469   //    `tensor.extract_slice` operations with source being the operands of the
470   //    untiled operation. Create a worklist of these `tensor.extract_slice`
471   //    operations. If the producers of the source of the `tensor.extract_slice`
472   //    can be tiled such that the tiled value is generated in-place, that
473   //    effectively tiles + fuses the operations.
474   auto addCandidateSlices = [](Operation *fusedOp,
475                                std::deque<tensor::ExtractSliceOp> &candidates) {
476     for (Value operand : fusedOp->getOperands())
477       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
478         candidates.push_back(sliceOp);
479   };
480 
481   std::deque<tensor::ExtractSliceOp> candidates;
482   addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
483   OpBuilder::InsertionGuard g(rewriter);
484   while (!candidates.empty()) {
485     // 2a. Traverse the slices in BFS fashion.
486     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
487     candidates.pop_front();
488 
489     // 2b. Get the producer of the source (potentially walking through
490     // `iter_args` of nested `scf.for`)
491     auto [fusableProducer, destinationIterArg] =
492         getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
493                                           tileAndFuseResult.loops);
494     if (!fusableProducer)
495       continue;
496 
497     // 2c. Generate the tiled implementation of the producer of the source
498     rewriter.setInsertionPoint(candidateSliceOp);
499     FailureOr<Value> fusedProducerValue =
500         tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
501                                                      fusableProducer);
502     if (failed(fusedProducerValue))
503       continue;
504     rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
505 
506     // 2d. The operands of the fused producer might themselved be slices of
507     //     values produced by operations that implement the `TilingInterface`.
508     //     Add these operations to the worklist.
509     Operation *fusedProducer = fusedProducerValue->getDefiningOp();
510     tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer);
511     addCandidateSlices(fusedProducer, candidates);
512 
513     // 2e. If the slice is for a destination operand, for example,
514     //
515     // ```mlir
516     // %0 = linalg.init
517     // %1 = linalg.fill .. outs(%0 : )
518     // %2 = scf.for .. iter_args(%arg0 = %1) {
519     //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
520     //     %4 = tensor.extract_slice %arg1 [..]
521     //     .. = linalg.matmul .. outs(%4 : )
522     //   }
523     // }
524     // ```
525     //
526     // the IR is currently
527     //
528     // ```
529     // %0 = linalg.init
530     // %1 = linalg.fill
531     // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
532     //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
533     //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
534     //     %5 = linalg.fill .. outs(%4 : )
535     //     .. = linalg.matmul .. outs(%5 : )
536     //   }
537     // }
538     // ```
539     //
540     // The untiled `linalg.fill` is still used as the `init_value` since it
541     // was originally a destination operand of the untiled `linalg.matmul`.
542     // When fusing an operand that is a destination operand.
543     //   - Update the iter_arg of the outer most loop to use the destination
544     //     of the untiled producer.
545     //   - Update the destination of the slice of the tiled producer generated
546     //     to use the same basic block argument as the slice that was used to
547     //     generate inplace the tiled implementation of the producer.
548     // With this the IR will be.
549     //
550     // ```
551     // %0 = linalg.init
552     // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
553     //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
554     //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
555     //     %4 = linalg.fill .. outs(%3 : )
556     //     .. = linalg.matmul .. outs(%4 : )
557     //   }
558     // }
559     // ```
560     // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
561     // Update to use that when it does become available.
562     scf::ForOp outerMostLoop = tileAndFuseResult.loops.front();
563     Optional<unsigned> iterArgNumber;
564     if (destinationIterArg) {
565       iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand(
566           *destinationIterArg.value());
567     }
568     if (iterArgNumber) {
569       unsigned resultNumber = fusableProducer.getResultNumber();
570       if (auto producerOp =
571               dyn_cast<TilingInterface>(fusableProducer.getOwner())) {
572         SmallVector<Value> destination =
573             producerOp.getDestinationOperands(rewriter);
574         outerMostLoop.setIterArg(iterArgNumber.value(),
575                                  destination[resultNumber]);
576       }
577       if (auto tiledAndFusedInterfaceOp =
578               fusedProducerValue.value().getDefiningOp<TilingInterface>()) {
579         scf::ForOp innerMostLoop = tileAndFuseResult.loops.back();
580         SmallVector<Value> destination =
581             tiledAndFusedInterfaceOp.getDestinationOperands(rewriter);
582         updateDestinationOperandsForTiledOp(
583             rewriter, destination[resultNumber],
584             innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
585       }
586     }
587   }
588   return tileAndFuseResult;
589 }
590 
591 //===----------------------------------------------------------------------===//
592 // lowerToLoopsUsingSCFForOp implementation.
593 //===----------------------------------------------------------------------===//
594 
595 FailureOr<SmallVector<scf::ForOp>>
596 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
597                                      TilingInterface op) {
598   // TODO: Handle cases where the op has results if needed.
599   if (op->getNumResults() > 0) {
600     return rewriter.notifyMatchFailure(
601         op, "unable to lower to loops operations with return values");
602   }
603 
604   SmallVector<Range> domain = op.getIterationDomain(rewriter);
605   SmallVector<Value> ivs;
606   SmallVector<scf::ForOp> loops;
607   Location loc = op.getLoc();
608   for (auto loopRange : domain) {
609     Value offsetVal =
610         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
611     Value sizeVal =
612         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
613     Value strideVal =
614         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
615     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
616                                             strideVal, ValueRange{});
617     loops.push_back(loop);
618     ivs.push_back(loop.getInductionVar());
619     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
620   }
621   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
622     return failure();
623   }
624   return loops;
625 }
626