xref: /llvm-project/mlir/lib/Dialect/SCF/Utils/Utils.cpp (revision d056c756aea3fc709cf7d6bc8acabe9a8c7218db)
1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Utils/Utils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Interfaces/SideEffectInterfaces.h"
25 #include "mlir/Transforms/RegionUtils.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/MathExtras.h"
32 #include <cstdint>
33 
34 using namespace mlir;
35 
36 #define DEBUG_TYPE "scf-utils"
37 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
38 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
39 
40 SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
41     RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
42     ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
43     bool replaceIterOperandsUsesInLoop) {
44   if (loopNest.empty())
45     return {};
46   // This method is recursive (to make it more readable). Adding an
47   // assertion here to limit the recursion. (See
48   // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
49   assert(loopNest.size() <= 10 &&
50          "exceeded recursion limit when yielding value from loop nest");
51 
52   // To yield a value from a perfectly nested loop nest, the following
53   // pattern needs to be created, i.e. starting with
54   //
55   // ```mlir
56   //  scf.for .. {
57   //    scf.for .. {
58   //      scf.for .. {
59   //        %value = ...
60   //      }
61   //    }
62   //  }
63   // ```
64   //
65   // needs to be modified to
66   //
67   // ```mlir
68   // %0 = scf.for .. iter_args(%arg0 = %init) {
69   //   %1 = scf.for .. iter_args(%arg1 = %arg0) {
70   //     %2 = scf.for .. iter_args(%arg2 = %arg1) {
71   //       %value = ...
72   //       scf.yield %value
73   //     }
74   //     scf.yield %2
75   //   }
76   //   scf.yield %1
77   // }
78   // ```
79   //
80   // The inner most loop is handled using the `replaceWithAdditionalYields`
81   // that works on a single loop.
82   if (loopNest.size() == 1) {
83     auto innerMostLoop =
84         cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
85             rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
86             newYieldValuesFn));
87     return {innerMostLoop};
88   }
89   // The outer loops are modified by calling this method recursively
90   // - The return value of the inner loop is the value yielded by this loop.
91   // - The region iter args of this loop are the init_args for the inner loop.
92   SmallVector<scf::ForOp> newLoopNest;
93   NewYieldValuesFn fn =
94       [&](OpBuilder &innerBuilder, Location loc,
95           ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
96     newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
97                                                innerNewBBArgs, newYieldValuesFn,
98                                                replaceIterOperandsUsesInLoop);
99     return llvm::to_vector(llvm::map_range(
100         newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
101         [](OpResult r) -> Value { return r; }));
102   };
103   scf::ForOp outerMostLoop =
104       cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
105           rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
106   newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
107   return newLoopNest;
108 }
109 
110 /// Outline a region with a single block into a new FuncOp.
111 /// Assumes the FuncOp result types is the type of the yielded operands of the
112 /// single block. This constraint makes it easy to determine the result.
113 /// This method also clones the `arith::ConstantIndexOp` at the start of
114 /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
115 /// provided, it will be set to point to the operation that calls the outlined
116 /// function.
117 // TODO: support more than single-block regions.
118 // TODO: more flexible constant handling.
119 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
120                                                        Location loc,
121                                                        Region &region,
122                                                        StringRef funcName,
123                                                        func::CallOp *callOp) {
124   assert(!funcName.empty() && "funcName cannot be empty");
125   if (!region.hasOneBlock())
126     return failure();
127 
128   Block *originalBlock = &region.front();
129   Operation *originalTerminator = originalBlock->getTerminator();
130 
131   // Outline before current function.
132   OpBuilder::InsertionGuard g(rewriter);
133   rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
134 
135   SetVector<Value> captures;
136   getUsedValuesDefinedAbove(region, captures);
137 
138   ValueRange outlinedValues(captures.getArrayRef());
139   SmallVector<Type> outlinedFuncArgTypes;
140   SmallVector<Location> outlinedFuncArgLocs;
141   // Region's arguments are exactly the first block's arguments as per
142   // Region::getArguments().
143   // Func's arguments are cat(regions's arguments, captures arguments).
144   for (BlockArgument arg : region.getArguments()) {
145     outlinedFuncArgTypes.push_back(arg.getType());
146     outlinedFuncArgLocs.push_back(arg.getLoc());
147   }
148   for (Value value : outlinedValues) {
149     outlinedFuncArgTypes.push_back(value.getType());
150     outlinedFuncArgLocs.push_back(value.getLoc());
151   }
152   FunctionType outlinedFuncType =
153       FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
154                         originalTerminator->getOperandTypes());
155   auto outlinedFunc =
156       rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType);
157   Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
158 
159   // Merge blocks while replacing the original block operands.
160   // Warning: `mergeBlocks` erases the original block, reconstruct it later.
161   int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
162   auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
163   {
164     OpBuilder::InsertionGuard g(rewriter);
165     rewriter.setInsertionPointToEnd(outlinedFuncBody);
166     rewriter.mergeBlocks(
167         originalBlock, outlinedFuncBody,
168         outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
169     // Explicitly set up a new ReturnOp terminator.
170     rewriter.setInsertionPointToEnd(outlinedFuncBody);
171     rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(),
172                                     originalTerminator->getOperands());
173   }
174 
175   // Reconstruct the block that was deleted and add a
176   // terminator(call_results).
177   Block *newBlock = rewriter.createBlock(
178       &region, region.begin(),
179       TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
180       ArrayRef<Location>(outlinedFuncArgLocs)
181           .take_front(numOriginalBlockArguments));
182   {
183     OpBuilder::InsertionGuard g(rewriter);
184     rewriter.setInsertionPointToEnd(newBlock);
185     SmallVector<Value> callValues;
186     llvm::append_range(callValues, newBlock->getArguments());
187     llvm::append_range(callValues, outlinedValues);
188     auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
189     if (callOp)
190       *callOp = call;
191 
192     // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
193     // Clone `originalTerminator` to take the callOp results then erase it from
194     // `outlinedFuncBody`.
195     IRMapping bvm;
196     bvm.map(originalTerminator->getOperands(), call->getResults());
197     rewriter.clone(*originalTerminator, bvm);
198     rewriter.eraseOp(originalTerminator);
199   }
200 
201   // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
202   // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
203   for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
204                                                outlinedValues.size()))) {
205     Value orig = std::get<0>(it);
206     Value repl = std::get<1>(it);
207     {
208       OpBuilder::InsertionGuard g(rewriter);
209       rewriter.setInsertionPointToStart(outlinedFuncBody);
210       if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
211         IRMapping bvm;
212         repl = rewriter.clone(*cst, bvm)->getResult(0);
213       }
214     }
215     orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
216       return outlinedFunc->isProperAncestor(opOperand.getOwner());
217     });
218   }
219 
220   return outlinedFunc;
221 }
222 
223 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
224                                 func::FuncOp *thenFn, StringRef thenFnName,
225                                 func::FuncOp *elseFn, StringRef elseFnName) {
226   IRRewriter rewriter(b);
227   Location loc = ifOp.getLoc();
228   FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
229   if (thenFn && !ifOp.getThenRegion().empty()) {
230     outlinedFuncOpOrFailure = outlineSingleBlockRegion(
231         rewriter, loc, ifOp.getThenRegion(), thenFnName);
232     if (failed(outlinedFuncOpOrFailure))
233       return failure();
234     *thenFn = *outlinedFuncOpOrFailure;
235   }
236   if (elseFn && !ifOp.getElseRegion().empty()) {
237     outlinedFuncOpOrFailure = outlineSingleBlockRegion(
238         rewriter, loc, ifOp.getElseRegion(), elseFnName);
239     if (failed(outlinedFuncOpOrFailure))
240       return failure();
241     *elseFn = *outlinedFuncOpOrFailure;
242   }
243   return success();
244 }
245 
246 bool mlir::getInnermostParallelLoops(Operation *rootOp,
247                                      SmallVectorImpl<scf::ParallelOp> &result) {
248   assert(rootOp != nullptr && "Root operation must not be a nullptr.");
249   bool rootEnclosesPloops = false;
250   for (Region &region : rootOp->getRegions()) {
251     for (Block &block : region.getBlocks()) {
252       for (Operation &op : block) {
253         bool enclosesPloops = getInnermostParallelLoops(&op, result);
254         rootEnclosesPloops |= enclosesPloops;
255         if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
256           rootEnclosesPloops = true;
257 
258           // Collect parallel loop if it is an innermost one.
259           if (!enclosesPloops)
260             result.push_back(ploop);
261         }
262       }
263     }
264   }
265   return rootEnclosesPloops;
266 }
267 
268 // Build the IR that performs ceil division of a positive value by a constant:
269 //    ceildiv(a, B) = divis(a + (B-1), B)
270 // where divis is rounding-to-zero division.
271 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
272                              int64_t divisor) {
273   assert(divisor > 0 && "expected positive divisor");
274   assert(dividend.getType().isIntOrIndex() &&
275          "expected integer or index-typed value");
276 
277   Value divisorMinusOneCst = builder.create<arith::ConstantOp>(
278       loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
279   Value divisorCst = builder.create<arith::ConstantOp>(
280       loc, builder.getIntegerAttr(dividend.getType(), divisor));
281   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
282   return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
283 }
284 
285 // Build the IR that performs ceil division of a positive value by another
286 // positive value:
287 //    ceildiv(a, b) = divis(a + (b - 1), b)
288 // where divis is rounding-to-zero division.
289 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
290                              Value divisor) {
291   assert(dividend.getType().isIntOrIndex() &&
292          "expected integer or index-typed value");
293   Value cstOne = builder.create<arith::ConstantOp>(
294       loc, builder.getOneAttr(dividend.getType()));
295   Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
296   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
297   return builder.create<arith::DivUIOp>(loc, sum, divisor);
298 }
299 
300 /// Returns the trip count of `forOp` if its' low bound, high bound and step are
301 /// constants, or optional otherwise. Trip count is computed as
302 /// ceilDiv(highBound - lowBound, step).
303 static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
304   std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
305   std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
306   std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
307   if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value())
308     return {};
309 
310   // Constant loop bounds computation.
311   int64_t lbCst = lbCstOp.value();
312   int64_t ubCst = ubCstOp.value();
313   int64_t stepCst = stepCstOp.value();
314   assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 &&
315          "expected positive loop bounds and step");
316   return llvm::divideCeilSigned(ubCst - lbCst, stepCst);
317 }
318 
319 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
320 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
321 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
322 /// unrolled iteration using annotateFn.
323 static void generateUnrolledLoop(
324     Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
325     function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
326     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
327     ValueRange iterArgs, ValueRange yieldedValues) {
328   // Builder to insert unrolled bodies just before the terminator of the body of
329   // 'forOp'.
330   auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
331 
332   constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
333   if (!annotateFn)
334     annotateFn = defaultAnnotateFn;
335 
336   // Keep a pointer to the last non-terminator operation in the original block
337   // so that we know what to clone (since we are doing this in-place).
338   Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
339 
340   // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
341   SmallVector<Value, 4> lastYielded(yieldedValues);
342 
343   for (unsigned i = 1; i < unrollFactor; i++) {
344     IRMapping operandMap;
345 
346     // Prepare operand map.
347     operandMap.map(iterArgs, lastYielded);
348 
349     // If the induction variable is used, create a remapping to the value for
350     // this unrolled instance.
351     if (!forOpIV.use_empty()) {
352       Value ivUnroll = ivRemapFn(i, forOpIV, builder);
353       operandMap.map(forOpIV, ivUnroll);
354     }
355 
356     // Clone the original body of 'forOp'.
357     for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
358       Operation *clonedOp = builder.clone(*it, operandMap);
359       annotateFn(i, clonedOp, builder);
360     }
361 
362     // Update yielded values.
363     for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
364       lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]);
365   }
366 
367   // Make sure we annotate the Ops in the original body. We do this last so that
368   // any annotations are not copied into the cloned Ops above.
369   for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
370     annotateFn(0, &*it, builder);
371 
372   // Update operands of the yield statement.
373   loopBodyBlock->getTerminator()->setOperands(lastYielded);
374 }
375 
376 /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
377 /// eplilog loop, if the loop is unrolled.
378 FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
379     scf::ForOp forOp, uint64_t unrollFactor,
380     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
381   assert(unrollFactor > 0 && "expected positive unroll factor");
382 
383   // Return if the loop body is empty.
384   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
385     return UnrolledLoopInfo{forOp, std::nullopt};
386 
387   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
388   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
389   OpBuilder boundsBuilder(forOp);
390   IRRewriter rewriter(forOp.getContext());
391   auto loc = forOp.getLoc();
392   Value step = forOp.getStep();
393   Value upperBoundUnrolled;
394   Value stepUnrolled;
395   bool generateEpilogueLoop = true;
396 
397   std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
398   if (constTripCount) {
399     // Constant loop bounds computation.
400     int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
401     int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
402     int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
403     if (unrollFactor == 1) {
404       if (*constTripCount == 1 &&
405           failed(forOp.promoteIfSingleIteration(rewriter)))
406         return failure();
407       return UnrolledLoopInfo{forOp, std::nullopt};
408     }
409 
410     int64_t tripCountEvenMultiple =
411         *constTripCount - (*constTripCount % unrollFactor);
412     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
413     int64_t stepUnrolledCst = stepCst * unrollFactor;
414 
415     // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
416     generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
417     if (generateEpilogueLoop)
418       upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>(
419           loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
420                                             upperBoundUnrolledCst));
421     else
422       upperBoundUnrolled = forOp.getUpperBound();
423 
424     // Create constant for 'stepUnrolled'.
425     stepUnrolled = stepCst == stepUnrolledCst
426                        ? step
427                        : boundsBuilder.create<arith::ConstantOp>(
428                              loc, boundsBuilder.getIntegerAttr(
429                                       step.getType(), stepUnrolledCst));
430   } else {
431     // Dynamic loop bounds computation.
432     // TODO: Add dynamic asserts for negative lb/ub/step, or
433     // consider using ceilDiv from AffineApplyExpander.
434     auto lowerBound = forOp.getLowerBound();
435     auto upperBound = forOp.getUpperBound();
436     Value diff =
437         boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
438     Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
439     Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>(
440         loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
441     Value tripCountRem =
442         boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
443     // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
444     Value tripCountEvenMultiple =
445         boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem);
446     // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
447     upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
448         loc, lowerBound,
449         boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
450     // Scale 'step' by 'unrollFactor'.
451     stepUnrolled =
452         boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
453   }
454 
455   UnrolledLoopInfo resultLoops;
456 
457   // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
458   if (generateEpilogueLoop) {
459     OpBuilder epilogueBuilder(forOp->getContext());
460     epilogueBuilder.setInsertionPointAfter(forOp);
461     auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
462     epilogueForOp.setLowerBound(upperBoundUnrolled);
463 
464     // Update uses of loop results.
465     auto results = forOp.getResults();
466     auto epilogueResults = epilogueForOp.getResults();
467 
468     for (auto e : llvm::zip(results, epilogueResults)) {
469       std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
470     }
471     epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
472                                epilogueForOp.getInitArgs().size(), results);
473     if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
474       resultLoops.epilogueLoopOp = epilogueForOp;
475   }
476 
477   // Create unrolled loop.
478   forOp.setUpperBound(upperBoundUnrolled);
479   forOp.setStep(stepUnrolled);
480 
481   auto iterArgs = ValueRange(forOp.getRegionIterArgs());
482   auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
483 
484   generateUnrolledLoop(
485       forOp.getBody(), forOp.getInductionVar(), unrollFactor,
486       [&](unsigned i, Value iv, OpBuilder b) {
487         // iv' = iv + step * i;
488         auto stride = b.create<arith::MulIOp>(
489             loc, step,
490             b.create<arith::ConstantOp>(loc,
491                                         b.getIntegerAttr(iv.getType(), i)));
492         return b.create<arith::AddIOp>(loc, iv, stride);
493       },
494       annotateFn, iterArgs, yieldedValues);
495   // Promote the loop body up if this has turned into a single iteration loop.
496   if (forOp.promoteIfSingleIteration(rewriter).failed())
497     resultLoops.mainLoopOp = forOp;
498   return resultLoops;
499 }
500 
501 /// Check if bounds of all inner loops are defined outside of `forOp`
502 /// and return false if not.
503 static bool areInnerBoundsInvariant(scf::ForOp forOp) {
504   auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
505     if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
506         !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
507         !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
508       return WalkResult::interrupt();
509 
510     return WalkResult::advance();
511   });
512   return !walkResult.wasInterrupted();
513 }
514 
515 /// Unrolls and jams this loop by the specified factor.
516 LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
517                                           uint64_t unrollJamFactor) {
518   assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
519 
520   if (unrollJamFactor == 1)
521     return success();
522 
523   // If any control operand of any inner loop of `forOp` is defined within
524   // `forOp`, no unroll jam.
525   if (!areInnerBoundsInvariant(forOp)) {
526     LDBG("failed to unroll and jam: inner bounds are not invariant");
527     return failure();
528   }
529 
530   // Currently, for operations with results are not supported.
531   if (forOp->getNumResults() > 0) {
532     LDBG("failed to unroll and jam: unsupported loop with results");
533     return failure();
534   }
535 
536   // Currently, only constant trip count that divided by the unroll factor is
537   // supported.
538   std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
539   if (!tripCount.has_value()) {
540     // If the trip count is dynamic, do not unroll & jam.
541     LDBG("failed to unroll and jam: trip count could not be determined");
542     return failure();
543   }
544   if (unrollJamFactor > *tripCount) {
545     LDBG("unroll and jam factor is greater than trip count, set factor to trip "
546          "count");
547     unrollJamFactor = *tripCount;
548   } else if (*tripCount % unrollJamFactor != 0) {
549     LDBG("failed to unroll and jam: unsupported trip count that is not a "
550          "multiple of unroll jam factor");
551     return failure();
552   }
553 
554   // Nothing in the loop body other than the terminator.
555   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
556     return success();
557 
558   // Gather all sub-blocks to jam upon the loop being unrolled.
559   JamBlockGatherer<scf::ForOp> jbg;
560   jbg.walk(forOp);
561   auto &subBlocks = jbg.subBlocks;
562 
563   // Collect inner loops.
564   SmallVector<scf::ForOp> innerLoops;
565   forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
566 
567   // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
568   // iteration. There are (`unrollJamFactor` - 1) iterations.
569   SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
570 
571   // For any loop with iter_args, replace it with a new loop that has
572   // `unrollJamFactor` copies of its iterOperands, iter_args and yield
573   // operands.
574   SmallVector<scf::ForOp> newInnerLoops;
575   IRRewriter rewriter(forOp.getContext());
576   for (scf::ForOp oldForOp : innerLoops) {
577     SmallVector<Value> dupIterOperands, dupYieldOperands;
578     ValueRange oldIterOperands = oldForOp.getInits();
579     ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
580     ValueRange oldYieldOperands =
581         cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
582     // Get additional iterOperands, iterArgs, and yield operands. We will
583     // fix iterOperands and yield operands after cloning of sub-blocks.
584     for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
585       dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
586       dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
587     }
588     // Create a new loop with additional iterOperands, iter_args and yield
589     // operands. This new loop will take the loop body of the original loop.
590     bool forOpReplaced = oldForOp == forOp;
591     scf::ForOp newForOp =
592         cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
593             rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
594             [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
595               return dupYieldOperands;
596             }));
597     newInnerLoops.push_back(newForOp);
598     // `forOp` has been replaced with a new loop.
599     if (forOpReplaced)
600       forOp = newForOp;
601     // Update `operandMaps` for `newForOp` iterArgs and results.
602     ValueRange newIterArgs = newForOp.getRegionIterArgs();
603     unsigned oldNumIterArgs = oldIterArgs.size();
604     ValueRange newResults = newForOp.getResults();
605     unsigned oldNumResults = newResults.size() / unrollJamFactor;
606     assert(oldNumIterArgs == oldNumResults &&
607            "oldNumIterArgs must be the same as oldNumResults");
608     for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
609       for (unsigned j = 0; j < oldNumIterArgs; ++j) {
610         // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
611         // results. Update `operandMaps[i - 1]` to map old iterArgs and results
612         // to those in the `i`th new set.
613         operandMaps[i - 1].map(newIterArgs[j],
614                                newIterArgs[i * oldNumIterArgs + j]);
615         operandMaps[i - 1].map(newResults[j],
616                                newResults[i * oldNumResults + j]);
617       }
618     }
619   }
620 
621   // Scale the step of loop being unroll-jammed by the unroll-jam factor.
622   rewriter.setInsertionPoint(forOp);
623   int64_t step = forOp.getConstantStep()->getSExtValue();
624   auto newStep = rewriter.createOrFold<arith::MulIOp>(
625       forOp.getLoc(), forOp.getStep(),
626       rewriter.createOrFold<arith::ConstantOp>(
627           forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
628   forOp.setStep(newStep);
629   auto forOpIV = forOp.getInductionVar();
630 
631   // Unroll and jam (appends unrollJamFactor - 1 additional copies).
632   for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
633     for (auto &subBlock : subBlocks) {
634       // Builder to insert unroll-jammed bodies. Insert right at the end of
635       // sub-block.
636       OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
637 
638       // If the induction variable is used, create a remapping to the value for
639       // this unrolled instance.
640       if (!forOpIV.use_empty()) {
641         // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
642         auto ivTag = builder.createOrFold<arith::ConstantOp>(
643             forOp.getLoc(), builder.getIndexAttr(step * i));
644         auto ivUnroll =
645             builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
646         operandMaps[i - 1].map(forOpIV, ivUnroll);
647       }
648       // Clone the sub-block being unroll-jammed.
649       for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
650         builder.clone(*it, operandMaps[i - 1]);
651     }
652     // Fix iterOperands and yield op operands of newly created loops.
653     for (auto newForOp : newInnerLoops) {
654       unsigned oldNumIterOperands =
655           newForOp.getNumRegionIterArgs() / unrollJamFactor;
656       unsigned numControlOperands = newForOp.getNumControlOperands();
657       auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
658       unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
659       assert(oldNumIterOperands == oldNumYieldOperands &&
660              "oldNumIterOperands must be the same as oldNumYieldOperands");
661       for (unsigned j = 0; j < oldNumIterOperands; ++j) {
662         // The `i`th duplication of an old iterOperand or yield op operand
663         // needs to be replaced with a mapped value from `operandMaps[i - 1]`
664         // if such mapped value exists.
665         newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
666                             operandMaps[i - 1].lookupOrDefault(
667                                 newForOp.getOperand(numControlOperands + j)));
668         yieldOp.setOperand(
669             i * oldNumYieldOperands + j,
670             operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
671       }
672     }
673   }
674 
675   // Promote the loop body up if this has turned into a single iteration loop.
676   (void)forOp.promoteIfSingleIteration(rewriter);
677   return success();
678 }
679 
680 Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc,
681                                            OpFoldResult lb, OpFoldResult ub,
682                                            OpFoldResult step) {
683   Range normalizedLoopBounds;
684   normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
685   normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
686   AffineExpr s0, s1, s2;
687   bindSymbols(rewriter.getContext(), s0, s1, s2);
688   AffineExpr e = (s1 - s0).ceilDiv(s2);
689   normalizedLoopBounds.size =
690       affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
691   return normalizedLoopBounds;
692 }
693 
694 Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
695                                      OpFoldResult lb, OpFoldResult ub,
696                                      OpFoldResult step) {
697   if (getType(lb).isIndex()) {
698     return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
699   }
700   // For non-index types, generate `arith` instructions
701   // Check if the loop is already known to have a constant zero lower bound or
702   // a constant one step.
703   bool isZeroBased = false;
704   if (auto lbCst = getConstantIntValue(lb))
705     isZeroBased = lbCst.value() == 0;
706 
707   bool isStepOne = false;
708   if (auto stepCst = getConstantIntValue(step))
709     isStepOne = stepCst.value() == 1;
710 
711   Type rangeType = getType(lb);
712   assert(rangeType == getType(ub) && rangeType == getType(step) &&
713          "expected matching types");
714 
715   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
716   // assuming the step is strictly positive.  Update the bounds and the step
717   // of the loop to go from 0 to the number of iterations, if necessary.
718   if (isZeroBased && isStepOne)
719     return {lb, ub, step};
720 
721   OpFoldResult diff = ub;
722   if (!isZeroBased) {
723     diff = rewriter.createOrFold<arith::SubIOp>(
724         loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
725         getValueOrCreateConstantIntOp(rewriter, loc, lb));
726   }
727   OpFoldResult newUpperBound = diff;
728   if (!isStepOne) {
729     newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
730         loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
731         getValueOrCreateConstantIntOp(rewriter, loc, step));
732   }
733 
734   OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
735   OpFoldResult newStep = rewriter.getOneAttr(rangeType);
736 
737   return {newLowerBound, newUpperBound, newStep};
738 }
739 
740 static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
741                                                      Location loc,
742                                                      Value normalizedIv,
743                                                      OpFoldResult origLb,
744                                                      OpFoldResult origStep) {
745   AffineExpr d0, s0, s1;
746   bindSymbols(rewriter.getContext(), s0, s1);
747   bindDims(rewriter.getContext(), d0);
748   AffineExpr e = d0 * s1 + s0;
749   OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply(
750       rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
751   Value denormalizedIvVal =
752       getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
753   SmallPtrSet<Operation *, 1> preservedUses;
754   // If an `affine.apply` operation is generated for denormalization, the use
755   // of `origLb` in those ops must not be replaced. These arent not generated
756   // when `origLb == 0` and `origStep == 1`.
757   if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
758     if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
759       preservedUses.insert(preservedUse);
760     }
761   }
762   rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
763 }
764 
765 void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
766                                         Value normalizedIv, OpFoldResult origLb,
767                                         OpFoldResult origStep) {
768   if (getType(origLb).isIndex()) {
769     return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
770                                                     origLb, origStep);
771   }
772   Value denormalizedIv;
773   SmallPtrSet<Operation *, 2> preserve;
774   bool isStepOne = isConstantIntValue(origStep, 1);
775   bool isZeroBased = isConstantIntValue(origLb, 0);
776 
777   Value scaled = normalizedIv;
778   if (!isStepOne) {
779     Value origStepValue =
780         getValueOrCreateConstantIntOp(rewriter, loc, origStep);
781     scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue);
782     preserve.insert(scaled.getDefiningOp());
783   }
784   denormalizedIv = scaled;
785   if (!isZeroBased) {
786     Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
787     denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue);
788     preserve.insert(denormalizedIv.getDefiningOp());
789   }
790 
791   rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
792 }
793 
794 static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
795                                         ArrayRef<OpFoldResult> values) {
796   assert(!values.empty() && "unexecpted empty array");
797   AffineExpr s0, s1;
798   bindSymbols(rewriter.getContext(), s0, s1);
799   AffineExpr mul = s0 * s1;
800   OpFoldResult products = rewriter.getIndexAttr(1);
801   for (auto v : values) {
802     products = affine::makeComposedFoldedAffineApply(
803         rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
804   }
805   return products;
806 }
807 
808 /// Helper function to multiply a sequence of values.
809 static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
810                                        ArrayRef<Value> values) {
811   assert(!values.empty() && "unexpected empty list");
812   if (getType(values.front()).isIndex()) {
813     SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
814     OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
815     return getValueOrCreateConstantIndexOp(rewriter, loc, product);
816   }
817   std::optional<Value> productOf;
818   for (auto v : values) {
819     auto vOne = getConstantIntValue(v);
820     if (vOne && vOne.value() == 1)
821       continue;
822     if (productOf)
823       productOf =
824           rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult();
825     else
826       productOf = v;
827   }
828   if (!productOf) {
829     productOf = rewriter
830                     .create<arith::ConstantOp>(
831                         loc, rewriter.getOneAttr(getType(values.front())))
832                     .getResult();
833   }
834   return productOf.value();
835 }
836 
837 /// For each original loop, the value of the
838 /// induction variable can be obtained by dividing the induction variable of
839 /// the linearized loop by the total number of iterations of the loops nested
840 /// in it modulo the number of iterations in this loop (remove the values
841 /// related to the outer loops):
842 ///   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
843 /// Compute these iteratively from the innermost loop by creating a "running
844 /// quotient" of division by the range.
845 static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
846 delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
847                              Value linearizedIv, ArrayRef<Value> ubs) {
848 
849   if (linearizedIv.getType().isIndex()) {
850     Operation *delinearizedOp =
851         rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
852                                                           ubs);
853     auto resultVals = llvm::map_to_vector(
854         delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
855     return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
856   }
857 
858   SmallVector<Value> delinearizedIvs(ubs.size());
859   SmallPtrSet<Operation *, 2> preservedUsers;
860 
861   llvm::BitVector isUbOne(ubs.size());
862   for (auto [index, ub] : llvm::enumerate(ubs)) {
863     auto ubCst = getConstantIntValue(ub);
864     if (ubCst && ubCst.value() == 1)
865       isUbOne.set(index);
866   }
867 
868   // Prune the lead ubs that are all ones.
869   unsigned numLeadingOneUbs = 0;
870   for (auto [index, ub] : llvm::enumerate(ubs)) {
871     if (!isUbOne.test(index)) {
872       break;
873     }
874     delinearizedIvs[index] = rewriter.create<arith::ConstantOp>(
875         loc, rewriter.getZeroAttr(ub.getType()));
876     numLeadingOneUbs++;
877   }
878 
879   Value previous = linearizedIv;
880   for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
881     unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
882     if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
883       previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
884       preservedUsers.insert(previous.getDefiningOp());
885     }
886     Value iv = previous;
887     if (i != e - 1) {
888       if (!isUbOne.test(idx)) {
889         iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
890         preservedUsers.insert(iv.getDefiningOp());
891       } else {
892         iv = rewriter.create<arith::ConstantOp>(
893             loc, rewriter.getZeroAttr(ubs[idx].getType()));
894       }
895     }
896     delinearizedIvs[idx] = iv;
897   }
898   return {delinearizedIvs, preservedUsers};
899 }
900 
901 LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
902                                   MutableArrayRef<scf::ForOp> loops) {
903   if (loops.size() < 2)
904     return failure();
905 
906   scf::ForOp innermost = loops.back();
907   scf::ForOp outermost = loops.front();
908 
909   // 1. Make sure all loops iterate from 0 to upperBound with step 1.  This
910   // allows the following code to assume upperBound is the number of iterations.
911   for (auto loop : loops) {
912     OpBuilder::InsertionGuard g(rewriter);
913     rewriter.setInsertionPoint(outermost);
914     Value lb = loop.getLowerBound();
915     Value ub = loop.getUpperBound();
916     Value step = loop.getStep();
917     auto newLoopRange =
918         emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
919 
920     rewriter.modifyOpInPlace(loop, [&]() {
921       loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
922                                                        newLoopRange.offset));
923       loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
924                                                        newLoopRange.size));
925       loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
926                                                  newLoopRange.stride));
927     });
928     rewriter.setInsertionPointToStart(innermost.getBody());
929     denormalizeInductionVariable(rewriter, loop.getLoc(),
930                                  loop.getInductionVar(), lb, step);
931   }
932 
933   // 2. Emit code computing the upper bound of the coalesced loop as product
934   // of the number of iterations of all loops.
935   OpBuilder::InsertionGuard g(rewriter);
936   rewriter.setInsertionPoint(outermost);
937   Location loc = outermost.getLoc();
938   SmallVector<Value> upperBounds = llvm::map_to_vector(
939       loops, [](auto loop) { return loop.getUpperBound(); });
940   Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
941   outermost.setUpperBound(upperBound);
942 
943   rewriter.setInsertionPointToStart(innermost.getBody());
944   auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
945       rewriter, loc, outermost.getInductionVar(), upperBounds);
946   rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
947                                 preservedUsers);
948 
949   for (int i = loops.size() - 1; i > 0; --i) {
950     auto outerLoop = loops[i - 1];
951     auto innerLoop = loops[i];
952 
953     Operation *innerTerminator = innerLoop.getBody()->getTerminator();
954     auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
955     assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
956     for (Value &yieldedVal : yieldedVals) {
957       // The yielded value may be an iteration argument of the inner loop
958       // which is about to be inlined.
959       auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
960       if (iter != innerLoop.getRegionIterArgs().end()) {
961         unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
962         // `outerLoop` iter args identical to the `innerLoop` init args.
963         assert(iterArgIndex < innerLoop.getInitArgs().size());
964         yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
965       }
966     }
967     rewriter.eraseOp(innerTerminator);
968 
969     SmallVector<Value> innerBlockArgs;
970     innerBlockArgs.push_back(delinearizeIvs[i]);
971     llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
972     rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
973                                Block::iterator(innerLoop), innerBlockArgs);
974     rewriter.replaceOp(innerLoop, yieldedVals);
975   }
976   return success();
977 }
978 
979 LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
980   if (loops.empty()) {
981     return failure();
982   }
983   IRRewriter rewriter(loops.front().getContext());
984   return coalesceLoops(rewriter, loops);
985 }
986 
987 LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
988   LogicalResult result(failure());
989   SmallVector<scf::ForOp> loops;
990   getPerfectlyNestedLoops(loops, op);
991 
992   // Look for a band of loops that can be coalesced, i.e. perfectly nested
993   // loops with bounds defined above some loop.
994 
995   // 1. For each loop, find above which parent loop its bounds operands are
996   // defined.
997   SmallVector<unsigned> operandsDefinedAbove(loops.size());
998   for (unsigned i = 0, e = loops.size(); i < e; ++i) {
999     operandsDefinedAbove[i] = i;
1000     for (unsigned j = 0; j < i; ++j) {
1001       SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
1002                                            loops[i].getUpperBound(),
1003                                            loops[i].getStep()};
1004       if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
1005         operandsDefinedAbove[i] = j;
1006         break;
1007       }
1008     }
1009   }
1010 
1011   // 2. For each inner loop check that the iter_args for the immediately outer
1012   // loop are the init for the immediately inner loop and that the yields of the
1013   // return of the inner loop is the yield for the immediately outer loop. Keep
1014   // track of where the chain starts from for each loop.
1015   SmallVector<unsigned> iterArgChainStart(loops.size());
1016   iterArgChainStart[0] = 0;
1017   for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1018     // By default set the start of the chain to itself.
1019     iterArgChainStart[i] = i;
1020     auto outerloop = loops[i - 1];
1021     auto innerLoop = loops[i];
1022     if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1023       continue;
1024     }
1025     if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1026       continue;
1027     }
1028     auto outerloopTerminator = outerloop.getBody()->getTerminator();
1029     if (!llvm::equal(outerloopTerminator->getOperands(),
1030                      innerLoop.getResults())) {
1031       continue;
1032     }
1033     iterArgChainStart[i] = iterArgChainStart[i - 1];
1034   }
1035 
1036   // 3. Identify bands of loops such that the operands of all of them are
1037   // defined above the first loop in the band.  Traverse the nest bottom-up
1038   // so that modifications don't invalidate the inner loops.
1039   for (unsigned end = loops.size(); end > 0; --end) {
1040     unsigned start = 0;
1041     for (; start < end - 1; ++start) {
1042       auto maxPos =
1043           *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1044                             std::next(operandsDefinedAbove.begin(), end));
1045       if (maxPos > start)
1046         continue;
1047       if (iterArgChainStart[end - 1] > start)
1048         continue;
1049       auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1050       if (succeeded(coalesceLoops(band)))
1051         result = success();
1052       break;
1053     }
1054     // If a band was found and transformed, keep looking at the loops above
1055     // the outermost transformed loop.
1056     if (start != end - 1)
1057       end = start + 1;
1058   }
1059   return result;
1060 }
1061 
1062 void mlir::collapseParallelLoops(
1063     RewriterBase &rewriter, scf::ParallelOp loops,
1064     ArrayRef<std::vector<unsigned>> combinedDimensions) {
1065   OpBuilder::InsertionGuard g(rewriter);
1066   rewriter.setInsertionPoint(loops);
1067   Location loc = loops.getLoc();
1068 
1069   // Presort combined dimensions.
1070   auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1071   for (auto &dims : sortedDimensions)
1072     llvm::sort(dims);
1073 
1074   // Normalize ParallelOp's iteration pattern.
1075   SmallVector<Value, 3> normalizedUpperBounds;
1076   for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1077     OpBuilder::InsertionGuard g2(rewriter);
1078     rewriter.setInsertionPoint(loops);
1079     Value lb = loops.getLowerBound()[i];
1080     Value ub = loops.getUpperBound()[i];
1081     Value step = loops.getStep()[i];
1082     auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1083     normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
1084         rewriter, loops.getLoc(), newLoopRange.size));
1085 
1086     rewriter.setInsertionPointToStart(loops.getBody());
1087     denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
1088                                  step);
1089   }
1090 
1091   // Combine iteration spaces.
1092   SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1093   auto cst0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1094   auto cst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1095   for (auto &sortedDimension : sortedDimensions) {
1096     Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1097     for (auto idx : sortedDimension) {
1098       newUpperBound = rewriter.create<arith::MulIOp>(
1099           loc, newUpperBound, normalizedUpperBounds[idx]);
1100     }
1101     lowerBounds.push_back(cst0);
1102     steps.push_back(cst1);
1103     upperBounds.push_back(newUpperBound);
1104   }
1105 
1106   // Create new ParallelLoop with conversions to the original induction values.
1107   // The loop below uses divisions to get the relevant range of values in the
1108   // new induction value that represent each range of the original induction
1109   // value. The remainders then determine based on that range, which iteration
1110   // of the original induction value this represents. This is a normalized value
1111   // that is un-normalized already by the previous logic.
1112   auto newPloop = rewriter.create<scf::ParallelOp>(
1113       loc, lowerBounds, upperBounds, steps,
1114       [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1115         for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1116           Value previous = ploopIVs[i];
1117           unsigned numberCombinedDimensions = combinedDimensions[i].size();
1118           // Iterate over all except the last induction value.
1119           for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1120             unsigned idx = combinedDimensions[i][j];
1121 
1122             // Determine the current induction value's current loop iteration
1123             Value iv = insideBuilder.create<arith::RemSIOp>(
1124                 loc, previous, normalizedUpperBounds[idx]);
1125             replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1126                                        loops.getRegion());
1127 
1128             // Remove the effect of the current induction value to prepare for
1129             // the next value.
1130             previous = insideBuilder.create<arith::DivSIOp>(
1131                 loc, previous, normalizedUpperBounds[idx]);
1132           }
1133 
1134           // The final induction value is just the remaining value.
1135           unsigned idx = combinedDimensions[i][0];
1136           replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
1137                                      previous, loops.getRegion());
1138         }
1139       });
1140 
1141   // Replace the old loop with the new loop.
1142   loops.getBody()->back().erase();
1143   newPloop.getBody()->getOperations().splice(
1144       Block::iterator(newPloop.getBody()->back()),
1145       loops.getBody()->getOperations());
1146   loops.erase();
1147 }
1148 
1149 // Hoist the ops within `outer` that appear before `inner`.
1150 // Such ops include the ops that have been introduced by parametric tiling.
1151 // Ops that come from triangular loops (i.e. that belong to the program slice
1152 // rooted at `outer`) and ops that have side effects cannot be hoisted.
1153 // Return failure when any op fails to hoist.
1154 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1155   SetVector<Operation *> forwardSlice;
1156   ForwardSliceOptions options;
1157   options.filter = [&inner](Operation *op) {
1158     return op != inner.getOperation();
1159   };
1160   getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
1161   LogicalResult status = success();
1162   SmallVector<Operation *, 8> toHoist;
1163   for (auto &op : outer.getBody()->without_terminator()) {
1164     // Stop when encountering the inner loop.
1165     if (&op == inner.getOperation())
1166       break;
1167     // Skip over non-hoistable ops.
1168     if (forwardSlice.count(&op) > 0) {
1169       status = failure();
1170       continue;
1171     }
1172     // Skip intermediate scf::ForOp, these are not considered a failure.
1173     if (isa<scf::ForOp>(op))
1174       continue;
1175     // Skip other ops with regions.
1176     if (op.getNumRegions() > 0) {
1177       status = failure();
1178       continue;
1179     }
1180     // Skip if op has side effects.
1181     // TODO: loads to immutable memory regions are ok.
1182     if (!isMemoryEffectFree(&op)) {
1183       status = failure();
1184       continue;
1185     }
1186     toHoist.push_back(&op);
1187   }
1188   auto *outerForOp = outer.getOperation();
1189   for (auto *op : toHoist)
1190     op->moveBefore(outerForOp);
1191   return status;
1192 }
1193 
1194 // Traverse the interTile and intraTile loops and try to hoist ops such that
1195 // bands of perfectly nested loops are isolated.
1196 // Return failure if either perfect interTile or perfect intraTile bands cannot
1197 // be formed.
1198 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1199   LogicalResult status = success();
1200   const Loops &interTile = tileLoops.first;
1201   const Loops &intraTile = tileLoops.second;
1202   auto size = interTile.size();
1203   assert(size == intraTile.size());
1204   if (size <= 1)
1205     return success();
1206   for (unsigned s = 1; s < size; ++s)
1207     status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1208                                : failure();
1209   for (unsigned s = 1; s < size; ++s)
1210     status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1211                                : failure();
1212   return status;
1213 }
1214 
1215 /// Collect perfectly nested loops starting from `rootForOps`.  Loops are
1216 /// perfectly nested if each loop is the first and only non-terminator operation
1217 /// in the parent loop.  Collect at most `maxLoops` loops and append them to
1218 /// `forOps`.
1219 template <typename T>
1220 static void getPerfectlyNestedLoopsImpl(
1221     SmallVectorImpl<T> &forOps, T rootForOp,
1222     unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1223   for (unsigned i = 0; i < maxLoops; ++i) {
1224     forOps.push_back(rootForOp);
1225     Block &body = rootForOp.getRegion().front();
1226     if (body.begin() != std::prev(body.end(), 2))
1227       return;
1228 
1229     rootForOp = dyn_cast<T>(&body.front());
1230     if (!rootForOp)
1231       return;
1232   }
1233 }
1234 
1235 static Loops stripmineSink(scf::ForOp forOp, Value factor,
1236                            ArrayRef<scf::ForOp> targets) {
1237   auto originalStep = forOp.getStep();
1238   auto iv = forOp.getInductionVar();
1239 
1240   OpBuilder b(forOp);
1241   forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
1242 
1243   Loops innerLoops;
1244   for (auto t : targets) {
1245     // Save information for splicing ops out of t when done
1246     auto begin = t.getBody()->begin();
1247     auto nOps = t.getBody()->getOperations().size();
1248 
1249     // Insert newForOp before the terminator of `t`.
1250     auto b = OpBuilder::atBlockTerminator((t.getBody()));
1251     Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
1252     Value ub =
1253         b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
1254 
1255     // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1256     auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
1257     newForOp.getBody()->getOperations().splice(
1258         newForOp.getBody()->getOperations().begin(),
1259         t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1260     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1261                                newForOp.getRegion());
1262 
1263     innerLoops.push_back(newForOp);
1264   }
1265 
1266   return innerLoops;
1267 }
1268 
1269 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1270 // Returns the new for operation, nested immediately under `target`.
1271 template <typename SizeType>
1272 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1273                                 scf::ForOp target) {
1274   // TODO: Use cheap structural assertions that targets are nested under
1275   // forOp and that targets are not nested under each other when DominanceInfo
1276   // exposes the capability. It seems overkill to construct a whole function
1277   // dominance tree at this point.
1278   auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1279   assert(res.size() == 1 && "Expected 1 inner forOp");
1280   return res[0];
1281 }
1282 
1283 SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
1284                                  ArrayRef<Value> sizes,
1285                                  ArrayRef<scf::ForOp> targets) {
1286   SmallVector<SmallVector<scf::ForOp, 8>, 8> res;
1287   SmallVector<scf::ForOp, 8> currentTargets(targets);
1288   for (auto it : llvm::zip(forOps, sizes)) {
1289     auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1290     res.push_back(step);
1291     currentTargets = step;
1292   }
1293   return res;
1294 }
1295 
1296 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
1297                  scf::ForOp target) {
1298   SmallVector<scf::ForOp, 8> res;
1299   for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
1300     assert(loops.size() == 1);
1301     res.push_back(loops[0]);
1302   }
1303   return res;
1304 }
1305 
1306 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
1307   // Collect perfectly nested loops.  If more size values provided than nested
1308   // loops available, truncate `sizes`.
1309   SmallVector<scf::ForOp, 4> forOps;
1310   forOps.reserve(sizes.size());
1311   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1312   if (forOps.size() < sizes.size())
1313     sizes = sizes.take_front(forOps.size());
1314 
1315   return ::tile(forOps, sizes, forOps.back());
1316 }
1317 
1318 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
1319                                    scf::ForOp root) {
1320   getPerfectlyNestedLoopsImpl(nestedLoops, root);
1321 }
1322 
1323 TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
1324                                        ArrayRef<int64_t> sizes) {
1325   // Collect perfectly nested loops.  If more size values provided than nested
1326   // loops available, truncate `sizes`.
1327   SmallVector<scf::ForOp, 4> forOps;
1328   forOps.reserve(sizes.size());
1329   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1330   if (forOps.size() < sizes.size())
1331     sizes = sizes.take_front(forOps.size());
1332 
1333   // Compute the tile sizes such that i-th outer loop executes size[i]
1334   // iterations.  Given that the loop current executes
1335   //   numIterations = ceildiv((upperBound - lowerBound), step)
1336   // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1337   SmallVector<Value, 4> tileSizes;
1338   tileSizes.reserve(sizes.size());
1339   for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1340     assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1341 
1342     auto forOp = forOps[i];
1343     OpBuilder builder(forOp);
1344     auto loc = forOp.getLoc();
1345     Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(),
1346                                                forOp.getLowerBound());
1347     Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1348     Value iterationsPerBlock =
1349         ceilDivPositive(builder, loc, numIterations, sizes[i]);
1350     tileSizes.push_back(iterationsPerBlock);
1351   }
1352 
1353   // Call parametric tiling with the given sizes.
1354   auto intraTile = tile(forOps, tileSizes, forOps.back());
1355   TileLoops tileLoops = std::make_pair(forOps, intraTile);
1356 
1357   // TODO: for now we just ignore the result of band isolation.
1358   // In the future, mapping decisions may be impacted by the ability to
1359   // isolate perfectly nested bands.
1360   (void)tryIsolateBands(tileLoops);
1361 
1362   return tileLoops;
1363 }
1364 
1365 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
1366                                                       scf::ForallOp source,
1367                                                       RewriterBase &rewriter) {
1368   unsigned numTargetOuts = target.getNumResults();
1369   unsigned numSourceOuts = source.getNumResults();
1370 
1371   // Create fused shared_outs.
1372   SmallVector<Value> fusedOuts;
1373   llvm::append_range(fusedOuts, target.getOutputs());
1374   llvm::append_range(fusedOuts, source.getOutputs());
1375 
1376   // Create a new scf.forall op after the source loop.
1377   rewriter.setInsertionPointAfter(source);
1378   scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
1379       source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
1380       source.getMixedStep(), fusedOuts, source.getMapping());
1381 
1382   // Map control operands.
1383   IRMapping mapping;
1384   mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1385   mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1386 
1387   // Map shared outs.
1388   mapping.map(target.getRegionIterArgs(),
1389               fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1390   mapping.map(source.getRegionIterArgs(),
1391               fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1392 
1393   // Append everything except the terminator into the fused operation.
1394   rewriter.setInsertionPointToStart(fusedLoop.getBody());
1395   for (Operation &op : target.getBody()->without_terminator())
1396     rewriter.clone(op, mapping);
1397   for (Operation &op : source.getBody()->without_terminator())
1398     rewriter.clone(op, mapping);
1399 
1400   // Fuse the old terminator in_parallel ops into the new one.
1401   scf::InParallelOp targetTerm = target.getTerminator();
1402   scf::InParallelOp sourceTerm = source.getTerminator();
1403   scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1404   rewriter.setInsertionPointToStart(fusedTerm.getBody());
1405   for (Operation &op : targetTerm.getYieldingOps())
1406     rewriter.clone(op, mapping);
1407   for (Operation &op : sourceTerm.getYieldingOps())
1408     rewriter.clone(op, mapping);
1409 
1410   // Replace old loops by substituting their uses by results of the fused loop.
1411   rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1412   rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1413 
1414   return fusedLoop;
1415 }
1416 
1417 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
1418                                                 scf::ForOp source,
1419                                                 RewriterBase &rewriter) {
1420   unsigned numTargetOuts = target.getNumResults();
1421   unsigned numSourceOuts = source.getNumResults();
1422 
1423   // Create fused init_args, with target's init_args before source's init_args.
1424   SmallVector<Value> fusedInitArgs;
1425   llvm::append_range(fusedInitArgs, target.getInitArgs());
1426   llvm::append_range(fusedInitArgs, source.getInitArgs());
1427 
1428   // Create a new scf.for op after the source loop (with scf.yield terminator
1429   // (without arguments) only in case its init_args is empty).
1430   rewriter.setInsertionPointAfter(source);
1431   scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
1432       source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1433       source.getStep(), fusedInitArgs);
1434 
1435   // Map original induction variables and operands to those of the fused loop.
1436   IRMapping mapping;
1437   mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1438   mapping.map(target.getRegionIterArgs(),
1439               fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1440   mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1441   mapping.map(source.getRegionIterArgs(),
1442               fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1443 
1444   // Merge target's body into the new (fused) for loop and then source's body.
1445   rewriter.setInsertionPointToStart(fusedLoop.getBody());
1446   for (Operation &op : target.getBody()->without_terminator())
1447     rewriter.clone(op, mapping);
1448   for (Operation &op : source.getBody()->without_terminator())
1449     rewriter.clone(op, mapping);
1450 
1451   // Build fused yield results by appropriately mapping original yield operands.
1452   SmallVector<Value> yieldResults;
1453   for (Value operand : target.getBody()->getTerminator()->getOperands())
1454     yieldResults.push_back(mapping.lookupOrDefault(operand));
1455   for (Value operand : source.getBody()->getTerminator()->getOperands())
1456     yieldResults.push_back(mapping.lookupOrDefault(operand));
1457   if (!yieldResults.empty())
1458     rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
1459 
1460   // Replace old loops by substituting their uses by results of the fused loop.
1461   rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1462   rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1463 
1464   return fusedLoop;
1465 }
1466 
1467 FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1468                                                  scf::ForallOp forallOp) {
1469   SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1470   SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1471   SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1472 
1473   if (llvm::all_of(
1474           lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
1475       llvm::all_of(
1476           steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
1477     return forallOp;
1478   }
1479 
1480   SmallVector<OpFoldResult> newLbs, newUbs, newSteps;
1481   for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1482     Range normalizedLoopParams =
1483         emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
1484     newLbs.push_back(normalizedLoopParams.offset);
1485     newUbs.push_back(normalizedLoopParams.size);
1486     newSteps.push_back(normalizedLoopParams.stride);
1487   }
1488 
1489   auto normalizedForallOp = rewriter.create<scf::ForallOp>(
1490       forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
1491       forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {});
1492 
1493   rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
1494                               normalizedForallOp.getBodyRegion(),
1495                               normalizedForallOp.getBodyRegion().begin());
1496 
1497   rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp);
1498   return success();
1499 }
1500