xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (revision 71cf48a62a6d2930d4e782545764aeb725df12f7)
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
24 #include "mlir/Interfaces/TilingInterface.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "tile-using-interface"
28 
29 using namespace mlir;
30 
31 scf::SCFTilingOptions &
32 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
33   assert(!tileSizeComputationFunction && "tile sizes already set");
34   SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
35   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
36     OpBuilder::InsertionGuard guard(b);
37     b.setInsertionPointToStart(
38         &op->getParentOfType<func::FuncOp>().getBody().front());
39     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
40       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
41       return v;
42     }));
43   };
44   return *this;
45 }
46 
47 /// Helper method to adjust the interchange vector to match the iteration
48 /// domain.
49 static SmallVector<int64_t>
50 fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
51                       size_t iterationDomainSize) {
52   SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
53   if (filledVector.size() < iterationDomainSize) {
54     auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
55     filledVector.append(range.begin(), range.end());
56   }
57   if (filledVector.size() > iterationDomainSize)
58     filledVector.resize(iterationDomainSize);
59   return filledVector;
60 }
61 
62 /// Helper method to apply permutation to a vector
63 template <typename T>
64 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
65                                                ArrayRef<int64_t> interchange) {
66   assert(interchange.size() == vector.size());
67   return llvm::to_vector(
68       llvm::map_range(interchange, [&](int64_t val) { return vector[val]; }));
69 }
70 /// Helper method to apply to invert a permutation.
71 static SmallVector<int64_t>
72 invertPermutationVector(ArrayRef<int64_t> interchange) {
73   SmallVector<int64_t> inversion(interchange.size());
74   for (const auto &pos : llvm::enumerate(interchange)) {
75     inversion[pos.value()] = pos.index();
76   }
77   return inversion;
78 }
79 /// Method to check if an interchange vector is a permutation.
80 static bool isPermutation(ArrayRef<int64_t> interchange) {
81   llvm::SmallDenseSet<int64_t, 4> seenVals;
82   for (auto val : interchange) {
83     if (seenVals.count(val))
84       return false;
85     seenVals.insert(val);
86   }
87   return seenVals.size() == interchange.size();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // tileUsingSCFForOp implementation.
92 //===----------------------------------------------------------------------===//
93 
94 // Check if `stride` evenly divides the trip count `size - offset`.
95 static bool tileDividesIterationDomain(Range loopRange) {
96   Optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
97   if (!offsetAsInt)
98     return false;
99   Optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
100   if (!sizeAsInt)
101     return false;
102   Optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
103   if (!strideAsInt)
104     return false;
105   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
106 }
107 
108 /// Returns the bounded tile size given the current `iv`, `loopRange` and
109 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
110 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
111                                        Range loopRange, Value iv,
112                                        Value tileSize) {
113   Optional<int64_t> ts = getConstantIntValue(tileSize);
114   if (ts && ts.value() == 1)
115     return getAsOpFoldResult(tileSize);
116 
117   if (tileDividesIterationDomain(
118           Range{loopRange.offset, loopRange.size, tileSize}))
119     return tileSize;
120 
121   // The tile size to use (to avoid out of bounds access) is  minimum of
122   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
123   // loop.
124   AffineExpr s0, s1, d0;
125   bindDims(b.getContext(), d0);
126   bindSymbols(b.getContext(), s0, s1);
127   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
128   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
129   return b.create<AffineMinOp>(loc, minMap, ValueRange{iv, tileSize, size})
130       .getResult();
131 }
132 
133 /// Generate an empty loop nest that represents the tiled loop nest shell.
134 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
135 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
136 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
137 /// the
138 ///   tile processed within the inner most loop.
139 static SmallVector<scf::ForOp>
140 generateTileLoopNest(OpBuilder &builder, Location loc,
141                      ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
142                      SmallVector<OpFoldResult> &offsets,
143                      SmallVector<OpFoldResult> &sizes) {
144   assert(!loopRanges.empty() && "expected at least one loop range");
145   assert(loopRanges.size() == tileSizeVals.size() &&
146          "expected as many tile sizes as loop ranges");
147   OpBuilder::InsertionGuard guard(builder);
148   SmallVector<scf::ForOp> loops;
149   offsets.resize(loopRanges.size());
150   sizes.resize(loopRanges.size());
151 
152   for (auto loopRange : llvm::enumerate(loopRanges)) {
153     Value offset =
154         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
155     Value size =
156         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
157     Value tileSize = tileSizeVals[loopRange.index()];
158     // No loops if tile size is zero. Set offset and size to the loop
159     // offset and size.
160     if (matchPattern(tileSize, m_Zero())) {
161       offsets[loopRange.index()] = offset;
162       sizes[loopRange.index()] = size;
163       continue;
164     }
165 
166     auto loop = builder.create<scf::ForOp>(
167         loc, offset, size, tileSize, ValueRange{},
168         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
169             ValueRange /*iterArgs*/) {
170           sizes[loopRange.index()] = getBoundedTileSize(
171               bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize);
172           builder.create<scf::YieldOp>(loc);
173         });
174     offsets[loopRange.index()] = loop.getInductionVar();
175     loops.push_back(loop);
176     builder.setInsertionPoint(loop.getBody()->getTerminator());
177   }
178   return loops;
179 }
180 
181 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
182 /// construct the destructive update pattern that inserts the yielded
183 /// value into a destination tensor provided by `initValue` at offset
184 /// `tileOffsets` and size `tileSizes`. For example,
185 ///
186 /// ```mlir
187 /// scf.for %iv0 = ... {
188 ///   %0 = tiled_op
189 /// }
190 /// ```
191 ///
192 /// is transformed to
193 ///
194 /// ```mlir
195 /// scf.for %iv0 = ... iter_args(%arg = %0) {
196 ///   %1 = tensor.extract_slice %arg
197 ///   %2 = tiled_op
198 ///   %3 = tensor.insert_slice %2 into %arg
199 ///   scf.yield %3
200 /// }
201 /// ```
202 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
203 static FailureOr<SmallVector<Value>>
204 yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
205                  ValueRange yieldedValues,
206                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
207                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
208                  MutableArrayRef<scf::ForOp> loops) {
209   NewYieldValueFn yieldValueFn =
210       [&](OpBuilder &b, Location loc,
211           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
212     SmallVector<Value> inserts;
213     for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
214       ArrayRef<OpFoldResult> tileOffsets =
215           tileOffsetsList[yieldedValue.index()];
216       ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
217       SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
218                                             b.getIndexAttr(1));
219       Value insert = b.create<tensor::InsertSliceOp>(
220           loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
221           tileOffsets, tileSizes, tileStrides);
222       inserts.push_back(insert);
223     }
224     return inserts;
225   };
226 
227   SmallVector<scf::ForOp> newLoops =
228       replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
229                                    /*replaceIterOperandsUsesInLoop =*/false);
230   for (const auto &loop : llvm::enumerate(loops)) {
231     rewriter.eraseOp(loop.value());
232     loops[loop.index()] = newLoops[loop.index()];
233   }
234   return llvm::to_vector(llvm::map_range(
235       loops.front().getResults().take_back(yieldedValues.size()),
236       [](OpResult r) -> Value { return r; }));
237 }
238 
239 /// If the tiled operation is destination passing style, update the
240 /// slice of the destination used (which refers to the untiled destination)
241 /// to use the corresponding region argument of the innermost loop.
242 ///
243 /// ```mlir
244 /// %0 =
245 /// scf.for %iv0 = ... iter_args(%arg = %0) {
246 ///   %1 = tensor.extract_slice %0
247 ///   %2 = tiled_op
248 ///   %3 = tensor.insert_slice %2 into %arg
249 ///   scf.yield %3
250 /// }
251 /// ```
252 ///
253 /// is transformed to
254 ///
255 /// ```mlir
256 /// scf.for %iv0 = ... iter_args(%arg = %0) {
257 ///   %1 = tensor.extract_slice %arg
258 ///   %2 = tiled_op
259 ///   %3 = tensor.insert_slice %2 into %arg
260 ///   scf.yield %3
261 /// }
262 /// ```
263 static void
264 updateDestinationOperandsForTiledOp(OpBuilder &builder,
265                                     ValueRange tiledOpDestinationValues,
266                                     ValueRange bbArgsList) {
267   for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
268     auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
269     if (!sliceOp)
270       continue;
271     sliceOp.setOperand(0, bbArgsList[destValue.index()]);
272   }
273 }
274 
275 /// Implementation of tiling transformation of `op` that implements the
276 /// `TilingInterface` using `scf.for` to iterate over the tiles.
277 FailureOr<scf::SCFTilingResult>
278 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
279                              const scf::SCFTilingOptions &options) {
280   OpBuilder::InsertionGuard guard(rewriter);
281   rewriter.setInsertionPointAfter(op);
282 
283   if (!options.tileSizeComputationFunction) {
284     return rewriter.notifyMatchFailure(
285         op, "missing tile size computation function");
286   }
287 
288   // Get destination tensors.
289   SmallVector<Value> destinationTensors;
290   if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
291                                              destinationTensors)))
292     return rewriter.notifyMatchFailure(op, "failed to get destinations");
293 
294   // 1. Get the range of the loops that are represented by the operation.
295   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
296   size_t numLoops = iterationDomain.size();
297   if (numLoops == 0) {
298     return rewriter.notifyMatchFailure(
299         op, "unable to tile op with no iteration domain");
300   }
301 
302   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
303   // skips tiling a particular dimension. This convention is significantly
304   // simpler to handle instead of adjusting affine maps to account for missing
305   // dimensions.
306   SmallVector<Value> tileSizeVector =
307       options.tileSizeComputationFunction(rewriter, op);
308   if (tileSizeVector.size() < iterationDomain.size()) {
309     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
310     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
311   }
312 
313   scf::SCFTilingResult tilingResult;
314   SmallVector<OpFoldResult> offsets, sizes;
315   {
316     // If there is an interchange specified, permute the iteration domain and
317     // the tile sizes.
318     SmallVector<int64_t> interchangeVector;
319     if (!options.interchangeVector.empty()) {
320       interchangeVector = fillInterchangeVector(options.interchangeVector,
321                                                 iterationDomain.size());
322     }
323     if (!interchangeVector.empty()) {
324       if (!isPermutation(interchangeVector)) {
325         return rewriter.notifyMatchFailure(
326             op, "invalid intechange vector, not a permutation of the entire "
327                 "iteration space");
328       }
329 
330       iterationDomain =
331           applyPermutationToVector(iterationDomain, interchangeVector);
332       tileSizeVector =
333           applyPermutationToVector(tileSizeVector, interchangeVector);
334     }
335 
336     // 3. Materialize an empty loop nest that iterates over the tiles. These
337     // loops for now do not return any values even if the original operation has
338     // results.
339     tilingResult.loops = generateTileLoopNest(
340         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
341 
342     if (!interchangeVector.empty()) {
343       auto inversePermutation = invertPermutationVector(interchangeVector);
344       offsets = applyPermutationToVector(offsets, inversePermutation);
345       sizes = applyPermutationToVector(sizes, inversePermutation);
346     }
347   }
348 
349   LLVM_DEBUG({
350     if (!tilingResult.loops.empty()) {
351       llvm::dbgs() << "LoopNest shell :\n";
352       tilingResult.loops.front().dump();
353       llvm::dbgs() << "\n";
354     }
355   });
356 
357   // 4. Generate the tiled implementation within the inner most loop.
358   if (!tilingResult.loops.empty())
359     rewriter.setInsertionPoint(
360         tilingResult.loops.back().getBody()->getTerminator());
361   SmallVector<Operation *> tiledImplementation =
362       op.getTiledImplementation(rewriter, offsets, sizes);
363   if (tiledImplementation.size() != 1) {
364     return rewriter.notifyMatchFailure(
365         op, "expected tiled implementation to return a single op");
366   }
367   tilingResult.tiledOp = tiledImplementation[0];
368   if (op->getNumResults() == 0) {
369     // nothing more to do.
370     return tilingResult;
371   }
372 
373   // If loops are empty, the tiled op is used as the replacement for the untiled
374   // op.
375   if (tilingResult.loops.empty()) {
376     tilingResult.replacements = llvm::to_vector(
377         llvm::map_range(tiledImplementation[0]->getResults(),
378                         [](OpResult result) -> Value { return result; }));
379     return tilingResult;
380   }
381 
382   // 5. Yield all the results of the tiled operation. The surrounding loop
383   //    nest is modified to insert a destructive update pattern to yield
384   //    from the loop nest values to replace the untiled op with.
385   int64_t numResults = op->getNumResults();
386   SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
387       resultSizesList(numResults);
388   for (const auto &result : llvm::enumerate(op->getResults())) {
389     if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
390                                         sizes,
391                                         resultOffsetsList[result.index()],
392                                         resultSizesList[result.index()]))) {
393       return rewriter.notifyMatchFailure(
394           op, "failed to get slice of result produced");
395     }
396   }
397 
398   FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
399       rewriter, destinationTensors, tilingResult.tiledOp->getResults(),
400       resultOffsetsList, resultSizesList, tilingResult.loops);
401   if (failed(replacementOr))
402     return rewriter.notifyMatchFailure(op, "failed to yield replacement");
403 
404   if (auto dstOp =
405           dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOp)) {
406     auto innerMostLoop = tilingResult.loops.back();
407     SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
408     assert(destinationTensors.size() ==
409                innerMostLoop.getRegionIterArgs().size() &&
410            "unexpected number of outputs");
411     updateDestinationOperandsForTiledOp(rewriter, destinationTensors,
412                                         innerMostLoop.getRegionIterArgs());
413   }
414 
415   tilingResult.replacements = replacementOr.value();
416 
417   LLVM_DEBUG({
418     if (!tilingResult.loops.empty()) {
419       llvm::dbgs() << "After tiled implementation :\n";
420       tilingResult.loops.front().dump();
421       llvm::dbgs() << "\n";
422     }
423   });
424   return tilingResult;
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
429 //===----------------------------------------------------------------------===//
430 
431 /// Return the untiled producer whose slice is used in a tiled consumer. The
432 /// method traverses the tile loop nest (`loops`) if needed, and returns the
433 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
434 /// indicates that this is a destination operand of the consumer. If there was
435 /// no loop traversal needed, the second value of the returned tuple is empty.
436 static std::tuple<OpResult, Optional<OpOperand *>>
437 getUntiledProducerFromSliceSource(OpOperand *source,
438                                   ArrayRef<scf::ForOp> loops) {
439   Optional<OpOperand *> destinationIterArg;
440   auto loopIt = loops.rbegin();
441   while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
442     scf::ForOp loop = *loopIt;
443     if (iterArg.getOwner()->getParentOp() != loop)
444       break;
445     source = &loop.getOpOperandForRegionIterArg(iterArg);
446     loopIt++;
447   }
448   if (loopIt == loops.rend())
449     destinationIterArg = source;
450   return {source->get().dyn_cast<OpResult>(), destinationIterArg};
451 }
452 
453 /// Implementation of tile consumer and fuse producer greedily.
454 FailureOr<scf::SCFTileAndFuseResult>
455 mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
456     RewriterBase &rewriter, TilingInterface consumer,
457     const scf::SCFTileAndFuseOptions &options) {
458   // This transformation is only valid for ops that return values (i.e. not
459   // valid to use with operations that have memref operands).
460   if (!consumer->getNumResults()) {
461     return rewriter.notifyMatchFailure(
462         consumer, "invalid pattern for op with no results");
463   }
464 
465   // 1. First tile the consumer.
466   scf::SCFTileAndFuseResult tileAndFuseResult;
467   llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
468   {
469     FailureOr<scf::SCFTilingResult> tilingResult =
470         tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
471     if (failed(tilingResult))
472       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
473     tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp);
474     tileAndFuseResult.loops = std::move(tilingResult->loops);
475     for (const auto &result : llvm::enumerate(
476              llvm::zip(consumer->getResults(), tilingResult->replacements))) {
477       tileAndFuseResult.replacements[std::get<0>(result.value())] =
478           std::get<1>(result.value());
479       yieldedValueToResultNumber[tilingResult->tiledOp->getResult(
480           result.index())] = result.index();
481     }
482   }
483 
484   // If there are no loops generated, fusion is immaterial.
485   if (tileAndFuseResult.loops.empty())
486     return tileAndFuseResult;
487 
488   // 2. Typically, the operands of the tiled operation are slices of the
489   //    operands of the untiled operation. These are expressed in IR using
490   //    `tensor.extract_slice` operations with source being the operands of the
491   //    untiled operation. Create a worklist of these `tensor.extract_slice`
492   //    operations. If the producers of the source of the `tensor.extract_slice`
493   //    can be tiled such that the tiled value is generated in-place, that
494   //    effectively tiles + fuses the operations.
495   auto addCandidateSlices = [](Operation *fusedOp,
496                                std::deque<tensor::ExtractSliceOp> &candidates) {
497     for (Value operand : fusedOp->getOperands())
498       if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
499         candidates.push_back(sliceOp);
500   };
501 
502   std::deque<tensor::ExtractSliceOp> candidates;
503   addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
504   OpBuilder::InsertionGuard g(rewriter);
505   while (!candidates.empty()) {
506     // 2a. Traverse the slices in BFS fashion.
507     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
508     candidates.pop_front();
509 
510     // 2b. Get the producer of the source (potentially walking through
511     // `iter_args` of nested `scf.for`)
512     auto [fusableProducer, destinationIterArg] =
513         getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
514                                           tileAndFuseResult.loops);
515     if (!fusableProducer)
516       continue;
517 
518     // 2c. Generate the tiled implementation of the producer of the source
519     rewriter.setInsertionPoint(candidateSliceOp);
520     FailureOr<Value> fusedProducerValue =
521         tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
522                                                      fusableProducer);
523     if (failed(fusedProducerValue))
524       continue;
525     rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
526 
527     // 2d. The operands of the fused producer might themselved be slices of
528     //     values produced by operations that implement the `TilingInterface`.
529     //     Add these operations to the worklist.
530     Operation *fusedProducer = fusedProducerValue->getDefiningOp();
531     tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer);
532     addCandidateSlices(fusedProducer, candidates);
533 
534     // 2e. If the slice is for a destination operand, for example,
535     //
536     // ```mlir
537     // %0 = linalg.init
538     // %1 = linalg.fill .. outs(%0 : )
539     // %2 = scf.for .. iter_args(%arg0 = %1) {
540     //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
541     //     %4 = tensor.extract_slice %arg1 [..]
542     //     .. = linalg.matmul .. outs(%4 : )
543     //   }
544     // }
545     // ```
546     //
547     // the IR is currently
548     //
549     // ```
550     // %0 = linalg.init
551     // %1 = linalg.fill
552     // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
553     //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
554     //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
555     //     %5 = linalg.fill .. outs(%4 : )
556     //     .. = linalg.matmul .. outs(%5 : )
557     //   }
558     // }
559     // ```
560     //
561     // The untiled `linalg.fill` is still used as the `init_value` since it
562     // was originally a destination operand of the untiled `linalg.matmul`.
563     // When fusing an operand that is a destination operand.
564     //   - Update the iter_arg of the outer most loop to use the destination
565     //     of the untiled producer.
566     //   - Update the destination of the slice of the tiled producer generated
567     //     to use the same basic block argument as the slice that was used to
568     //     generate inplace the tiled implementation of the producer.
569     // With this the IR will be.
570     //
571     // ```
572     // %0 = linalg.init
573     // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
574     //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
575     //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
576     //     %4 = linalg.fill .. outs(%3 : )
577     //     .. = linalg.matmul .. outs(%4 : )
578     //   }
579     // }
580     // ```
581     // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
582     // Update to use that when it does become available.
583     scf::ForOp outerMostLoop = tileAndFuseResult.loops.front();
584     Optional<unsigned> iterArgNumber;
585     if (destinationIterArg) {
586       iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand(
587           *destinationIterArg.value());
588     }
589     if (iterArgNumber) {
590       int64_t resultNumber = fusableProducer.getResultNumber();
591       if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(
592               fusableProducer.getOwner())) {
593         outerMostLoop.setIterArg(
594             iterArgNumber.value(),
595             dstOp.getTiedOpOperand(fusableProducer)->get());
596       }
597       if (auto dstOp = fusedProducerValue.value()
598                            .getDefiningOp<DestinationStyleOpInterface>()) {
599         scf::ForOp innerMostLoop = tileAndFuseResult.loops.back();
600         updateDestinationOperandsForTiledOp(
601             rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
602             innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
603       }
604     }
605   }
606   return tileAndFuseResult;
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // lowerToLoopsUsingSCFForOp implementation.
611 //===----------------------------------------------------------------------===//
612 
613 FailureOr<SmallVector<scf::ForOp>>
614 mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
615                                      TilingInterface op) {
616   // TODO: Handle cases where the op has results if needed.
617   if (op->getNumResults() > 0) {
618     return rewriter.notifyMatchFailure(
619         op, "unable to lower to loops operations with return values");
620   }
621 
622   SmallVector<Range> domain = op.getIterationDomain(rewriter);
623   SmallVector<Value> ivs;
624   SmallVector<scf::ForOp> loops;
625   Location loc = op.getLoc();
626   for (auto loopRange : domain) {
627     Value offsetVal =
628         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
629     Value sizeVal =
630         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
631     Value strideVal =
632         getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
633     auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
634                                             strideVal, ValueRange{});
635     loops.push_back(loop);
636     ivs.push_back(loop.getInductionVar());
637     rewriter.setInsertionPoint(loop.getBody()->getTerminator());
638   }
639   if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
640     return failure();
641   }
642   return loops;
643 }
644