xref: /llvm-project/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (revision 35c5e56b6113b468b521c071ac141b4bb94da1d7)
1 //===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/LoopUtils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
18 #include "mlir/Dialect/Affine/Utils.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <optional>
30 
31 #define DEBUG_TYPE "loop-utils"
32 
33 using namespace mlir;
34 using namespace affine;
35 using namespace presburger;
36 using llvm::SmallMapVector;
37 
38 /// Computes the cleanup loop lower bound of the loop being unrolled with
39 /// the specified unroll factor; this bound will also be upper bound of the main
40 /// part of the unrolled loop. Computes the bound as an AffineMap with its
41 /// operands or a null map when the trip count can't be expressed as an affine
42 /// expression.
43 static void
44 getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
45                          AffineMap &cleanupLbMap,
46                          SmallVectorImpl<Value> &cleanupLbOperands) {
47   AffineMap tripCountMap;
48   SmallVector<Value, 4> tripCountOperands;
49   getTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands);
50   // Trip count can't be computed.
51   if (!tripCountMap) {
52     cleanupLbMap = AffineMap();
53     return;
54   }
55 
56   OpBuilder b(forOp);
57   auto lbMap = forOp.getLowerBoundMap();
58   auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap,
59                                     forOp.getLowerBoundOperands());
60 
61   // For each upper bound expr, get the range.
62   // Eg: affine.for %i = lb to min (ub1, ub2),
63   // where tripCountExprs yield (tr1, tr2), we create affine.apply's:
64   // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all
65   // these affine.apply's make up the cleanup loop lower bound.
66   SmallVector<AffineExpr, 4> bumpExprs(tripCountMap.getNumResults());
67   SmallVector<Value, 4> bumpValues(tripCountMap.getNumResults());
68   int64_t step = forOp.getStepAsInt();
69   for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) {
70     auto tripCountExpr = tripCountMap.getResult(i);
71     bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step;
72     auto bumpMap = AffineMap::get(tripCountMap.getNumDims(),
73                                   tripCountMap.getNumSymbols(), bumpExprs[i]);
74     bumpValues[i] =
75         b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands);
76   }
77 
78   SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults());
79   for (unsigned i = 0, e = bumpExprs.size(); i < e; i++)
80     newUbExprs[i] = b.getAffineDimExpr(0) + b.getAffineDimExpr(i + 1);
81 
82   cleanupLbOperands.clear();
83   cleanupLbOperands.push_back(lb);
84   cleanupLbOperands.append(bumpValues.begin(), bumpValues.end());
85   cleanupLbMap = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs,
86                                 b.getContext());
87   // Simplify the cleanupLbMap + cleanupLbOperands.
88   fullyComposeAffineMapAndOperands(&cleanupLbMap, &cleanupLbOperands);
89   cleanupLbMap = simplifyAffineMap(cleanupLbMap);
90   canonicalizeMapAndOperands(&cleanupLbMap, &cleanupLbOperands);
91   // Remove any affine.apply's that became dead from the simplification above.
92   for (auto v : bumpValues)
93     if (v.use_empty())
94       v.getDefiningOp()->erase();
95 
96   if (lb.use_empty())
97     lb.erase();
98 }
99 
100 /// Helper to replace uses of loop carried values (iter_args) and loop
101 /// yield values while promoting single iteration affine.for ops.
102 static void replaceIterArgsAndYieldResults(AffineForOp forOp) {
103   // Replace uses of iter arguments with iter operands (initial values).
104   auto iterOperands = forOp.getInits();
105   auto iterArgs = forOp.getRegionIterArgs();
106   for (auto e : llvm::zip(iterOperands, iterArgs))
107     std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
108 
109   // Replace uses of loop results with the values yielded by the loop.
110   auto outerResults = forOp.getResults();
111   auto innerResults = forOp.getBody()->getTerminator()->getOperands();
112   for (auto e : llvm::zip(outerResults, innerResults))
113     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
114 }
115 
116 /// Promotes the loop body of a forOp to its containing block if the forOp
117 /// was known to have a single iteration.
118 LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
119   std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
120   if (!tripCount || *tripCount != 1)
121     return failure();
122 
123   // TODO: extend this for arbitrary affine bounds.
124   if (forOp.getLowerBoundMap().getNumResults() != 1)
125     return failure();
126 
127   // Replaces all IV uses to its single iteration value.
128   auto iv = forOp.getInductionVar();
129   auto *parentBlock = forOp->getBlock();
130   if (!iv.use_empty()) {
131     if (forOp.hasConstantLowerBound()) {
132       auto func = forOp->getParentOfType<FunctionOpInterface>();
133       OpBuilder builder(forOp->getContext());
134       if (func)
135         builder.setInsertionPointToStart(&func.getFunctionBody().front());
136       else
137         builder.setInsertionPoint(forOp);
138       auto constOp = builder.create<arith::ConstantIndexOp>(
139           forOp.getLoc(), forOp.getConstantLowerBound());
140       iv.replaceAllUsesWith(constOp);
141     } else {
142       auto lbOperands = forOp.getLowerBoundOperands();
143       auto lbMap = forOp.getLowerBoundMap();
144       OpBuilder builder(forOp);
145       if (lbMap == builder.getDimIdentityMap()) {
146         // No need of generating an affine.apply.
147         iv.replaceAllUsesWith(lbOperands[0]);
148       } else {
149         auto affineApplyOp =
150             builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
151         iv.replaceAllUsesWith(affineApplyOp);
152       }
153     }
154   }
155 
156   replaceIterArgsAndYieldResults(forOp);
157 
158   // Move the loop body operations, except for its terminator, to the loop's
159   // containing block.
160   forOp.getBody()->back().erase();
161   parentBlock->getOperations().splice(Block::iterator(forOp),
162                                       forOp.getBody()->getOperations());
163   forOp.erase();
164   return success();
165 }
166 
167 /// Generates an affine.for op with the specified lower and upper bounds
168 /// while generating the right IV remappings to realize shifts for operations in
169 /// its body. The operations that go into the loop body are specified in
170 /// opGroupQueue starting from the specified offset, and in that order. The
171 /// first element of the pair specifies the shift applied to that group of
172 /// operations; the shift is multiplied by the loop step before being applied.
173 /// Returns nullptr if the generated loop simplifies to a single iteration one.
174 static AffineForOp generateShiftedLoop(
175     AffineMap lbMap, AffineMap ubMap,
176     const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> &opGroupQueue,
177     unsigned offset, AffineForOp srcForOp, OpBuilder b) {
178   auto lbOperands = srcForOp.getLowerBoundOperands();
179   auto ubOperands = srcForOp.getUpperBoundOperands();
180 
181   assert(lbMap.getNumInputs() == lbOperands.size());
182   assert(ubMap.getNumInputs() == ubOperands.size());
183 
184   auto loopChunk =
185       b.create<AffineForOp>(srcForOp.getLoc(), lbOperands, lbMap, ubOperands,
186                             ubMap, srcForOp.getStepAsInt());
187   auto loopChunkIV = loopChunk.getInductionVar();
188   auto srcIV = srcForOp.getInductionVar();
189 
190   IRMapping operandMap;
191 
192   auto bodyBuilder = OpBuilder::atBlockTerminator(loopChunk.getBody());
193   for (const auto &it : llvm::drop_begin(opGroupQueue, offset)) {
194     uint64_t shift = it.first;
195     auto ops = it.second;
196     // All 'same shift' operations get added with their operands being
197     // remapped to results of cloned operations, and their IV used remapped.
198     // Generate the remapping if the shift is not zero: remappedIV = newIV -
199     // shift.
200     if (!srcIV.use_empty() && shift != 0) {
201       auto ivRemap = bodyBuilder.create<AffineApplyOp>(
202           srcForOp.getLoc(),
203           bodyBuilder.getSingleDimShiftAffineMap(
204               -static_cast<int64_t>(srcForOp.getStepAsInt() * shift)),
205           loopChunkIV);
206       operandMap.map(srcIV, ivRemap);
207     } else {
208       operandMap.map(srcIV, loopChunkIV);
209     }
210     for (auto *op : ops)
211       bodyBuilder.clone(*op, operandMap);
212   };
213   if (succeeded(promoteIfSingleIteration(loopChunk)))
214     return AffineForOp();
215   return loopChunk;
216 }
217 
218 // The skewing of operations with respect to one another can be used for
219 // example to allow overlap of asynchronous operations (such as DMA
220 // communication) with computation, or just relative shifting of operations
221 // for better register reuse, locality or parallelism. As such, the shifts are
222 // typically expected to be at most of the order of the number of operations.
223 // This method should not be used as a substitute for loop distribution/fission.
224 // This method uses an algorithm// in time linear in the number of operations
225 // in the body of the for loop - (using the 'sweep line' paradigm). This method
226 // asserts preservation of SSA dominance. A check for that as well as that for
227 // memory-based dependence preservation check rests with the users of this
228 // method.
229 LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
230                                                 ArrayRef<uint64_t> shifts,
231                                                 bool unrollPrologueEpilogue) {
232   assert(forOp.getBody()->getOperations().size() == shifts.size() &&
233          "too few/many shifts");
234   if (forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
235     return success();
236 
237   // If the trip counts aren't constant, we would need versioning and
238   // conditional guards (or context information to prevent such versioning). The
239   // better way to pipeline for such loops is to first tile them and extract
240   // constant trip count "full tiles" before applying this.
241   auto mayBeConstTripCount = getConstantTripCount(forOp);
242   if (!mayBeConstTripCount) {
243     LLVM_DEBUG(forOp.emitRemark("non-constant trip count loop not handled"));
244     return success();
245   }
246   uint64_t tripCount = *mayBeConstTripCount;
247 
248   assert(isOpwiseShiftValid(forOp, shifts) &&
249          "shifts will lead to an invalid transformation\n");
250 
251   int64_t step = forOp.getStepAsInt();
252 
253   unsigned numChildOps = shifts.size();
254 
255   // Do a linear time (counting) sort for the shifts.
256   uint64_t maxShift = *llvm::max_element(shifts);
257   if (maxShift >= numChildOps) {
258     // Large shifts are not the typical use case.
259     forOp.emitWarning("not shifting because shifts are unrealistically large");
260     return success();
261   }
262 
263   // An array of operation groups sorted by shift amount; each group has all
264   // operations with the same shift in the order in which they appear in the
265   // body of the 'affine.for' op.
266   std::vector<std::vector<Operation *>> sortedOpGroups(maxShift + 1);
267   unsigned pos = 0;
268   for (auto &op : forOp.getBody()->without_terminator()) {
269     auto shift = shifts[pos++];
270     sortedOpGroups[shift].push_back(&op);
271   }
272 
273   // Unless the shifts have a specific pattern (which actually would be the
274   // common use case), prologue and epilogue are not meaningfully defined.
275   // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
276   // loop generated as the prologue and the last as epilogue and unroll these
277   // fully.
278   AffineForOp prologue, epilogue;
279 
280   // Do a sweep over the sorted shifts while storing open groups in a
281   // vector, and generating loop portions as necessary during the sweep. A block
282   // of operations is paired with its shift.
283   std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> opGroupQueue;
284 
285   auto origLbMap = forOp.getLowerBoundMap();
286   uint64_t lbShift = 0;
287   OpBuilder b(forOp);
288   for (uint64_t d = 0, e = sortedOpGroups.size(); d < e; ++d) {
289     // If nothing is shifted by d, continue.
290     if (sortedOpGroups[d].empty())
291       continue;
292     if (!opGroupQueue.empty()) {
293       assert(d > 0 &&
294              "Queue expected to be empty when the first block is found");
295       // The interval for which the loop needs to be generated here is:
296       // [lbShift, min(lbShift + tripCount, d)) and the body of the
297       // loop needs to have all operations in opQueue in that order.
298       AffineForOp res;
299       if (lbShift + tripCount * step < d * step) {
300         res = generateShiftedLoop(
301             b.getShiftedAffineMap(origLbMap, lbShift),
302             b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
303             opGroupQueue, /*offset=*/0, forOp, b);
304         // Entire loop for the queued op groups generated, empty it.
305         opGroupQueue.clear();
306         lbShift += tripCount * step;
307       } else {
308         res = generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
309                                   b.getShiftedAffineMap(origLbMap, d),
310                                   opGroupQueue, /*offset=*/0, forOp, b);
311         lbShift = d * step;
312       }
313 
314       if (res) {
315         // Simplify/canonicalize the affine.for.
316         RewritePatternSet patterns(res.getContext());
317         AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
318         GreedyRewriteConfig config;
319         config.strictMode = GreedyRewriteStrictness::ExistingOps;
320         bool erased;
321         (void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
322                                       config, /*changed=*/nullptr, &erased);
323         if (!erased && !prologue)
324           prologue = res;
325         if (!erased)
326           epilogue = res;
327       }
328     } else {
329       // Start of first interval.
330       lbShift = d * step;
331     }
332     // Augment the list of operations that get into the current open interval.
333     opGroupQueue.emplace_back(d, sortedOpGroups[d]);
334   }
335 
336   // Those operations groups left in the queue now need to be processed (FIFO)
337   // and their loops completed.
338   for (unsigned i = 0, e = opGroupQueue.size(); i < e; ++i) {
339     uint64_t ubShift = (opGroupQueue[i].first + tripCount) * step;
340     epilogue = generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
341                                    b.getShiftedAffineMap(origLbMap, ubShift),
342                                    opGroupQueue, /*offset=*/i, forOp, b);
343     lbShift = ubShift;
344     if (!prologue)
345       prologue = epilogue;
346   }
347 
348   // Erase the original for op.
349   forOp.erase();
350 
351   if (unrollPrologueEpilogue && prologue)
352     (void)loopUnrollFull(prologue);
353   if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
354     (void)loopUnrollFull(epilogue);
355 
356   return success();
357 }
358 
359 /// Checks whether a loop nest is hyper-rectangular or not.
360 static LogicalResult
361 checkIfHyperRectangular(MutableArrayRef<AffineForOp> input) {
362   FlatAffineValueConstraints cst;
363   SmallVector<Operation *, 8> ops(input.begin(), input.end());
364   // 0-d or 1-d is trivially hyper-rectangular.
365   if (input.size() <= 1)
366     return success();
367   if (failed(getIndexSet(ops, &cst))) {
368     LLVM_DEBUG(llvm::dbgs() << "Index set computation failed!\n");
369     return failure();
370   }
371   if (!cst.isHyperRectangular(0, input.size())) {
372     LLVM_DEBUG(llvm::dbgs()
373                << "Non-hyperrectangular nests not supported for tiling!\n");
374     return failure();
375   }
376   return success();
377 }
378 
379 /// Check if the input nest is supported for tiling and whether tiling would be
380 /// legal or not.
381 template <typename t>
382 static LogicalResult performPreTilingChecks(MutableArrayRef<AffineForOp> input,
383                                             ArrayRef<t> tileSizes) {
384   assert(input.size() == tileSizes.size() && "Too few/many tile sizes");
385 
386   if (llvm::any_of(input,
387                    [](AffineForOp op) { return op.getNumResults() > 0; })) {
388     LLVM_DEBUG(llvm::dbgs()
389                << "Cannot tile nest where a loop has yield values\n");
390     return failure();
391   }
392 
393   // Check if the supplied `for` ops are all successively nested.
394   if (!isPerfectlyNested(input)) {
395     LLVM_DEBUG(llvm::dbgs() << "input loops not perfectly nested");
396     return failure();
397   }
398 
399   //  TODO: handle non hyper-rectangular spaces.
400   if (failed(checkIfHyperRectangular(input)))
401     return failure();
402 
403   return success();
404 }
405 
406 /// Move the loop body of AffineForOp 'src' from 'src' into the specified
407 /// location in destination's body, ignoring the terminator.
408 static void moveLoopBodyImpl(AffineForOp src, AffineForOp dest,
409                              Block::iterator loc) {
410   auto &ops = src.getBody()->getOperations();
411   dest.getBody()->getOperations().splice(loc, ops, ops.begin(),
412                                          std::prev(ops.end()));
413 }
414 
415 /// Move the loop body of AffineForOp 'src' from 'src' to the start of dest
416 /// body.
417 static void moveLoopBody(AffineForOp src, AffineForOp dest) {
418   moveLoopBodyImpl(src, dest, dest.getBody()->begin());
419 }
420 
421 /// Constructs tiled loop nest, without setting the loop bounds and move the
422 /// body of the original loop nest to the tiled loop nest.
423 static void constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops,
424                                    AffineForOp rootAffineForOp, unsigned width,
425                                    MutableArrayRef<AffineForOp> tiledLoops) {
426   Location loc = rootAffineForOp.getLoc();
427 
428   // The outermost among the loops as we add more..
429   Operation *topLoop = rootAffineForOp.getOperation();
430   AffineForOp innermostPointLoop;
431 
432   // Add intra-tile (or point) loops.
433   for (unsigned i = 0; i < width; i++) {
434     OpBuilder b(topLoop);
435     // Loop bounds will be set later.
436     AffineForOp pointLoop = b.create<AffineForOp>(loc, 0, 0);
437     pointLoop.getBody()->getOperations().splice(
438         pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
439         topLoop);
440     tiledLoops[2 * width - 1 - i] = pointLoop;
441     topLoop = pointLoop.getOperation();
442     if (i == 0)
443       innermostPointLoop = pointLoop;
444   }
445 
446   // Add tile space loops;
447   for (unsigned i = width; i < 2 * width; i++) {
448     OpBuilder b(topLoop);
449     // Loop bounds will be set later.
450     AffineForOp tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0);
451     tileSpaceLoop.getBody()->getOperations().splice(
452         tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
453         topLoop);
454     tiledLoops[2 * width - i - 1] = tileSpaceLoop;
455     topLoop = tileSpaceLoop.getOperation();
456   }
457 
458   // Move the loop body of the original nest to the new one.
459   moveLoopBody(origLoops.back(), innermostPointLoop);
460 }
461 
462 /// Set lower and upper bounds of intra-tile loops for parametric tiling.
463 //  TODO: Handle non-constant lower bounds.
464 static void setIntraTileBoundsParametric(OpBuilder &b, AffineForOp origLoop,
465                                          AffineForOp newInterTileLoop,
466                                          AffineForOp newIntraTileLoop,
467                                          Value tileSize) {
468   // The lower bound for the intra-tile loop is represented by an affine map
469   // as (%i, %t0)->((%i - %origlb) * %t0 + %origlb). Similarly, the upper bound
470   // for the intra-tile loop is represented by an affine map as (%i, %t0)->((%i
471   // - %origlb) * %t0) + (%t0 * %origLoopStep) + %origlb), where %i is loop IV
472   // of the corresponding inter-tile loop, %t0 is the corresponding tiling
473   // parameter, %origlb is lower bound and %origLoopStep is the loop step of the
474   // corresponding inter-tile loop.
475 
476   assert(origLoop.hasConstantLowerBound() &&
477          "expected input loops to have constant lower bound.");
478 
479   // Get lower bound of original loop as an affine expression.
480   AffineExpr origLowerBoundExpr;
481   origLowerBoundExpr =
482       b.getAffineConstantExpr(origLoop.getConstantLowerBound());
483 
484   // Add dim operands from original lower/upper bound.
485   SmallVector<Value, 4> lbOperands, ubOperands;
486   AffineBound lb = origLoop.getLowerBound();
487   AffineBound ub = origLoop.getUpperBound();
488   lbOperands.reserve(lb.getNumOperands() + 2);
489   ubOperands.reserve(ub.getNumOperands() + 2);
490   AffineMap origLbMap = lb.getMap();
491   AffineMap origUbMap = ub.getMap();
492   for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j)
493     lbOperands.push_back(lb.getOperand(j));
494   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
495     ubOperands.push_back(ub.getOperand(j));
496 
497   // Add a new dim operand in lb/ubOperands corresponding to the origLoop
498   // IV.
499   lbOperands.push_back(newInterTileLoop.getInductionVar());
500   ubOperands.push_back(newInterTileLoop.getInductionVar());
501 
502   // Get loop IV as an affine expression for lower/upper bound. Size of
503   // lb/ubOperands is guaranteed to be atleast one.
504   AffineExpr lbLoopIvExpr = b.getAffineDimExpr(lbOperands.size() - 1);
505   AffineExpr ubLoopIvExpr = b.getAffineDimExpr(ubOperands.size() - 1);
506 
507   // Add symbol operands from original lower/upper bound.
508   for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
509     lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
510   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
511     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
512 
513   // Add a new symbol operand which is the tile size for this loop.
514   lbOperands.push_back(tileSize);
515   ubOperands.push_back(tileSize);
516 
517   SmallVector<AffineExpr, 4> lbBoundExprs;
518   SmallVector<AffineExpr, 4> ubBoundExprs;
519   lbBoundExprs.reserve(origLbMap.getNumResults());
520   ubBoundExprs.reserve(origUbMap.getNumResults());
521 
522   // Get tiling parameter as an affine expression for lb/ub.
523   AffineExpr lbTileParameter = b.getAffineSymbolExpr(origLbMap.getNumSymbols());
524   AffineExpr ubTileParameter = b.getAffineSymbolExpr(origUbMap.getNumSymbols());
525 
526   // Insert lb as inter-tile ((loop IV - origlb) * tilingParameter) + origlb.
527   lbBoundExprs.push_back(
528       ((lbLoopIvExpr - origLowerBoundExpr) * lbTileParameter) +
529       origLowerBoundExpr);
530 
531   // Get the origLoopStep as an affine expression.
532   AffineExpr origLoopStep = b.getAffineConstantExpr(origLoop.getStepAsInt());
533 
534   // Insert ub as inter-tile ((loop IV - origlb) * tilingParameter) +
535   // (tilingParameter * origLoopStep) + origlb.
536   ubBoundExprs.push_back(
537       ((ubLoopIvExpr - origLowerBoundExpr) * ubTileParameter) +
538       (ubTileParameter * origLoopStep) + origLowerBoundExpr);
539 
540   ubBoundExprs.append(origUbMap.getResults().begin(),
541                       origUbMap.getResults().end());
542 
543   AffineMap lbMap =
544       AffineMap::get(origLbMap.getNumDims() + 1, origLbMap.getNumSymbols() + 1,
545                      lbBoundExprs, b.getContext());
546   newIntraTileLoop.setLowerBound(lbOperands, lbMap);
547 
548   AffineMap ubMap =
549       AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols() + 1,
550                      ubBoundExprs, b.getContext());
551   newIntraTileLoop.setUpperBound(ubOperands, ubMap);
552 
553   // Original loop step must be preserved.
554   newIntraTileLoop.setStep(origLoop.getStepAsInt());
555 }
556 
557 /// Set lower and upper bounds of inter-tile loops for parametric tiling.
558 //  TODO: Handle non-constant lower bounds.
559 static void setInterTileBoundsParametric(OpBuilder &b, AffineForOp origLoop,
560                                          AffineForOp newLoop, Value tileSize) {
561   OperandRange newLbOperands = origLoop.getLowerBoundOperands();
562 
563   // The lower bounds for inter-tile loops are same as the corresponding lower
564   // bounds of original loops.
565   newLoop.setLowerBound(newLbOperands, origLoop.getLowerBoundMap());
566 
567   // The new upper bound map for inter-tile loops, assuming constant lower
568   // bounds, are now originalLowerBound + ceildiv((originalUpperBound -
569   // originalLowerBound), tiling parameter); where tiling parameter is the
570   // respective tile size for that loop. For e.g. if the original ubmap was
571   // ()->(1024), the new map will be
572   // ()[s0]->(ceildiv((1024 -lb) % s0)), where s0 is the tiling parameter.
573   // Therefore a new symbol operand is inserted in the map and the result
574   // expression is overwritten.
575 
576   assert(origLoop.hasConstantLowerBound() &&
577          "expected input loops to have constant lower bound.");
578 
579   // Get lower bound of original loop as an affine expression.
580   AffineExpr origLowerBoundExpr;
581   origLowerBoundExpr =
582       b.getAffineConstantExpr(origLoop.getConstantLowerBound());
583 
584   // Add dim operands from original upper bound.
585   SmallVector<Value, 4> ubOperands;
586   AffineBound ub = origLoop.getUpperBound();
587   ubOperands.reserve(ub.getNumOperands() + 1);
588   AffineMap origUbMap = ub.getMap();
589   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
590     ubOperands.push_back(ub.getOperand(j));
591 
592   // Add symbol operands from original upper bound.
593   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
594     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
595 
596   // Add a new symbol operand which is the tile size for this loop.
597   ubOperands.push_back(tileSize);
598 
599   // Get tiling parameter as an affine expression.
600   AffineExpr tileParameter = b.getAffineSymbolExpr(origUbMap.getNumSymbols());
601 
602   SmallVector<AffineExpr, 4> boundExprs;
603   boundExprs.reserve(origUbMap.getNumResults());
604   int64_t origUpperBound;
605   AffineExpr origUpperBoundExpr;
606 
607   // If upper bound for the original loop is constant, then the constant can
608   // be obtained as an affine expression straight away.
609   if (origLoop.hasConstantUpperBound()) {
610     origUpperBound = origLoop.getConstantUpperBound();
611 
612     // Get original constant upper bound as an affine expression.
613     origUpperBoundExpr = b.getAffineConstantExpr(origUpperBound);
614 
615     // Insert the bound as originalLowerBoundceildiv((originalUpperBound -
616     // originalLowerBound), tilingParameter).
617     boundExprs.push_back(
618         origLowerBoundExpr +
619         (origUpperBoundExpr - origLowerBoundExpr).ceilDiv(tileParameter));
620   } else {
621     // If upper bound for the original loop is not constant then two cases
622     // are possible, although there handeling is the same, 1.) The result of
623     // ubmap has only one result expression. For e.g.
624     //    affine.for %i = 5 to %ub
625     //
626     // A symbol operand is added which represents the tiling parameter. The
627     // new loop bounds here will be like ()[s0, s1] -> ((s0 - 5) ceildiv s1 + 5)
628     // where 's0' is the original upper bound and 's1' is the tiling
629     // parameter. 2.) When ubMap has more than one result expression. For e.g.
630     //    #map0 = affine_map<()[s0, s1] -> (s0, s1)
631     //    affine.for %i = 5 to min #map0()[%s0, %s1]
632     //
633     // A symbol operand is added which represents the tiling parameter. The
634     // new loop bounds will be like ()[s0, s1, s2] -> ((s0 - 5) ceildiv s2 + 5,
635     // (s1 -5) ceildiv s2 + 5), where s2 is the tiling parameter.
636 
637     // Insert the bounds as originalLowerBound + ceildiv((originalUpperBound -
638     // originalLowerBound), tilingParameter).
639     for (AffineExpr origUpperBoundExpr : origUbMap.getResults())
640       boundExprs.push_back(
641           origLowerBoundExpr +
642           (origUpperBoundExpr - origLowerBoundExpr).ceilDiv(tileParameter));
643   }
644 
645   AffineMap ubMap =
646       AffineMap::get(origUbMap.getNumDims(), origUbMap.getNumSymbols() + 1,
647                      boundExprs, b.getContext());
648   newLoop.setUpperBound(ubOperands, ubMap);
649 
650   // Original loop step must be preserved.
651   newLoop.setStep(origLoop.getStepAsInt());
652 }
653 
654 /// Constructs and sets new loop bounds after tiling for the case of
655 /// hyper-rectangular index sets, where the bounds of one dimension do not
656 /// depend on other dimensions and tiling parameters are captured from SSA
657 /// values. Bounds of each dimension can thus be treated independently,
658 /// and deriving the new bounds is much simpler and faster than for the case of
659 /// tiling arbitrary polyhedral shapes.
660 static void constructParametricallyTiledIndexSetHyperRect(
661     MutableArrayRef<AffineForOp> origLoops,
662     MutableArrayRef<AffineForOp> newLoops, ArrayRef<Value> tileSizes) {
663   assert(!origLoops.empty() && "expected atleast one loop in band");
664   assert(origLoops.size() == tileSizes.size() &&
665          "expected tiling parameter for each loop in band.");
666 
667   OpBuilder b(origLoops[0].getOperation());
668   unsigned width = origLoops.size();
669 
670   // Set bounds for tile space loops.
671   for (unsigned i = 0; i < width; ++i) {
672     setInterTileBoundsParametric(b, origLoops[i], newLoops[i], tileSizes[i]);
673   }
674 
675   // Set bounds for intra-tile loops.
676   for (unsigned i = 0; i < width; ++i) {
677     setIntraTileBoundsParametric(b, origLoops[i], newLoops[i],
678                                  newLoops[i + width], tileSizes[i]);
679   }
680 }
681 
682 /// Constructs and sets new loop bounds after tiling for the case of
683 /// hyper-rectangular index sets, where the bounds of one dimension do not
684 /// depend on other dimensions. Bounds of each dimension can thus be treated
685 /// independently, and deriving the new bounds is much simpler and faster
686 /// than for the case of tiling arbitrary polyhedral shapes.
687 static void
688 constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,
689                                 MutableArrayRef<AffineForOp> newLoops,
690                                 ArrayRef<unsigned> tileSizes) {
691   assert(!origLoops.empty());
692   assert(origLoops.size() == tileSizes.size());
693 
694   OpBuilder b(origLoops[0].getOperation());
695   unsigned width = origLoops.size();
696 
697   // Bounds for tile space loops.
698   for (unsigned i = 0; i < width; i++) {
699     OperandRange newLbOperands = origLoops[i].getLowerBoundOperands();
700     OperandRange newUbOperands = origLoops[i].getUpperBoundOperands();
701     newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap());
702     newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap());
703     // If the step size of original loop is x and tileSize is y then after
704     // tiling the tile space loops' step size becomes x*y.
705     newLoops[i].setStep(tileSizes[i] * origLoops[i].getStepAsInt());
706   }
707   // Bounds for intra-tile loops.
708   for (unsigned i = 0; i < width; i++) {
709     int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]);
710     std::optional<uint64_t> mayBeConstantCount =
711         getConstantTripCount(origLoops[i]);
712     // The lower bound is just the tile-space loop.
713     AffineMap lbMap = b.getDimIdentityMap();
714     newLoops[width + i].setLowerBound(
715         /*operands=*/newLoops[i].getInductionVar(), lbMap);
716     // The step sizes of intra-tile loops is just the original loops' step size.
717     newLoops[width + i].setStep(origLoops[i].getStepAsInt());
718 
719     // Set the upper bound.
720     if (mayBeConstantCount && *mayBeConstantCount < tileSizes[i]) {
721       // Trip count is less than the tile size: upper bound is lower bound +
722       // trip count * stepSize.
723       AffineMap ubMap = b.getSingleDimShiftAffineMap(
724           *mayBeConstantCount * origLoops[i].getStepAsInt());
725       newLoops[width + i].setUpperBound(
726           /*operands=*/newLoops[i].getInductionVar(), ubMap);
727     } else if (largestDiv % tileSizes[i] != 0) {
728       // Intra-tile loop ii goes from i to min(i + tileSize * stepSize, ub_i).
729       // Construct the upper bound map; the operands are the original operands
730       // with 'i' (tile-space loop) appended to it. The new upper bound map is
731       // the original one with an additional expression i + tileSize * stepSize
732       // appended.
733 
734       // Add dim operands from original upper bound.
735       SmallVector<Value, 4> ubOperands;
736       AffineBound ub = origLoops[i].getUpperBound();
737       ubOperands.reserve(ub.getNumOperands() + 1);
738       AffineMap origUbMap = ub.getMap();
739       for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
740         ubOperands.push_back(ub.getOperand(j));
741 
742       // Add dim operand for new loop upper bound.
743       ubOperands.push_back(newLoops[i].getInductionVar());
744 
745       // Add symbol operands from original upper bound.
746       for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
747         ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
748 
749       SmallVector<AffineExpr, 4> boundExprs;
750       boundExprs.reserve(1 + origUbMap.getNumResults());
751       AffineExpr dim = b.getAffineDimExpr(origUbMap.getNumDims());
752       // The new upper bound map is the original one with an additional
753       // expression i + tileSize * stepSize (of original loop) appended.
754       boundExprs.push_back(dim + tileSizes[i] * origLoops[i].getStepAsInt());
755       boundExprs.append(origUbMap.getResults().begin(),
756                         origUbMap.getResults().end());
757       AffineMap ubMap =
758           AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols(),
759                          boundExprs, b.getContext());
760       newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap);
761     } else {
762       // No need of the min expression.
763       AffineExpr dim = b.getAffineDimExpr(0);
764       AffineMap ubMap = AffineMap::get(
765           1, 0, dim + tileSizes[i] * origLoops[i].getStepAsInt());
766       newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap);
767     }
768   }
769 }
770 
771 LogicalResult
772 mlir::affine::tilePerfectlyNested(MutableArrayRef<AffineForOp> input,
773                                   ArrayRef<unsigned> tileSizes,
774                                   SmallVectorImpl<AffineForOp> *tiledNest) {
775   if (input.empty())
776     return success();
777 
778   if (failed(performPreTilingChecks(input, tileSizes)))
779     return failure();
780 
781   MutableArrayRef<AffineForOp> origLoops = input;
782   AffineForOp rootAffineForOp = origLoops[0];
783 
784   // Note that width is at least one since the band isn't empty.
785   unsigned width = input.size();
786   SmallVector<AffineForOp, 6> tiledLoops(2 * width);
787 
788   // Construct a tiled loop nest without setting their bounds. Bounds are
789   // set later.
790   constructTiledLoopNest(origLoops, rootAffineForOp, width, tiledLoops);
791 
792   SmallVector<Value, 8> origLoopIVs;
793   extractForInductionVars(input, &origLoopIVs);
794 
795   // Set loop bounds for the tiled loop nest.
796   constructTiledIndexSetHyperRect(origLoops, tiledLoops, tileSizes);
797 
798   // Replace original IVs with intra-tile loop IVs.
799   for (unsigned i = 0; i < width; i++)
800     origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar());
801 
802   // Erase the old loop nest.
803   rootAffineForOp.erase();
804 
805   if (tiledNest)
806     *tiledNest = std::move(tiledLoops);
807 
808   return success();
809 }
810 
811 /// Tiles the specified band of perfectly nested loops creating tile-space
812 /// loops and intra-tile loops, using SSA values as tiling parameters. A band
813 /// is a contiguous set of loops.
814 LogicalResult mlir::affine::tilePerfectlyNestedParametric(
815     MutableArrayRef<AffineForOp> input, ArrayRef<Value> tileSizes,
816     SmallVectorImpl<AffineForOp> *tiledNest) {
817   if (input.empty())
818     return success();
819 
820   if (failed(performPreTilingChecks(input, tileSizes)))
821     return failure();
822 
823   MutableArrayRef<AffineForOp> origLoops = input;
824   AffineForOp rootAffineForOp = origLoops[0];
825   unsigned width = input.size();
826   SmallVector<AffineForOp, 6> tiledLoops(2 * width);
827 
828   // Construct a tiled loop nest without setting their bounds. Bounds are
829   // set later.
830   constructTiledLoopNest(origLoops, rootAffineForOp, width, tiledLoops);
831 
832   SmallVector<Value, 8> origLoopIVs;
833   extractForInductionVars(input, &origLoopIVs);
834 
835   // Set loop bounds for the tiled loop nest.
836   constructParametricallyTiledIndexSetHyperRect(origLoops, tiledLoops,
837                                                 tileSizes);
838 
839   // Replace original IVs with intra-tile loop IVs.
840   for (unsigned i = 0; i < width; i++)
841     origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar());
842 
843   // Erase the old loop nest.
844   rootAffineForOp.erase();
845 
846   if (tiledNest)
847     *tiledNest = std::move(tiledLoops);
848 
849   return success();
850 }
851 
852 /// Get perfectly nested sequence of loops starting at root of loop nest
853 /// (the first op being another AffineFor, and the second op - a terminator).
854 /// A loop is perfectly nested iff: the first op in the loop's body is another
855 /// AffineForOp, and the second op is a terminator).
856 void mlir::affine::getPerfectlyNestedLoops(
857     SmallVectorImpl<AffineForOp> &nestedLoops, AffineForOp root) {
858   for (unsigned i = 0; i < std::numeric_limits<unsigned>::max(); ++i) {
859     nestedLoops.push_back(root);
860     Block &body = root.getRegion().front();
861     if (body.begin() != std::prev(body.end(), 2))
862       return;
863 
864     root = dyn_cast<AffineForOp>(&body.front());
865     if (!root)
866       return;
867   }
868 }
869 
870 /// Identify valid and profitable bands of loops to tile. This is currently just
871 /// a temporary placeholder to test the mechanics of tiled code generation.
872 /// Returns all maximal outermost perfect loop nests to tile.
873 void mlir::affine::getTileableBands(
874     func::FuncOp f, std::vector<SmallVector<AffineForOp, 6>> *bands) {
875   // Get maximal perfect nest of 'affine.for' insts starting from root
876   // (inclusive).
877   for (AffineForOp forOp : f.getOps<AffineForOp>()) {
878     SmallVector<AffineForOp, 6> band;
879     getPerfectlyNestedLoops(band, forOp);
880     bands->push_back(band);
881   }
882 }
883 
884 /// Unrolls this loop completely.
885 LogicalResult mlir::affine::loopUnrollFull(AffineForOp forOp) {
886   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
887   if (mayBeConstantTripCount.has_value()) {
888     uint64_t tripCount = *mayBeConstantTripCount;
889     if (tripCount == 0)
890       return success();
891     if (tripCount == 1)
892       return promoteIfSingleIteration(forOp);
893     return loopUnrollByFactor(forOp, tripCount);
894   }
895   return failure();
896 }
897 
898 /// Unrolls this loop by the specified factor or by the trip count (if constant)
899 /// whichever is lower.
900 LogicalResult mlir::affine::loopUnrollUpToFactor(AffineForOp forOp,
901                                                  uint64_t unrollFactor) {
902   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
903   if (mayBeConstantTripCount.has_value() &&
904       *mayBeConstantTripCount < unrollFactor)
905     return loopUnrollByFactor(forOp, *mayBeConstantTripCount);
906   return loopUnrollByFactor(forOp, unrollFactor);
907 }
908 
909 /// Generates unrolled copies of AffineForOp 'loopBodyBlock', with associated
910 /// 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap 'forOpIV' for each
911 /// unrolled body. If specified, annotates the Ops in each unrolled iteration
912 /// using annotateFn.
913 static void generateUnrolledLoop(
914     Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
915     function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
916     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
917     ValueRange iterArgs, ValueRange yieldedValues) {
918   // Builder to insert unrolled bodies just before the terminator of the body of
919   // 'forOp'.
920   auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
921 
922   constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
923   if (!annotateFn)
924     annotateFn = defaultAnnotateFn;
925 
926   // Keep a pointer to the last non-terminator operation in the original block
927   // so that we know what to clone (since we are doing this in-place).
928   Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
929 
930   // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
931   SmallVector<Value, 4> lastYielded(yieldedValues);
932 
933   for (unsigned i = 1; i < unrollFactor; i++) {
934     IRMapping operandMap;
935 
936     // Prepare operand map.
937     operandMap.map(iterArgs, lastYielded);
938 
939     // If the induction variable is used, create a remapping to the value for
940     // this unrolled instance.
941     if (!forOpIV.use_empty()) {
942       Value ivUnroll = ivRemapFn(i, forOpIV, builder);
943       operandMap.map(forOpIV, ivUnroll);
944     }
945 
946     // Clone the original body of 'forOp'.
947     for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
948       Operation *clonedOp = builder.clone(*it, operandMap);
949       annotateFn(i, clonedOp, builder);
950     }
951 
952     // Update yielded values. If the yielded value is defined outside the
953     // `loopBodyBlock` or if it is a BlockArgument then it won't be cloned, thus
954     // the `lastYielded` value remains unchanged. Else, update the `lastYielded`
955     // value with the clone corresponding to the yielded value.
956     for (unsigned i = 0, e = lastYielded.size(); i < e; i++) {
957       Operation *defOp = yieldedValues[i].getDefiningOp();
958       if (defOp && defOp->getBlock() == loopBodyBlock)
959         lastYielded[i] = operandMap.lookup(yieldedValues[i]);
960     }
961   }
962 
963   // Make sure we annotate the Ops in the original body. We do this last so that
964   // any annotations are not copied into the cloned Ops above.
965   for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
966     annotateFn(0, &*it, builder);
967 
968   // Update operands of the yield statement.
969   loopBodyBlock->getTerminator()->setOperands(lastYielded);
970 }
971 
972 /// Helper to generate cleanup loop for unroll or unroll-and-jam when the trip
973 /// count is not a multiple of `unrollFactor`.
974 static LogicalResult generateCleanupLoopForUnroll(AffineForOp forOp,
975                                                   uint64_t unrollFactor) {
976   // Insert the cleanup loop right after 'forOp'.
977   OpBuilder builder(forOp->getBlock(), std::next(Block::iterator(forOp)));
978   auto cleanupForOp = cast<AffineForOp>(builder.clone(*forOp));
979 
980   // Update uses of `forOp` results. `cleanupForOp` should use `forOp` result
981   // and produce results for the original users of `forOp` results.
982   auto results = forOp.getResults();
983   auto cleanupResults = cleanupForOp.getResults();
984   auto cleanupIterOperands = cleanupForOp.getInits();
985 
986   for (auto e : llvm::zip(results, cleanupResults, cleanupIterOperands)) {
987     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
988     cleanupForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
989   }
990 
991   AffineMap cleanupMap;
992   SmallVector<Value, 4> cleanupOperands;
993   getCleanupLoopLowerBound(forOp, unrollFactor, cleanupMap, cleanupOperands);
994   if (!cleanupMap)
995     return failure();
996 
997   cleanupForOp.setLowerBound(cleanupOperands, cleanupMap);
998   // Promote the loop body up if this has turned into a single iteration loop.
999   (void)promoteIfSingleIteration(cleanupForOp);
1000 
1001   // Adjust upper bound of the original loop; this is the same as the lower
1002   // bound of the cleanup loop.
1003   forOp.setUpperBound(cleanupOperands, cleanupMap);
1004   return success();
1005 }
1006 
1007 /// Unrolls this loop by the specified factor. Returns success if the loop
1008 /// is successfully unrolled.
1009 LogicalResult mlir::affine::loopUnrollByFactor(
1010     AffineForOp forOp, uint64_t unrollFactor,
1011     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
1012     bool cleanUpUnroll) {
1013   assert(unrollFactor > 0 && "unroll factor should be positive");
1014 
1015   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1016   if (unrollFactor == 1) {
1017     if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
1018         failed(promoteIfSingleIteration(forOp)))
1019       return failure();
1020     return success();
1021   }
1022 
1023   // Nothing in the loop body other than the terminator.
1024   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
1025     return success();
1026 
1027   // If the trip count is lower than the unroll factor, no unrolled body.
1028   if (mayBeConstantTripCount && *mayBeConstantTripCount < unrollFactor) {
1029     if (cleanUpUnroll) {
1030       // Unroll the cleanup loop if cleanUpUnroll is specified.
1031       return loopUnrollFull(forOp);
1032     }
1033 
1034     return failure();
1035   }
1036 
1037   // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
1038   if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
1039     // Loops where the lower bound is a max expression or the upper bound is
1040     // a min expression and the trip count doesn't divide the unroll factor
1041     // can't be unrolled since the lower bound of the cleanup loop in such cases
1042     // cannot be expressed as an affine function or a max over affine functions.
1043     if (forOp.getLowerBoundMap().getNumResults() != 1 ||
1044         forOp.getUpperBoundMap().getNumResults() != 1)
1045       return failure();
1046     if (cleanUpUnroll)
1047       // Force unroll including cleanup loop
1048       return loopUnrollFull(forOp);
1049     if (failed(generateCleanupLoopForUnroll(forOp, unrollFactor)))
1050       assert(false && "cleanup loop lower bound map for single result lower "
1051                       "and upper bound maps can always be determined");
1052   }
1053 
1054   ValueRange iterArgs(forOp.getRegionIterArgs());
1055   auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
1056 
1057   // Scale the step of loop being unrolled by unroll factor.
1058   int64_t step = forOp.getStepAsInt();
1059   forOp.setStep(step * unrollFactor);
1060   generateUnrolledLoop(
1061       forOp.getBody(), forOp.getInductionVar(), unrollFactor,
1062       [&](unsigned i, Value iv, OpBuilder b) {
1063         // iv' = iv + i * step
1064         auto d0 = b.getAffineDimExpr(0);
1065         auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
1066         return b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, iv);
1067       },
1068       /*annotateFn=*/annotateFn,
1069       /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues);
1070 
1071   // Promote the loop body up if this has turned into a single iteration loop.
1072   (void)promoteIfSingleIteration(forOp);
1073   return success();
1074 }
1075 
1076 LogicalResult mlir::affine::loopUnrollJamUpToFactor(AffineForOp forOp,
1077                                                     uint64_t unrollJamFactor) {
1078   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1079   if (mayBeConstantTripCount.has_value() &&
1080       *mayBeConstantTripCount < unrollJamFactor)
1081     return loopUnrollJamByFactor(forOp, *mayBeConstantTripCount);
1082   return loopUnrollJamByFactor(forOp, unrollJamFactor);
1083 }
1084 
1085 /// Check if all control operands of all loops are defined outside of `forOp`
1086 /// and return false if not.
1087 static bool areInnerBoundsInvariant(AffineForOp forOp) {
1088   auto walkResult = forOp.walk([&](AffineForOp aForOp) {
1089     for (auto controlOperand : aForOp.getControlOperands()) {
1090       if (!forOp.isDefinedOutsideOfLoop(controlOperand))
1091         return WalkResult::interrupt();
1092     }
1093     return WalkResult::advance();
1094   });
1095   return !walkResult.wasInterrupted();
1096 }
1097 
1098 /// Unrolls and jams this loop by the specified factor.
1099 LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
1100                                                   uint64_t unrollJamFactor) {
1101   assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
1102 
1103   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1104   if (unrollJamFactor == 1) {
1105     if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
1106         failed(promoteIfSingleIteration(forOp)))
1107       return failure();
1108     return success();
1109   }
1110 
1111   // Nothing in the loop body other than the terminator.
1112   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
1113     return success();
1114 
1115   // If the trip count is lower than the unroll jam factor, no unroll jam.
1116   if (mayBeConstantTripCount && *mayBeConstantTripCount < unrollJamFactor) {
1117     LLVM_DEBUG(llvm::dbgs() << "[failed] trip count < unroll-jam factor\n");
1118     return failure();
1119   }
1120 
1121   // If any control operand of any inner loop of `forOp` is defined within
1122   // `forOp`, no unroll jam.
1123   if (!areInnerBoundsInvariant(forOp))
1124     return failure();
1125 
1126   // Gather all sub-blocks to jam upon the loop being unrolled.
1127   JamBlockGatherer<AffineForOp> jbg;
1128   jbg.walk(forOp);
1129   auto &subBlocks = jbg.subBlocks;
1130 
1131   // Collect loops with iter_args.
1132   SmallVector<AffineForOp, 4> loopsWithIterArgs;
1133   forOp.walk([&](AffineForOp aForOp) {
1134     if (aForOp.getNumIterOperands() > 0)
1135       loopsWithIterArgs.push_back(aForOp);
1136   });
1137 
1138   // Get supported reductions to be used for creating reduction ops at the end.
1139   SmallVector<LoopReduction> reductions;
1140   if (forOp.getNumIterOperands() > 0)
1141     getSupportedReductions(forOp, reductions);
1142 
1143   // Generate the cleanup loop if trip count isn't a multiple of
1144   // unrollJamFactor.
1145   if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) {
1146     // Loops where the lower bound is a max expression or the upper bound is
1147     // a min expression and the trip count doesn't divide the unroll factor
1148     // can't be unrolled since the lower bound of the cleanup loop in such cases
1149     // cannot be expressed as an affine function or a max over affine functions.
1150     if (forOp.getLowerBoundMap().getNumResults() != 1 ||
1151         forOp.getUpperBoundMap().getNumResults() != 1)
1152       return failure();
1153     if (failed(generateCleanupLoopForUnroll(forOp, unrollJamFactor)))
1154       assert(false && "cleanup loop lower bound map for single result lower "
1155                       "and upper bound maps can always be determined");
1156   }
1157 
1158   // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
1159   // iteration. There are (`unrollJamFactor` - 1) iterations.
1160   SmallVector<IRMapping, 4> operandMaps(unrollJamFactor - 1);
1161 
1162   // For any loop with iter_args, replace it with a new loop that has
1163   // `unrollJamFactor` copies of its iterOperands, iter_args and yield
1164   // operands.
1165   SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
1166   IRRewriter rewriter(forOp.getContext());
1167   for (AffineForOp oldForOp : loopsWithIterArgs) {
1168     SmallVector<Value> dupIterOperands, dupYieldOperands;
1169     ValueRange oldIterOperands = oldForOp.getInits();
1170     ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
1171     ValueRange oldYieldOperands =
1172         cast<AffineYieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
1173     // Get additional iterOperands, iterArgs, and yield operands. We will
1174     // fix iterOperands and yield operands after cloning of sub-blocks.
1175     for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1176       dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
1177       dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
1178     }
1179     // Create a new loop with additional iterOperands, iter_args and yield
1180     // operands. This new loop will take the loop body of the original loop.
1181     bool forOpReplaced = oldForOp == forOp;
1182     AffineForOp newForOp =
1183         cast<AffineForOp>(*oldForOp.replaceWithAdditionalYields(
1184             rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
1185             [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
1186               return dupYieldOperands;
1187             }));
1188     newLoopsWithIterArgs.push_back(newForOp);
1189     // `forOp` has been replaced with a new loop.
1190     if (forOpReplaced)
1191       forOp = newForOp;
1192     // Update `operandMaps` for `newForOp` iterArgs and results.
1193     ValueRange newIterArgs = newForOp.getRegionIterArgs();
1194     unsigned oldNumIterArgs = oldIterArgs.size();
1195     ValueRange newResults = newForOp.getResults();
1196     unsigned oldNumResults = newResults.size() / unrollJamFactor;
1197     assert(oldNumIterArgs == oldNumResults &&
1198            "oldNumIterArgs must be the same as oldNumResults");
1199     for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1200       for (unsigned j = 0; j < oldNumIterArgs; ++j) {
1201         // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
1202         // results. Update `operandMaps[i - 1]` to map old iterArgs and results
1203         // to those in the `i`th new set.
1204         operandMaps[i - 1].map(newIterArgs[j],
1205                                newIterArgs[i * oldNumIterArgs + j]);
1206         operandMaps[i - 1].map(newResults[j],
1207                                newResults[i * oldNumResults + j]);
1208       }
1209     }
1210   }
1211 
1212   // Scale the step of loop being unroll-jammed by the unroll-jam factor.
1213   int64_t step = forOp.getStepAsInt();
1214   forOp.setStep(step * unrollJamFactor);
1215 
1216   auto forOpIV = forOp.getInductionVar();
1217   // Unroll and jam (appends unrollJamFactor - 1 additional copies).
1218   for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1219     for (auto &subBlock : subBlocks) {
1220       // Builder to insert unroll-jammed bodies. Insert right at the end of
1221       // sub-block.
1222       OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
1223 
1224       // If the induction variable is used, create a remapping to the value for
1225       // this unrolled instance.
1226       if (!forOpIV.use_empty()) {
1227         // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
1228         auto d0 = builder.getAffineDimExpr(0);
1229         auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
1230         auto ivUnroll =
1231             builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
1232         operandMaps[i - 1].map(forOpIV, ivUnroll);
1233       }
1234       // Clone the sub-block being unroll-jammed.
1235       for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
1236         builder.clone(*it, operandMaps[i - 1]);
1237     }
1238     // Fix iterOperands and yield op operands of newly created loops.
1239     for (auto newForOp : newLoopsWithIterArgs) {
1240       unsigned oldNumIterOperands =
1241           newForOp.getNumIterOperands() / unrollJamFactor;
1242       unsigned numControlOperands = newForOp.getNumControlOperands();
1243       auto yieldOp = cast<AffineYieldOp>(newForOp.getBody()->getTerminator());
1244       unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
1245       assert(oldNumIterOperands == oldNumYieldOperands &&
1246              "oldNumIterOperands must be the same as oldNumYieldOperands");
1247       for (unsigned j = 0; j < oldNumIterOperands; ++j) {
1248         // The `i`th duplication of an old iterOperand or yield op operand
1249         // needs to be replaced with a mapped value from `operandMaps[i - 1]`
1250         // if such mapped value exists.
1251         newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
1252                             operandMaps[i - 1].lookupOrDefault(
1253                                 newForOp.getOperand(numControlOperands + j)));
1254         yieldOp.setOperand(
1255             i * oldNumYieldOperands + j,
1256             operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
1257       }
1258     }
1259   }
1260   if (forOp.getNumResults() > 0) {
1261     // Create reduction ops to combine every `unrollJamFactor` related results
1262     // into one value. For example, for %0:2 = affine.for ... and addf, we add
1263     // %1 = arith.addf %0#0, %0#1, and replace the following uses of %0#0 with
1264     // %1.
1265     rewriter.setInsertionPointAfter(forOp);
1266     auto loc = forOp.getLoc();
1267     unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
1268     for (LoopReduction &reduction : reductions) {
1269       unsigned pos = reduction.iterArgPosition;
1270       Value lhs = forOp.getResult(pos);
1271       Value rhs;
1272       SmallPtrSet<Operation *, 4> newOps;
1273       for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1274         rhs = forOp.getResult(i * oldNumResults + pos);
1275         // Create ops based on reduction type.
1276         lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
1277         if (!lhs)
1278           return failure();
1279         Operation *op = lhs.getDefiningOp();
1280         assert(op && "Reduction op should have been created");
1281         newOps.insert(op);
1282       }
1283       // Replace all uses except those in newly created reduction ops.
1284       forOp.getResult(pos).replaceAllUsesExcept(lhs, newOps);
1285     }
1286   }
1287 
1288   // Promote the loop body up if this has turned into a single iteration loop.
1289   (void)promoteIfSingleIteration(forOp);
1290   return success();
1291 }
1292 
1293 /// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is
1294 /// nested within 'forOpA' as the only non-terminator operation in its block.
1295 void mlir::affine::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) {
1296   assert(&*forOpA.getBody()->begin() == forOpB.getOperation());
1297   auto &forOpABody = forOpA.getBody()->getOperations();
1298   auto &forOpBBody = forOpB.getBody()->getOperations();
1299 
1300   // 1) Splice forOpA's non-terminator operations (which is just forOpB) just
1301   // before forOpA (in ForOpA's parent's block) this should leave 'forOpA's
1302   // body containing only the terminator.
1303   forOpA->getBlock()->getOperations().splice(Block::iterator(forOpA),
1304                                              forOpABody, forOpABody.begin(),
1305                                              std::prev(forOpABody.end()));
1306   // 2) Splice forOpB's non-terminator operations into the beginning of forOpA's
1307   // body (this leaves forOpB's body containing only the terminator).
1308   forOpABody.splice(forOpABody.begin(), forOpBBody, forOpBBody.begin(),
1309                     std::prev(forOpBBody.end()));
1310   // 3) Splice forOpA into the beginning of forOpB's body.
1311   forOpBBody.splice(forOpBBody.begin(), forOpA->getBlock()->getOperations(),
1312                     Block::iterator(forOpA));
1313 }
1314 
1315 // Checks each dependence component against the permutation to see if the
1316 // desired loop interchange would violate dependences by making the
1317 // dependence component lexicographically negative.
1318 static bool checkLoopInterchangeDependences(
1319     const std::vector<SmallVector<DependenceComponent, 2>> &depCompsVec,
1320     ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) {
1321   // Invert permutation map.
1322   unsigned maxLoopDepth = loops.size();
1323   SmallVector<unsigned, 4> loopPermMapInv;
1324   loopPermMapInv.resize(maxLoopDepth);
1325   for (unsigned i = 0; i < maxLoopDepth; ++i)
1326     loopPermMapInv[loopPermMap[i]] = i;
1327 
1328   // Check each dependence component against the permutation to see if the
1329   // desired loop interchange permutation would make the dependence vectors
1330   // lexicographically negative.
1331   // Example 1: [-1, 1][0, 0]
1332   // Example 2: [0, 0][-1, 1]
1333   for (const auto &depComps : depCompsVec) {
1334     assert(depComps.size() >= maxLoopDepth);
1335     // Check if the first non-zero dependence component is positive.
1336     // This iterates through loops in the desired order.
1337     for (unsigned j = 0; j < maxLoopDepth; ++j) {
1338       unsigned permIndex = loopPermMapInv[j];
1339       assert(depComps[permIndex].lb);
1340       int64_t depCompLb = *depComps[permIndex].lb;
1341       if (depCompLb > 0)
1342         break;
1343       if (depCompLb < 0)
1344         return false;
1345     }
1346   }
1347   return true;
1348 }
1349 
1350 /// Checks if the loop interchange permutation 'loopPermMap' of the perfectly
1351 /// nested sequence of loops in 'loops' would violate dependences.
1352 bool mlir::affine::isValidLoopInterchangePermutation(
1353     ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) {
1354   // Gather dependence components for dependences between all ops in loop nest
1355   // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
1356   assert(loopPermMap.size() == loops.size());
1357   unsigned maxLoopDepth = loops.size();
1358   std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
1359   getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
1360   return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap);
1361 }
1362 
1363 /// Returns true if `loops` is a perfectly nested loop nest, where loops appear
1364 /// in it from outermost to innermost.
1365 bool LLVM_ATTRIBUTE_UNUSED
1366 mlir::affine::isPerfectlyNested(ArrayRef<AffineForOp> loops) {
1367   assert(!loops.empty() && "no loops provided");
1368 
1369   // We already know that the block can't be empty.
1370   auto hasTwoElements = [](Block *block) {
1371     auto secondOpIt = std::next(block->begin());
1372     return secondOpIt != block->end() && &*secondOpIt == &block->back();
1373   };
1374 
1375   auto enclosingLoop = loops.front();
1376   for (auto loop : loops.drop_front()) {
1377     auto parentForOp = dyn_cast<AffineForOp>(loop->getParentOp());
1378     // parentForOp's body should be just this loop and the terminator.
1379     if (parentForOp != enclosingLoop || !hasTwoElements(parentForOp.getBody()))
1380       return false;
1381     enclosingLoop = loop;
1382   }
1383   return true;
1384 }
1385 
1386 // input[i] should move from position i -> permMap[i]. Returns the position in
1387 // `input` that becomes the new outermost loop.
1388 unsigned mlir::affine::permuteLoops(ArrayRef<AffineForOp> input,
1389                                     ArrayRef<unsigned> permMap) {
1390   assert(input.size() == permMap.size() && "invalid permutation map size");
1391   // Check whether the permutation spec is valid. This is a small vector - we'll
1392   // just sort and check if it's iota.
1393   SmallVector<unsigned, 4> checkPermMap(permMap);
1394   llvm::sort(checkPermMap);
1395   if (llvm::any_of(llvm::enumerate(checkPermMap),
1396                    [](const auto &en) { return en.value() != en.index(); }))
1397     assert(false && "invalid permutation map");
1398 
1399   // Nothing to do.
1400   if (input.size() < 2)
1401     return 0;
1402 
1403   assert(isPerfectlyNested(input) && "input not perfectly nested");
1404 
1405   // Compute the inverse mapping, invPermMap: since input[i] goes to position
1406   // permMap[i], position i of the permuted nest is at input[invPermMap[i]].
1407   SmallVector<std::pair<unsigned, unsigned>, 4> invPermMap;
1408   for (unsigned i = 0, e = input.size(); i < e; ++i)
1409     invPermMap.push_back({permMap[i], i});
1410   llvm::sort(invPermMap);
1411 
1412   // Move the innermost loop body to the loop that would be the innermost in the
1413   // permuted nest (only if the innermost loop is going to change).
1414   if (permMap.back() != input.size() - 1) {
1415     Block *destBody = ((AffineForOp)input[invPermMap.back().second]).getBody();
1416     Block *srcBody = ((AffineForOp)input.back()).getBody();
1417     destBody->getOperations().splice(destBody->begin(),
1418                                      srcBody->getOperations(), srcBody->begin(),
1419                                      std::prev(srcBody->end()));
1420   }
1421 
1422   // We'll move each loop in `input` in the reverse order so that its body is
1423   // empty when we are moving it; this incurs zero copies and no erasing.
1424   for (int i = input.size() - 1; i >= 0; --i) {
1425     // If this has to become the outermost loop after permutation, add it to the
1426     // parent block of the original root.
1427     if (permMap[i] == 0) {
1428       // If the root remains the same, nothing to do.
1429       if (i == 0)
1430         continue;
1431       // Make input[i] the new outermost loop moving it into parentBlock.
1432       auto *parentBlock = input[0]->getBlock();
1433       parentBlock->getOperations().splice(Block::iterator(input[0]),
1434                                           input[i]->getBlock()->getOperations(),
1435                                           Block::iterator(input[i]));
1436       continue;
1437     }
1438 
1439     // If the parent in the permuted order is the same as in the original,
1440     // nothing to do.
1441     unsigned parentPosInInput = invPermMap[permMap[i] - 1].second;
1442     if (i > 0 && static_cast<unsigned>(i - 1) == parentPosInInput)
1443       continue;
1444 
1445     // Move input[i] to its surrounding loop in the transformed nest.
1446     auto *destBody = ((AffineForOp)input[parentPosInInput]).getBody();
1447     destBody->getOperations().splice(destBody->begin(),
1448                                      input[i]->getBlock()->getOperations(),
1449                                      Block::iterator(input[i]));
1450   }
1451 
1452   return invPermMap[0].second;
1453 }
1454 
1455 // Sinks all sequential loops to the innermost levels (while preserving
1456 // relative order among them) and moves all parallel loops to the
1457 // outermost (while again preserving relative order among them).
1458 AffineForOp mlir::affine::sinkSequentialLoops(AffineForOp forOp) {
1459   SmallVector<AffineForOp, 4> loops;
1460   getPerfectlyNestedLoops(loops, forOp);
1461   if (loops.size() < 2)
1462     return forOp;
1463 
1464   // Gather dependence components for dependences between all ops in loop nest
1465   // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
1466   unsigned maxLoopDepth = loops.size();
1467   std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
1468   getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
1469 
1470   // Mark loops as either parallel or sequential.
1471   SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
1472   for (auto &depComps : depCompsVec) {
1473     assert(depComps.size() >= maxLoopDepth);
1474     for (unsigned j = 0; j < maxLoopDepth; ++j) {
1475       DependenceComponent &depComp = depComps[j];
1476       assert(depComp.lb.has_value() && depComp.ub.has_value());
1477       if (*depComp.lb != 0 || *depComp.ub != 0)
1478         isParallelLoop[j] = false;
1479     }
1480   }
1481 
1482   unsigned numParallelLoops = llvm::count(isParallelLoop, true);
1483 
1484   // Compute permutation of loops that sinks sequential loops (and thus raises
1485   // parallel loops) while preserving relative order.
1486   SmallVector<unsigned, 4> loopPermMap(maxLoopDepth);
1487   unsigned nextSequentialLoop = numParallelLoops;
1488   unsigned nextParallelLoop = 0;
1489   for (unsigned i = 0; i < maxLoopDepth; ++i) {
1490     if (isParallelLoop[i]) {
1491       loopPermMap[i] = nextParallelLoop++;
1492     } else {
1493       loopPermMap[i] = nextSequentialLoop++;
1494     }
1495   }
1496 
1497   // Check if permutation 'loopPermMap' would violate dependences.
1498   if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap))
1499     return forOp;
1500   // Perform loop interchange according to permutation 'loopPermMap'.
1501   unsigned loopNestRootIndex = permuteLoops(loops, loopPermMap);
1502   return loops[loopNestRootIndex];
1503 }
1504 
1505 // Factors out common behavior to add a new `iv` (resp. `iv` + `offset`) to the
1506 // lower (resp. upper) loop bound. When called for both the lower and upper
1507 // bounds, the resulting IR resembles:
1508 //
1509 // ```mlir
1510 //    affine.for %i = max (`iv, ...) to min (`iv` + `offset`) {
1511 //      ...
1512 //    }
1513 // ```
1514 static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map,
1515                                 SmallVector<Value, 4> *operands,
1516                                 int64_t offset = 0) {
1517   auto bounds = llvm::to_vector<4>(map->getResults());
1518   bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset);
1519   operands->insert(operands->begin() + map->getNumDims(), iv);
1520   *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds,
1521                         b.getContext());
1522   canonicalizeMapAndOperands(map, operands);
1523 }
1524 
1525 // Stripmines `forOp` by `factor` and sinks it under each of the `targets`.
1526 // Stripmine-sink is a primitive building block for generalized tiling of
1527 // imperfectly nested loops.
1528 // This transformation is purely mechanical and does not check legality,
1529 // profitability or even structural correctness. It is the user's
1530 // responsibility to specify `targets` that are dominated by `forOp`.
1531 // Returns the new AffineForOps, one per `targets`, nested immediately under
1532 // each of the `targets`.
1533 static SmallVector<AffineForOp, 8>
1534 stripmineSink(AffineForOp forOp, uint64_t factor,
1535               ArrayRef<AffineForOp> targets) {
1536   auto originalStep = forOp.getStepAsInt();
1537   auto scaledStep = originalStep * factor;
1538   forOp.setStep(scaledStep);
1539 
1540   OpBuilder b(forOp->getBlock(), std::next(Block::iterator(forOp)));
1541 
1542   // Lower-bound map creation.
1543   auto lbMap = forOp.getLowerBoundMap();
1544   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1545   augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands);
1546 
1547   // Upper-bound map creation.
1548   auto ubMap = forOp.getUpperBoundMap();
1549   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1550   augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands,
1551                       /*offset=*/scaledStep);
1552 
1553   auto iv = forOp.getInductionVar();
1554   SmallVector<AffineForOp, 8> innerLoops;
1555   for (auto t : targets) {
1556     // Insert newForOp before the terminator of `t`.
1557     auto b = OpBuilder::atBlockTerminator(t.getBody());
1558     auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap,
1559                                           ubOperands, ubMap, originalStep);
1560     auto begin = t.getBody()->begin();
1561     // Skip terminator and `newForOp` which is just before the terminator.
1562     auto nOps = t.getBody()->getOperations().size() - 2;
1563     newForOp.getBody()->getOperations().splice(
1564         newForOp.getBody()->getOperations().begin(),
1565         t.getBody()->getOperations(), begin, std::next(begin, nOps));
1566     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1567                                newForOp.getRegion());
1568     innerLoops.push_back(newForOp);
1569   }
1570 
1571   return innerLoops;
1572 }
1573 
1574 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1575 // Returns the new AffineForOps, nested immediately under `target`.
1576 template <typename SizeType>
1577 static AffineForOp stripmineSink(AffineForOp forOp, SizeType factor,
1578                                  AffineForOp target) {
1579   // TODO: Use cheap structural assertions that targets are nested under
1580   // forOp and that targets are not nested under each other when DominanceInfo
1581   // exposes the capability. It seems overkill to construct a whole function
1582   // dominance tree at this point.
1583   auto res = stripmineSink(forOp, factor, ArrayRef<AffineForOp>(target));
1584   assert(res.size() == 1 && "Expected 1 inner forOp");
1585   return res[0];
1586 }
1587 
1588 SmallVector<SmallVector<AffineForOp, 8>, 8>
1589 mlir::affine::tile(ArrayRef<AffineForOp> forOps, ArrayRef<uint64_t> sizes,
1590                    ArrayRef<AffineForOp> targets) {
1591   SmallVector<SmallVector<AffineForOp, 8>, 8> res;
1592   SmallVector<AffineForOp, 8> currentTargets(targets);
1593   for (auto it : llvm::zip(forOps, sizes)) {
1594     auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1595     res.push_back(step);
1596     currentTargets = step;
1597   }
1598   return res;
1599 }
1600 
1601 SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
1602                                                ArrayRef<uint64_t> sizes,
1603                                                AffineForOp target) {
1604   SmallVector<AffineForOp, 8> res;
1605   for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
1606     assert(loops.size() == 1);
1607     res.push_back(loops[0]);
1608   }
1609   return res;
1610 }
1611 
1612 LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
1613   if (loops.size() < 2)
1614     return success();
1615 
1616   AffineForOp innermost = loops.back();
1617   AffineForOp outermost = loops.front();
1618   AffineBound ub = outermost.getUpperBound();
1619   AffineMap origUbMap = ub.getMap();
1620   Location loc = outermost.getLoc();
1621   OpBuilder builder(outermost);
1622   for (AffineForOp loop : loops) {
1623     // We only work on normalized loops.
1624     if (loop.getStepAsInt() != 1 || !loop.hasConstantLowerBound() ||
1625         loop.getConstantLowerBound() != 0)
1626       return failure();
1627   }
1628   SmallVector<Value, 4> upperBoundSymbols;
1629   SmallVector<Value, 4> ubOperands(ub.getOperands().begin(),
1630                                    ub.getOperands().end());
1631 
1632   // 1. Store the upper bound of the outermost loop in a variable.
1633   Value prev;
1634   if (!llvm::hasSingleElement(origUbMap.getResults()))
1635     prev = builder.create<AffineMinOp>(loc, origUbMap, ubOperands);
1636   else
1637     prev = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands);
1638   upperBoundSymbols.push_back(prev);
1639 
1640   // 2. Emit code computing the upper bound of the coalesced loop as product of
1641   // the number of iterations of all loops.
1642   for (AffineForOp loop : loops.drop_front()) {
1643     ub = loop.getUpperBound();
1644     origUbMap = ub.getMap();
1645     ubOperands = ub.getOperands();
1646     Value upperBound;
1647     // If upper bound map has more than one result, take their minimum.
1648     if (!llvm::hasSingleElement(origUbMap.getResults()))
1649       upperBound = builder.create<AffineMinOp>(loc, origUbMap, ubOperands);
1650     else
1651       upperBound = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands);
1652     upperBoundSymbols.push_back(upperBound);
1653     SmallVector<Value, 4> operands;
1654     operands.push_back(prev);
1655     operands.push_back(upperBound);
1656     // Maintain running product of loop upper bounds.
1657     prev = builder.create<AffineApplyOp>(
1658         loc,
1659         AffineMap::get(/*dimCount=*/1,
1660                        /*symbolCount=*/1,
1661                        builder.getAffineDimExpr(0) *
1662                            builder.getAffineSymbolExpr(0)),
1663         operands);
1664   }
1665   // Set upper bound of the coalesced loop.
1666   AffineMap newUbMap = AffineMap::get(
1667       /*dimCount=*/0,
1668       /*symbolCount=*/1, builder.getAffineSymbolExpr(0), builder.getContext());
1669   outermost.setUpperBound(prev, newUbMap);
1670 
1671   builder.setInsertionPointToStart(outermost.getBody());
1672 
1673   // 3. Remap induction variables. For each original loop, the value of the
1674   // induction variable can be obtained by dividing the induction variable of
1675   // the linearized loop by the total number of iterations of the loops nested
1676   // in it modulo the number of iterations in this loop (remove the values
1677   // related to the outer loops):
1678   //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
1679   // Compute these iteratively from the innermost loop by creating a "running
1680   // quotient" of division by the range.
1681   Value previous = outermost.getInductionVar();
1682   for (unsigned idx = loops.size(); idx > 0; --idx) {
1683     if (idx != loops.size()) {
1684       SmallVector<Value, 4> operands;
1685       operands.push_back(previous);
1686       operands.push_back(upperBoundSymbols[idx]);
1687       previous = builder.create<AffineApplyOp>(
1688           loc,
1689           AffineMap::get(
1690               /*dimCount=*/1, /*symbolCount=*/1,
1691               builder.getAffineDimExpr(0).floorDiv(
1692                   builder.getAffineSymbolExpr(0))),
1693           operands);
1694     }
1695     // Modified value of the induction variables of the nested loops after
1696     // coalescing.
1697     Value inductionVariable;
1698     if (idx == 1) {
1699       inductionVariable = previous;
1700     } else {
1701       SmallVector<Value, 4> applyOperands;
1702       applyOperands.push_back(previous);
1703       applyOperands.push_back(upperBoundSymbols[idx - 1]);
1704       inductionVariable = builder.create<AffineApplyOp>(
1705           loc,
1706           AffineMap::get(
1707               /*dimCount=*/1, /*symbolCount=*/1,
1708               builder.getAffineDimExpr(0) % builder.getAffineSymbolExpr(0)),
1709           applyOperands);
1710     }
1711     replaceAllUsesInRegionWith(loops[idx - 1].getInductionVar(),
1712                                inductionVariable, loops.back().getRegion());
1713   }
1714 
1715   // 4. Move the operations from the innermost just above the second-outermost
1716   // loop, delete the extra terminator and the second-outermost loop.
1717   AffineForOp secondOutermostLoop = loops[1];
1718   innermost.getBody()->back().erase();
1719   outermost.getBody()->getOperations().splice(
1720       Block::iterator(secondOutermostLoop.getOperation()),
1721       innermost.getBody()->getOperations());
1722   secondOutermostLoop.erase();
1723   return success();
1724 }
1725 
1726 void mlir::affine::mapLoopToProcessorIds(scf::ForOp forOp,
1727                                          ArrayRef<Value> processorId,
1728                                          ArrayRef<Value> numProcessors) {
1729   assert(processorId.size() == numProcessors.size());
1730   if (processorId.empty())
1731     return;
1732 
1733   OpBuilder b(forOp);
1734   Location loc(forOp.getLoc());
1735   AffineExpr lhs, rhs;
1736   bindSymbols(forOp.getContext(), lhs, rhs);
1737   auto mulMap = AffineMap::get(0, 2, lhs * rhs);
1738   auto addMap = AffineMap::get(0, 2, lhs + rhs);
1739 
1740   Value linearIndex = processorId.front();
1741   for (unsigned i = 1, e = processorId.size(); i < e; ++i) {
1742     auto mulApplyOp = b.create<AffineApplyOp>(
1743         loc, mulMap, ValueRange{linearIndex, numProcessors[i]});
1744     linearIndex = b.create<AffineApplyOp>(
1745         loc, addMap, ValueRange{mulApplyOp, processorId[i]});
1746   }
1747 
1748   auto mulApplyOp = b.create<AffineApplyOp>(
1749       loc, mulMap, ValueRange{linearIndex, forOp.getStep()});
1750   Value lb = b.create<AffineApplyOp>(
1751       loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()});
1752   forOp.setLowerBound(lb);
1753 
1754   Value step = forOp.getStep();
1755   for (auto numProcs : numProcessors)
1756     step = b.create<AffineApplyOp>(loc, mulMap, ValueRange{numProcs, step});
1757   forOp.setStep(step);
1758 }
1759 
1760 /// Given a memref region, determine the lowest depth at which transfers can be
1761 /// placed for it, and return the corresponding block, start and end positions
1762 /// in the block for placing incoming (read) and outgoing (write) copies
1763 /// respectively. The lowest depth depends on whether the region being accessed
1764 /// is hoistable with respect to one or more immediately surrounding loops.
1765 static void
1766 findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
1767                              Block::iterator &begin, Block::iterator &end,
1768                              Block **copyPlacementBlock,
1769                              Block::iterator *copyInPlacementStart,
1770                              Block::iterator *copyOutPlacementStart) {
1771   const auto *cst = region.getConstraints();
1772   SmallVector<Value, 4> symbols;
1773   cst->getValues(cst->getNumDimVars(), cst->getNumDimAndSymbolVars(), &symbols);
1774 
1775   SmallVector<AffineForOp, 4> enclosingFors;
1776   getAffineForIVs(*block.begin(), &enclosingFors);
1777   // Walk up loop parents till we find an IV on which this region is
1778   // symbolic/variant.
1779   auto it = enclosingFors.rbegin();
1780   for (auto e = enclosingFors.rend(); it != e; ++it) {
1781     // TODO: also need to be checking this for regions symbols that
1782     // aren't loop IVs, whether we are within their resp. defs' dominance scope.
1783     if (llvm::is_contained(symbols, it->getInductionVar()))
1784       break;
1785   }
1786 
1787   if (it != enclosingFors.rbegin()) {
1788     auto lastInvariantIV = *std::prev(it);
1789     *copyInPlacementStart = Block::iterator(lastInvariantIV.getOperation());
1790     *copyOutPlacementStart = std::next(*copyInPlacementStart);
1791     *copyPlacementBlock = lastInvariantIV->getBlock();
1792   } else {
1793     *copyInPlacementStart = begin;
1794     *copyOutPlacementStart = end;
1795     *copyPlacementBlock = &block;
1796   }
1797 }
1798 
1799 // Info comprising stride and number of elements transferred every stride.
1800 struct StrideInfo {
1801   int64_t stride;
1802   int64_t numEltPerStride;
1803 };
1804 
1805 /// Returns striding information for a copy/transfer of this region with
1806 /// potentially multiple striding levels from outermost to innermost. For an
1807 /// n-dimensional region, there can be at most n-1 levels of striding
1808 /// successively nested.
1809 //  TODO: make this work with non-identity layout maps.
1810 static void getMultiLevelStrides(const MemRefRegion &region,
1811                                  ArrayRef<int64_t> bufferShape,
1812                                  SmallVectorImpl<StrideInfo> *strideInfos) {
1813   if (bufferShape.size() <= 1)
1814     return;
1815 
1816   int64_t numEltPerStride = 1;
1817   int64_t stride = 1;
1818   for (int d = bufferShape.size() - 1; d >= 1; d--) {
1819     int64_t dimSize = cast<MemRefType>(region.memref.getType()).getDimSize(d);
1820     stride *= dimSize;
1821     numEltPerStride *= bufferShape[d];
1822     // A stride is needed only if the region has a shorter extent than the
1823     // memref along the dimension *and* has an extent greater than one along the
1824     // next major dimension.
1825     if (bufferShape[d] < dimSize && bufferShape[d - 1] > 1) {
1826       strideInfos->push_back({stride, numEltPerStride});
1827     }
1828   }
1829 }
1830 
1831 /// Generates a point-wise copy from/to `memref' to/from `fastMemRef' and
1832 /// returns the outermost AffineForOp of the copy loop nest. `lbMaps` and
1833 /// `ubMaps` along with `lbOperands` and `ubOperands` hold the lower and upper
1834 /// bound information for the copy loop nest. `fastBufOffsets` contain the
1835 /// expressions to be subtracted out from the respective copy loop iterators in
1836 /// order to index the fast buffer. If `copyOut' is true, generates a copy-out;
1837 /// otherwise a copy-in. Builder `b` should be set to the point the copy nest is
1838 /// inserted.
1839 //
1840 /// The copy-in nest is generated as follows as an example for a 2-d region:
1841 /// for x = ...
1842 ///   for y = ...
1843 ///     fast_buf[x - offset_x][y - offset_y] = memref[x][y]
1844 ///
1845 static AffineForOp
1846 generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
1847                       ArrayRef<AffineMap> lbMaps, ArrayRef<Value> lbOperands,
1848                       ArrayRef<AffineMap> ubMaps, ArrayRef<Value> ubOperands,
1849                       ArrayRef<AffineExpr> fastBufOffsets, bool isCopyOut,
1850                       OpBuilder b) {
1851   assert(llvm::all_of(lbMaps, [&](AffineMap lbMap) {
1852     return lbMap.getNumInputs() == lbOperands.size();
1853   }));
1854   assert(llvm::all_of(ubMaps, [&](AffineMap ubMap) {
1855     return ubMap.getNumInputs() == ubOperands.size();
1856   }));
1857 
1858   unsigned rank = cast<MemRefType>(memref.getType()).getRank();
1859   assert(lbMaps.size() == rank && "wrong number of lb maps");
1860   assert(ubMaps.size() == rank && "wrong number of ub maps");
1861 
1862   SmallVector<Value, 4> memIndices;
1863   SmallVector<AffineExpr, 4> fastBufExprs;
1864   SmallVector<Value, 4> fastBufMapOperands;
1865   AffineForOp copyNestRoot;
1866   SmallVector<AffineApplyOp, 4> mayBeDeadApplys;
1867   for (unsigned d = 0; d < rank; ++d) {
1868     auto forOp = createCanonicalizedAffineForOp(b, loc, lbOperands, lbMaps[d],
1869                                                 ubOperands, ubMaps[d]);
1870     if (d == 0)
1871       copyNestRoot = forOp;
1872 
1873     b = OpBuilder::atBlockTerminator(forOp.getBody());
1874 
1875     auto fastBufOffsetMap =
1876         AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]);
1877     auto offset = b.create<AffineApplyOp>(loc, fastBufOffsetMap, lbOperands);
1878 
1879     // Construct the subscript for the fast memref being copied into/from:
1880     // x - offset_x.
1881     fastBufExprs.push_back(b.getAffineDimExpr(2 * d + 1) -
1882                            b.getAffineDimExpr(2 * d));
1883     fastBufMapOperands.push_back(offset);
1884     fastBufMapOperands.push_back(forOp.getInductionVar());
1885     mayBeDeadApplys.push_back(offset);
1886 
1887     // Subscript for the slow memref being copied.
1888     memIndices.push_back(forOp.getInductionVar());
1889   }
1890 
1891   auto fastBufMap =
1892       AffineMap::get(2 * rank, /*symbolCount=*/0, fastBufExprs, b.getContext());
1893   fullyComposeAffineMapAndOperands(&fastBufMap, &fastBufMapOperands);
1894   fastBufMap = simplifyAffineMap(fastBufMap);
1895   canonicalizeMapAndOperands(&fastBufMap, &fastBufMapOperands);
1896 
1897   // Drop any dead affine.applys.
1898   for (auto applyOp : mayBeDeadApplys)
1899     if (applyOp.use_empty())
1900       applyOp.erase();
1901 
1902   if (!isCopyOut) {
1903     // Copy in.
1904     auto load = b.create<AffineLoadOp>(loc, memref, memIndices);
1905     b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufMap,
1906                             fastBufMapOperands);
1907     return copyNestRoot;
1908   }
1909 
1910   // Copy out.
1911   auto load =
1912       b.create<AffineLoadOp>(loc, fastMemRef, fastBufMap, fastBufMapOperands);
1913   b.create<AffineStoreOp>(loc, load, memref, memIndices);
1914   return copyNestRoot;
1915 }
1916 
1917 static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
1918 emitRemarkForBlock(Block &block) {
1919   return block.getParentOp()->emitRemark();
1920 }
1921 
1922 /// Creates a buffer in the faster memory space for the specified memref region;
1923 /// generates a copy from the lower memory space to this one, and replaces all
1924 /// loads/stores in the block range [`begin', `end') of `block' to load/store
1925 /// from that buffer. Returns failure if copies could not be generated due to
1926 /// yet unimplemented cases. `copyInPlacementStart` and `copyOutPlacementStart`
1927 /// in copyPlacementBlock specify the insertion points where the incoming copies
1928 /// and outgoing copies, respectively, should be inserted (the insertion happens
1929 /// right before the insertion point). Since `begin` can itself be invalidated
1930 /// due to the memref rewriting done from this method, the output argument
1931 /// `nBegin` is set to its replacement (set to `begin` if no invalidation
1932 /// happens). Since outgoing copies could have  been inserted at `end`, the
1933 /// output argument `nEnd` is set to the new end. `sizeInBytes` is set to the
1934 /// size of the fast buffer allocated.
1935 static LogicalResult generateCopy(
1936     const MemRefRegion &region, Block *block, Block::iterator begin,
1937     Block::iterator end, Block *copyPlacementBlock,
1938     Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart,
1939     const AffineCopyOptions &copyOptions, DenseMap<Value, Value> &fastBufferMap,
1940     DenseSet<Operation *> &copyNests, uint64_t *sizeInBytes,
1941     Block::iterator *nBegin, Block::iterator *nEnd) {
1942   *nBegin = begin;
1943   *nEnd = end;
1944 
1945   auto f = begin->getParentOfType<FunctionOpInterface>();
1946   OpBuilder topBuilder(f.getFunctionBody());
1947   Value zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0);
1948 
1949   *sizeInBytes = 0;
1950 
1951   if (begin == end)
1952     return success();
1953 
1954   // Is the copy out point at the end of the block where we are doing
1955   // explicit copying.
1956   bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart);
1957 
1958   // Copies for read regions are going to be inserted at 'begin'.
1959   OpBuilder prologue(copyPlacementBlock, copyInPlacementStart);
1960   // Copies for write regions are going to be inserted at 'end'.
1961   OpBuilder epilogue(copyPlacementBlock, copyOutPlacementStart);
1962   OpBuilder &b = region.isWrite() ? epilogue : prologue;
1963 
1964   // Builder to create constants at the top level.
1965   auto func =
1966       copyPlacementBlock->getParent()->getParentOfType<FunctionOpInterface>();
1967   OpBuilder top(func.getFunctionBody());
1968 
1969   auto loc = region.loc;
1970   auto memref = region.memref;
1971   auto memRefType = cast<MemRefType>(memref.getType());
1972 
1973   if (!memRefType.getLayout().isIdentity()) {
1974     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
1975     return failure();
1976   }
1977 
1978   // Indices to use for the copying.
1979   // Indices for the original memref being copied from/to.
1980   SmallVector<Value, 4> memIndices;
1981   // Indices for the faster buffer being copied into/from.
1982   SmallVector<Value, 4> bufIndices;
1983 
1984   unsigned rank = memRefType.getRank();
1985   SmallVector<int64_t, 4> fastBufferShape;
1986 
1987   // Compute the extents of the buffer.
1988   std::vector<SmallVector<int64_t, 4>> lbs;
1989   SmallVector<int64_t, 8> lbDivisors;
1990   lbs.reserve(rank);
1991   std::optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(
1992       &fastBufferShape, &lbs, &lbDivisors);
1993   if (!numElements) {
1994     LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
1995     return failure();
1996   }
1997 
1998   if (*numElements == 0) {
1999     LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
2000     return success();
2001   }
2002 
2003   SmallVector<AffineMap, 4> lbMaps(rank), ubMaps(rank);
2004   for (unsigned i = 0; i < rank; ++i)
2005     region.getLowerAndUpperBound(i, lbMaps[i], ubMaps[i]);
2006 
2007   const FlatAffineValueConstraints *cst = region.getConstraints();
2008   // 'regionSymbols' hold values that this memory region is symbolic/parametric
2009   // on; these typically include loop IVs surrounding the level at which the
2010   // copy generation is being done or other valid symbols in MLIR.
2011   SmallVector<Value, 8> regionSymbols;
2012   cst->getValues(rank, cst->getNumVars(), &regionSymbols);
2013 
2014   // Construct the index expressions for the fast memory buffer. The index
2015   // expression for a particular dimension of the fast buffer is obtained by
2016   // subtracting out the lower bound on the original memref's data region
2017   // along the corresponding dimension.
2018 
2019   // Index start offsets for faster memory buffer relative to the original.
2020   SmallVector<AffineExpr, 4> fastBufOffsets;
2021   fastBufOffsets.reserve(rank);
2022   for (unsigned d = 0; d < rank; d++) {
2023     assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
2024 
2025     AffineExpr offset = top.getAffineConstantExpr(0);
2026     for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++)
2027       offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
2028     assert(lbDivisors[d] > 0);
2029     offset =
2030         (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
2031 
2032     // Set copy start location for this dimension in the lower memory space
2033     // memref.
2034     if (auto caf = dyn_cast<AffineConstantExpr>(offset)) {
2035       auto indexVal = caf.getValue();
2036       if (indexVal == 0) {
2037         memIndices.push_back(zeroIndex);
2038       } else {
2039         memIndices.push_back(
2040             top.create<arith::ConstantIndexOp>(loc, indexVal).getResult());
2041       }
2042     } else {
2043       // The coordinate for the start location is just the lower bound along the
2044       // corresponding dimension on the memory region (stored in 'offset').
2045       auto map = AffineMap::get(
2046           cst->getNumDimVars() + cst->getNumSymbolVars() - rank, 0, offset);
2047       memIndices.push_back(b.create<AffineApplyOp>(loc, map, regionSymbols));
2048     }
2049     // The fast buffer is copied into at location zero; addressing is relative.
2050     bufIndices.push_back(zeroIndex);
2051 
2052     // Record the offsets since they are needed to remap the memory accesses of
2053     // the original memref further below.
2054     fastBufOffsets.push_back(offset);
2055   }
2056 
2057   // The faster memory space buffer.
2058   Value fastMemRef;
2059 
2060   // Check if a buffer was already created.
2061   bool existingBuf = fastBufferMap.count(memref) > 0;
2062   if (!existingBuf) {
2063     AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank);
2064     auto fastMemRefType =
2065         MemRefType::get(fastBufferShape, memRefType.getElementType(),
2066                         fastBufferLayout, copyOptions.fastMemorySpace);
2067 
2068     // Create the fast memory space buffer just before the 'affine.for'
2069     // operation.
2070     fastMemRef =
2071         prologue.create<memref::AllocOp>(loc, fastMemRefType).getResult();
2072     // Record it.
2073     fastBufferMap[memref] = fastMemRef;
2074     // fastMemRefType is a constant shaped memref.
2075     auto maySizeInBytes = getIntOrFloatMemRefSizeInBytes(fastMemRefType);
2076     // We don't account for things of unknown size.
2077     *sizeInBytes = maySizeInBytes.value_or(0);
2078 
2079     LLVM_DEBUG(emitRemarkForBlock(*block)
2080                << "Creating fast buffer of type " << fastMemRefType
2081                << " and size " << llvm::divideCeil(*sizeInBytes, 1024)
2082                << " KiB\n");
2083   } else {
2084     // Reuse the one already created.
2085     fastMemRef = fastBufferMap[memref];
2086   }
2087 
2088   auto numElementsSSA = top.create<arith::ConstantIndexOp>(loc, *numElements);
2089 
2090   Value dmaStride;
2091   Value numEltPerDmaStride;
2092   if (copyOptions.generateDma) {
2093     SmallVector<StrideInfo, 4> dmaStrideInfos;
2094     getMultiLevelStrides(region, fastBufferShape, &dmaStrideInfos);
2095 
2096     // TODO: use all stride levels once DmaStartOp is extended for
2097     // multi-level strides.
2098     if (dmaStrideInfos.size() > 1) {
2099       LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
2100       return failure();
2101     }
2102 
2103     if (!dmaStrideInfos.empty()) {
2104       dmaStride =
2105           top.create<arith::ConstantIndexOp>(loc, dmaStrideInfos[0].stride);
2106       numEltPerDmaStride = top.create<arith::ConstantIndexOp>(
2107           loc, dmaStrideInfos[0].numEltPerStride);
2108     }
2109   }
2110 
2111   // Record the last operation where we want the memref replacement to end. We
2112   // later do the memref replacement only in [begin, postDomFilter] so
2113   // that the original memref's used in the data movement code themselves don't
2114   // get replaced.
2115   auto postDomFilter = std::prev(end);
2116 
2117   // Create fully composed affine maps for each memref.
2118   auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
2119   fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
2120   auto bufAffineMap = b.getMultiDimIdentityMap(bufIndices.size());
2121   fullyComposeAffineMapAndOperands(&bufAffineMap, &bufIndices);
2122 
2123   if (!copyOptions.generateDma) {
2124     // Point-wise copy generation.
2125     auto copyNest =
2126         generatePointWiseCopy(loc, memref, fastMemRef, lbMaps,
2127                               /*lbOperands=*/regionSymbols, ubMaps,
2128                               /*ubOperands=*/regionSymbols, fastBufOffsets,
2129                               /*isCopyOut=*/region.isWrite(), b);
2130 
2131     // Record this so that we can skip it from yet another copy.
2132     copyNests.insert(copyNest);
2133 
2134     // Since new ops are being appended (for copy out's), adjust the end to
2135     // mark end of block range being processed if necessary.
2136     if (region.isWrite() && isCopyOutAtEndOfBlock)
2137       *nEnd = Block::iterator(copyNest.getOperation());
2138   } else {
2139     // DMA generation.
2140     // Create a tag (single element 1-d memref) for the DMA.
2141     auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {},
2142                                          copyOptions.tagMemorySpace);
2143     auto tagMemRef = prologue.create<memref::AllocOp>(loc, tagMemRefType);
2144 
2145     SmallVector<Value, 4> tagIndices({zeroIndex});
2146     auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
2147     fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
2148     if (!region.isWrite()) {
2149       // DMA non-blocking read from original buffer to fast buffer.
2150       b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
2151                                  fastMemRef, bufAffineMap, bufIndices,
2152                                  tagMemRef, tagAffineMap, tagIndices,
2153                                  numElementsSSA, dmaStride, numEltPerDmaStride);
2154     } else {
2155       // DMA non-blocking write from fast buffer to the original memref.
2156       auto op = b.create<AffineDmaStartOp>(
2157           loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
2158           memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
2159           dmaStride, numEltPerDmaStride);
2160       // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
2161       // end to mark end of block range being processed.
2162       if (isCopyOutAtEndOfBlock)
2163         *nEnd = Block::iterator(op.getOperation());
2164     }
2165 
2166     // Matching DMA wait to block on completion; tag always has a 0 index.
2167     b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex,
2168                               numElementsSSA);
2169 
2170     // Generate dealloc for the tag.
2171     auto tagDeallocOp = epilogue.create<memref::DeallocOp>(loc, tagMemRef);
2172     if (*nEnd == end && isCopyOutAtEndOfBlock)
2173       // Since new ops are being appended (for outgoing DMAs), adjust the end to
2174       // mark end of range of the original.
2175       *nEnd = Block::iterator(tagDeallocOp.getOperation());
2176   }
2177 
2178   // Generate dealloc for the buffer.
2179   if (!existingBuf) {
2180     auto bufDeallocOp = epilogue.create<memref::DeallocOp>(loc, fastMemRef);
2181     // When generating pointwise copies, `nEnd' has to be set to deallocOp on
2182     // the fast buffer (since it marks the new end insertion point).
2183     if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock)
2184       *nEnd = Block::iterator(bufDeallocOp.getOperation());
2185   }
2186 
2187   // Replace all uses of the old memref with the faster one while remapping
2188   // access indices (subtracting out lower bound offsets for each dimension).
2189   // Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT],
2190   // index remap will be (%i, %j) -> (%i - %iT, %j - %jT),
2191   // i.e., affine.apply (d0, d1, d2, d3) -> (d2-d0, d3-d1) (%iT, %jT, %i, %j),
2192   // and (%iT, %jT) will be the 'extraOperands' for 'rep all memref uses with'.
2193   // d2, d3 correspond to the original indices (%i, %j).
2194   SmallVector<AffineExpr, 4> remapExprs;
2195   remapExprs.reserve(rank);
2196   for (unsigned i = 0; i < rank; i++) {
2197     // The starting operands of indexRemap will be regionSymbols (the symbols on
2198     // which the memref region is parametric); then those corresponding to
2199     // the memref's original indices follow.
2200     auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i);
2201     remapExprs.push_back(dimExpr - fastBufOffsets[i]);
2202   }
2203   auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs,
2204                                    b.getContext());
2205 
2206   // Record the begin since it may be invalidated by memref replacement.
2207   Block::iterator prevOfBegin;
2208   bool isBeginAtStartOfBlock = (begin == block->begin());
2209   if (!isBeginAtStartOfBlock)
2210     prevOfBegin = std::prev(begin);
2211 
2212   // *Only* those uses within the range [begin, end) of 'block' are replaced.
2213   (void)replaceAllMemRefUsesWith(memref, fastMemRef,
2214                                  /*extraIndices=*/{}, indexRemap,
2215                                  /*extraOperands=*/regionSymbols,
2216                                  /*symbolOperands=*/{},
2217                                  /*domOpFilter=*/&*begin,
2218                                  /*postDomOpFilter=*/&*postDomFilter);
2219 
2220   *nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);
2221 
2222   return success();
2223 }
2224 
2225 /// Construct the memref region to just include the entire memref. Returns false
2226 /// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
2227 /// enclosing loop IVs of `op` (starting from the outermost) that the region
2228 /// is parametric on.
2229 static bool getFullMemRefAsRegion(Operation *op, unsigned numParamLoopIVs,
2230                                   MemRefRegion *region) {
2231   unsigned rank;
2232   if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
2233     rank = loadOp.getMemRefType().getRank();
2234     region->memref = loadOp.getMemRef();
2235     region->setWrite(false);
2236   } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
2237     rank = storeOp.getMemRefType().getRank();
2238     region->memref = storeOp.getMemRef();
2239     region->setWrite(true);
2240   } else {
2241     assert(false && "expected load or store op");
2242     return false;
2243   }
2244   auto memRefType = cast<MemRefType>(region->memref.getType());
2245   if (!memRefType.hasStaticShape())
2246     return false;
2247 
2248   auto *regionCst = region->getConstraints();
2249 
2250   // Just get the first numSymbols IVs, which the memref region is parametric
2251   // on.
2252   SmallVector<AffineForOp, 4> ivs;
2253   getAffineForIVs(*op, &ivs);
2254   ivs.resize(numParamLoopIVs);
2255   SmallVector<Value, 4> symbols;
2256   extractForInductionVars(ivs, &symbols);
2257   *regionCst = FlatAffineValueConstraints(rank, numParamLoopIVs, 0);
2258   regionCst->setValues(rank, rank + numParamLoopIVs, symbols);
2259 
2260   // Memref dim sizes provide the bounds.
2261   for (unsigned d = 0; d < rank; d++) {
2262     auto dimSize = memRefType.getDimSize(d);
2263     assert(dimSize > 0 && "filtered dynamic shapes above");
2264     regionCst->addBound(BoundType::LB, d, 0);
2265     regionCst->addBound(BoundType::UB, d, dimSize - 1);
2266   }
2267   return true;
2268 }
2269 
2270 LogicalResult
2271 mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
2272                                      const AffineCopyOptions &copyOptions,
2273                                      std::optional<Value> filterMemRef,
2274                                      DenseSet<Operation *> &copyNests) {
2275   if (begin == end)
2276     return success();
2277 
2278   assert(begin->getBlock() == std::prev(end)->getBlock() &&
2279          "Inconsistent block begin/end args");
2280   assert(end != end->getBlock()->end() && "end can't be the block terminator");
2281 
2282   Block *block = begin->getBlock();
2283 
2284   // Copies will be generated for this depth, i.e., symbolic in all loops
2285   // surrounding the this block range.
2286   unsigned copyDepth = getNestingDepth(&*begin);
2287 
2288   LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
2289                           << "\n");
2290   LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n");
2291   LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n");
2292 
2293   // List of memory regions to copy for. We need a map vector to have a
2294   // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
2295   // since the alloc's for example are identical except for the SSA id.
2296   SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> readRegions;
2297   SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> writeRegions;
2298 
2299   // Map from original memref's to the fast buffers that their accesses are
2300   // replaced with.
2301   DenseMap<Value, Value> fastBufferMap;
2302 
2303   // To check for errors when walking the block.
2304   bool error = false;
2305 
2306   // Walk this range of operations  to gather all memory regions.
2307   block->walk(begin, end, [&](Operation *opInst) {
2308     Value memref;
2309     MemRefType memrefType;
2310     // Gather regions to allocate to buffers in faster memory space.
2311     if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
2312       memref = loadOp.getMemRef();
2313       memrefType = loadOp.getMemRefType();
2314     } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
2315       memref = storeOp.getMemRef();
2316       memrefType = storeOp.getMemRefType();
2317     }
2318     // Neither load nor a store op.
2319     if (!memref)
2320       return;
2321 
2322     auto memorySpaceAttr =
2323         dyn_cast_or_null<IntegerAttr>(memrefType.getMemorySpace());
2324     if ((filterMemRef.has_value() && filterMemRef != memref) ||
2325         (memorySpaceAttr &&
2326          memrefType.getMemorySpaceAsInt() != copyOptions.slowMemorySpace))
2327       return;
2328 
2329     // Compute the MemRefRegion accessed.
2330     auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
2331     if (failed(region->compute(opInst, copyDepth, /*sliceState=*/nullptr,
2332                                /*addMemRefDimBounds=*/false))) {
2333       LLVM_DEBUG(llvm::dbgs()
2334                  << "Error obtaining memory region: semi-affine maps?\n");
2335       LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
2336       if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
2337         LLVM_DEBUG(
2338             opInst->emitError("non-constant memref sizes not yet supported"));
2339         error = true;
2340         return;
2341       }
2342     }
2343 
2344     // Each memref has a single buffer associated with it irrespective of how
2345     // many load's and store's happen on it.
2346     // TODO: in the future, when regions don't intersect and satisfy
2347     // other properties (based on load/store regions), we could consider
2348     // multiple buffers per memref.
2349 
2350     // Add to the appropriate region if it's not already in it, or take a
2351     // bounding box union with the existing one if it's already in there.
2352     // Note that a memref may have both read and write regions - so update the
2353     // region in the other list if one exists (write in case of read and vice
2354     // versa) since there is a single bounding box for a memref across all reads
2355     // and writes that happen on it.
2356 
2357     // Attempts to update; returns true if 'region' exists in targetRegions.
2358     auto updateRegion =
2359         [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
2360                 &targetRegions) {
2361           const auto *const it = targetRegions.find(region->memref);
2362           if (it == targetRegions.end())
2363             return false;
2364 
2365           // Perform a union with the existing region.
2366           if (failed(it->second->unionBoundingBox(*region))) {
2367             LLVM_DEBUG(llvm::dbgs()
2368                        << "Memory region bounding box failed; "
2369                           "over-approximating to the entire memref\n");
2370             // If the union fails, we will overapproximate.
2371             if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
2372               LLVM_DEBUG(opInst->emitError(
2373                   "non-constant memref sizes not yet supported"));
2374               error = true;
2375               return true;
2376             }
2377             it->second->getConstraints()->clearAndCopyFrom(
2378                 *region->getConstraints());
2379           } else {
2380             // Union was computed and stored in 'it->second': copy to 'region'.
2381             region->getConstraints()->clearAndCopyFrom(
2382                 *it->second->getConstraints());
2383           }
2384           return true;
2385         };
2386 
2387     bool existsInRead = updateRegion(readRegions);
2388     if (error)
2389       return;
2390     bool existsInWrite = updateRegion(writeRegions);
2391     if (error)
2392       return;
2393 
2394     // Finally add it to the region list.
2395     if (region->isWrite() && !existsInWrite) {
2396       writeRegions[region->memref] = std::move(region);
2397     } else if (!region->isWrite() && !existsInRead) {
2398       readRegions[region->memref] = std::move(region);
2399     }
2400   });
2401 
2402   if (error) {
2403     LLVM_DEBUG(begin->emitError(
2404         "copy generation failed for one or more memref's in this block\n"));
2405     return failure();
2406   }
2407 
2408   uint64_t totalCopyBuffersSizeInBytes = 0;
2409   bool ret = true;
2410   auto processRegions =
2411       [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
2412               &regions) {
2413         for (const auto &regionEntry : regions) {
2414           // For each region, hoist copy in/out past all hoistable
2415           // 'affine.for's.
2416           Block::iterator copyInPlacementStart, copyOutPlacementStart;
2417           Block *copyPlacementBlock;
2418           findHighestBlockForPlacement(
2419               *regionEntry.second, *block, begin, end, &copyPlacementBlock,
2420               &copyInPlacementStart, &copyOutPlacementStart);
2421 
2422           uint64_t sizeInBytes;
2423           Block::iterator nBegin, nEnd;
2424           LogicalResult iRet = generateCopy(
2425               *regionEntry.second, block, begin, end, copyPlacementBlock,
2426               copyInPlacementStart, copyOutPlacementStart, copyOptions,
2427               fastBufferMap, copyNests, &sizeInBytes, &nBegin, &nEnd);
2428           if (succeeded(iRet)) {
2429             // begin/end could have been invalidated, and need update.
2430             begin = nBegin;
2431             end = nEnd;
2432             totalCopyBuffersSizeInBytes += sizeInBytes;
2433           }
2434           ret = ret & succeeded(iRet);
2435         }
2436       };
2437   processRegions(readRegions);
2438   processRegions(writeRegions);
2439 
2440   if (!ret) {
2441     LLVM_DEBUG(begin->emitError(
2442         "copy generation failed for one or more memref's in this block\n"));
2443     return failure();
2444   }
2445 
2446   // For a range of operations, a note will be emitted at the caller.
2447   AffineForOp forOp;
2448   if (llvm::DebugFlag && (forOp = dyn_cast<AffineForOp>(&*begin))) {
2449     LLVM_DEBUG(forOp.emitRemark()
2450                << llvm::divideCeil(totalCopyBuffersSizeInBytes, 1024)
2451                << " KiB of copy buffers in fast memory space for this block");
2452   }
2453 
2454   if (totalCopyBuffersSizeInBytes > copyOptions.fastMemCapacityBytes) {
2455     block->getParentOp()->emitWarning(
2456         "total size of all copy buffers' for this block exceeds fast memory "
2457         "capacity");
2458   }
2459 
2460   return success();
2461 }
2462 
2463 // A convenience version of affineDataCopyGenerate for all ops in the body of
2464 // an AffineForOp.
2465 LogicalResult mlir::affine::affineDataCopyGenerate(
2466     AffineForOp forOp, const AffineCopyOptions &copyOptions,
2467     std::optional<Value> filterMemRef, DenseSet<Operation *> &copyNests) {
2468   return affineDataCopyGenerate(forOp.getBody()->begin(),
2469                                 std::prev(forOp.getBody()->end()), copyOptions,
2470                                 filterMemRef, copyNests);
2471 }
2472 
2473 LogicalResult mlir::affine::generateCopyForMemRegion(
2474     const MemRefRegion &memrefRegion, Operation *analyzedOp,
2475     const AffineCopyOptions &copyOptions, CopyGenerateResult &result) {
2476   Block *block = analyzedOp->getBlock();
2477   auto begin = analyzedOp->getIterator();
2478   auto end = std::next(begin);
2479   DenseMap<Value, Value> fastBufferMap;
2480   DenseSet<Operation *> copyNests;
2481 
2482   auto err = generateCopy(memrefRegion, block, begin, end, block, begin, end,
2483                           copyOptions, fastBufferMap, copyNests,
2484                           &result.sizeInBytes, &begin, &end);
2485   if (failed(err))
2486     return err;
2487 
2488   const auto &en = fastBufferMap.find(memrefRegion.memref);
2489   // In some cases (empty loops), no copy generation would have happened.
2490   if (en == fastBufferMap.end())
2491     return failure();
2492   result.alloc = en->second.getDefiningOp();
2493   assert(result.alloc && "fast buffer expected to be locally allocated");
2494   assert(copyNests.size() <= 1 && "At most one copy nest is expected.");
2495   result.copyNest = copyNests.empty() ? nullptr : *copyNests.begin();
2496   return success();
2497 }
2498 
2499 /// Gathers all AffineForOps in 'block' at 'currLoopDepth' in 'depthToLoops'.
2500 static void
2501 gatherLoopsInBlock(Block *block, unsigned currLoopDepth,
2502                    std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
2503   // Add a new empty level to output if it doesn't exist level already.
2504   assert(currLoopDepth <= depthToLoops.size() && "Unexpected currLoopDepth");
2505   if (currLoopDepth == depthToLoops.size())
2506     depthToLoops.emplace_back();
2507 
2508   for (auto &op : *block) {
2509     if (auto forOp = dyn_cast<AffineForOp>(op)) {
2510       depthToLoops[currLoopDepth].push_back(forOp);
2511       gatherLoopsInBlock(forOp.getBody(), currLoopDepth + 1, depthToLoops);
2512     }
2513   }
2514 }
2515 
2516 /// Gathers all AffineForOps in 'func.func' grouped by loop depth.
2517 void mlir::affine::gatherLoops(
2518     func::FuncOp func, std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
2519   for (auto &block : func)
2520     gatherLoopsInBlock(&block, /*currLoopDepth=*/0, depthToLoops);
2521 
2522   // Remove last loop level from output since it's empty.
2523   if (!depthToLoops.empty()) {
2524     assert(depthToLoops.back().empty() && "Last loop level is not empty?");
2525     depthToLoops.pop_back();
2526   }
2527 }
2528 
2529 AffineForOp mlir::affine::createCanonicalizedAffineForOp(
2530     OpBuilder b, Location loc, ValueRange lbOperands, AffineMap lbMap,
2531     ValueRange ubOperands, AffineMap ubMap, int64_t step) {
2532   SmallVector<Value, 4> lowerOperands(lbOperands);
2533   SmallVector<Value, 4> upperOperands(ubOperands);
2534 
2535   fullyComposeAffineMapAndOperands(&lbMap, &lowerOperands);
2536   canonicalizeMapAndOperands(&lbMap, &lowerOperands);
2537   lbMap = removeDuplicateExprs(lbMap);
2538   fullyComposeAffineMapAndOperands(&ubMap, &upperOperands);
2539   canonicalizeMapAndOperands(&ubMap, &upperOperands);
2540   ubMap = removeDuplicateExprs(ubMap);
2541 
2542   return b.create<AffineForOp>(loc, lowerOperands, lbMap, upperOperands, ubMap,
2543                                step);
2544 }
2545 
2546 /// Creates an AffineIfOp that encodes the conditional to choose between
2547 /// the constant trip count version and an unknown trip count version of this
2548 /// nest of loops. This is used to separate partial and full tiles if `loops`
2549 /// has the intra-tile loops. The affine.if op is inserted at the builder
2550 /// insertion point of `b`.
2551 static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
2552                                             OpBuilder b) {
2553   if (loops.empty())
2554     return nullptr;
2555 
2556   auto *context = loops[0].getContext();
2557 
2558   FlatAffineValueConstraints cst;
2559   SmallVector<Operation *, 8> ops;
2560   llvm::append_range(ops, loops);
2561   (void)getIndexSet(ops, &cst);
2562 
2563   // Remove constraints that are independent of these loop IVs.
2564   cst.removeIndependentConstraints(/*pos=*/0, /*num=*/loops.size());
2565 
2566   // Construct the constraint set representing the guard for full tiles. The
2567   // lower bound (and upper bound) corresponding to the full tile should be
2568   // larger (and resp. smaller) than any other lower (or upper bound).
2569   SmallVector<int64_t, 8> fullTileLb, fullTileUb;
2570   for (auto loop : loops) {
2571     (void)loop;
2572     // TODO: Non-unit stride is not an issue to generalize to.
2573     assert(loop.getStepAsInt() == 1 && "point loop step expected to be one");
2574     // Mark everything symbols for the purpose of finding a constant diff pair.
2575     cst.setDimSymbolSeparation(/*newSymbolCount=*/cst.getNumDimAndSymbolVars() -
2576                                1);
2577     unsigned fullTileLbPos, fullTileUbPos;
2578     if (!cst.getConstantBoundOnDimSize(0, /*lb=*/nullptr,
2579                                        /*boundFloorDivisor=*/nullptr,
2580                                        /*ub=*/nullptr, &fullTileLbPos,
2581                                        &fullTileUbPos)) {
2582       LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n");
2583       return nullptr;
2584     }
2585 
2586     SmallVector<unsigned, 4> lbIndices, ubIndices;
2587     cst.getLowerAndUpperBoundIndices(/*pos=*/0, &lbIndices, &ubIndices);
2588 
2589     auto fLb = cst.getInequality(fullTileLbPos);
2590     auto fUb = cst.getInequality(fullTileUbPos);
2591     fullTileLb.assign(fLb.begin(), fLb.end());
2592     fullTileUb.assign(fUb.begin(), fUb.end());
2593 
2594     // Full tile lower bound should be >= than any other lower bound.
2595     for (auto lbIndex : lbIndices)
2596       for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i)
2597         cst.atIneq(lbIndex, i) = fullTileLb[i] - cst.atIneq(lbIndex, i);
2598 
2599     // Full tile upper bound should be <= any other upper bound.
2600     for (auto ubIndex : ubIndices)
2601       for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i)
2602         cst.atIneq(ubIndex, i) -= fullTileUb[i];
2603 
2604     cst.removeVar(0);
2605   }
2606 
2607   // The previous step leads to all zeros for the full tile lb and ub position
2608   // itself; remove those and any other duplicates / trivial redundancies.
2609   cst.removeTrivialRedundancy();
2610 
2611   // Turn everything into dims conservatively since we earlier turned all
2612   // trailing ids past point loop IV into symbols. Some of these could be outer
2613   // loop IVs; we'll canonicalize anyway.
2614   cst.setDimSymbolSeparation(0);
2615 
2616   IntegerSet ifCondSet = cst.getAsIntegerSet(context);
2617   // ifCondSet can be null if cst was empty -- this can happen if all loops
2618   // in the nest have constant trip counts.
2619   if (!ifCondSet)
2620     return nullptr;
2621 
2622   SmallVector<Value, 4> setOperands;
2623   cst.getValues(0, cst.getNumDimAndSymbolVars(), &setOperands);
2624   canonicalizeSetAndOperands(&ifCondSet, &setOperands);
2625   return b.create<AffineIfOp>(loops[0].getLoc(), ifCondSet, setOperands,
2626                               /*withElseRegion=*/true);
2627 }
2628 
2629 /// Create the full tile loop nest (along with its body).
2630 static LogicalResult
2631 createFullTiles(MutableArrayRef<AffineForOp> inputNest,
2632                 SmallVectorImpl<AffineForOp> &fullTileLoops, OpBuilder b) {
2633   fullTileLoops.reserve(inputNest.size());
2634 
2635   // For each loop in the original nest identify a lower/upper bound pair such
2636   // that their difference is a constant.
2637   FlatAffineValueConstraints cst;
2638   for (auto loop : inputNest) {
2639     // TODO: straightforward to generalize to a non-unit stride.
2640     if (loop.getStepAsInt() != 1) {
2641       LLVM_DEBUG(llvm::dbgs()
2642                  << "[tile separation] non-unit stride not implemented\n");
2643       return failure();
2644     }
2645     SmallVector<Operation *, 1> loopOp{loop.getOperation()};
2646     (void)getIndexSet(loopOp, &cst);
2647     // We will mark everything other than this loop IV as symbol for getting a
2648     // pair of <lb, ub> with a constant difference.
2649     cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - 1);
2650     unsigned lbPos, ubPos;
2651     if (!cst.getConstantBoundOnDimSize(/*pos=*/0, /*lb=*/nullptr,
2652                                        /*boundFloorDivisor=*/nullptr,
2653                                        /*ub=*/nullptr, &lbPos, &ubPos) ||
2654         lbPos == ubPos) {
2655       LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / "
2656                                  "equalities not yet handled\n");
2657       return failure();
2658     }
2659 
2660     // Set all variables as dimensions uniformly since some of those marked as
2661     // symbols above could be outer loop IVs (corresponding tile space IVs).
2662     cst.setDimSymbolSeparation(/*newSymbolCount=*/0);
2663 
2664     AffineValueMap lbVmap, ubVmap;
2665     cst.getIneqAsAffineValueMap(/*pos=*/0, lbPos, lbVmap, b.getContext());
2666     cst.getIneqAsAffineValueMap(/*pos=*/0, ubPos, ubVmap, b.getContext());
2667     AffineForOp fullTileLoop = createCanonicalizedAffineForOp(
2668         b, loop.getLoc(), lbVmap.getOperands(), lbVmap.getAffineMap(),
2669         ubVmap.getOperands(), ubVmap.getAffineMap());
2670     b = OpBuilder::atBlockTerminator(fullTileLoop.getBody());
2671     fullTileLoops.push_back(fullTileLoop);
2672   }
2673 
2674   // Add the body for the full tile loop nest.
2675   IRMapping operandMap;
2676   for (const auto &loopEn : llvm::enumerate(inputNest))
2677     operandMap.map(loopEn.value().getInductionVar(),
2678                    fullTileLoops[loopEn.index()].getInductionVar());
2679   b = OpBuilder::atBlockTerminator(fullTileLoops.back().getBody());
2680   for (auto &op : inputNest.back().getBody()->without_terminator())
2681     b.clone(op, operandMap);
2682   return success();
2683 }
2684 
2685 LogicalResult
2686 mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
2687                                 SmallVectorImpl<AffineForOp> *fullTileNest) {
2688   if (inputNest.empty())
2689     return success();
2690 
2691   auto firstLoop = inputNest[0];
2692 
2693   // Each successive for op has to be nested in the other.
2694   auto prevLoop = firstLoop;
2695   for (auto loop : inputNest.drop_front(1)) {
2696     assert(loop->getParentOp() == prevLoop && "input not contiguously nested");
2697     prevLoop = loop;
2698   }
2699 
2700   // Create the full tile loop nest.
2701   SmallVector<AffineForOp, 4> fullTileLoops;
2702   OpBuilder b(firstLoop);
2703   if (failed(createFullTiles(inputNest, fullTileLoops, b))) {
2704     if (!fullTileLoops.empty())
2705       fullTileLoops.front().erase();
2706     return failure();
2707   }
2708 
2709   // Create and insert the version select right before the root of the nest.
2710   b = OpBuilder(firstLoop);
2711   AffineIfOp ifOp = createSeparationCondition(inputNest, b);
2712   if (!ifOp) {
2713     fullTileLoops.front().erase();
2714     LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating "
2715                                "separation condition\n");
2716     return failure();
2717   }
2718 
2719   // Move the full tile into the then block.
2720   Block *thenBlock = ifOp.getThenBlock();
2721   AffineForOp outermostFullTileLoop = fullTileLoops[0];
2722   thenBlock->getOperations().splice(
2723       std::prev(thenBlock->end()),
2724       outermostFullTileLoop->getBlock()->getOperations(),
2725       Block::iterator(outermostFullTileLoop));
2726 
2727   // Move the partial tile into the else block. The partial tile is the same as
2728   // the original loop nest.
2729   Block *elseBlock = ifOp.getElseBlock();
2730   elseBlock->getOperations().splice(std::prev(elseBlock->end()),
2731                                     firstLoop->getBlock()->getOperations(),
2732                                     Block::iterator(firstLoop));
2733 
2734   if (fullTileNest)
2735     *fullTileNest = std::move(fullTileLoops);
2736 
2737   return success();
2738 }
2739 
2740 LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
2741   LogicalResult result(failure());
2742   SmallVector<AffineForOp> loops;
2743   getPerfectlyNestedLoops(loops, op);
2744   if (loops.size() <= 1)
2745     return success();
2746 
2747   // Look for a band of loops that can be coalesced, i.e. perfectly nested
2748   // loops with bounds defined above some loop.
2749   // 1. For each loop, find above which parent loop its operands are
2750   // defined.
2751   SmallVector<unsigned> operandsDefinedAbove(loops.size());
2752   for (unsigned i = 0, e = loops.size(); i < e; ++i) {
2753     operandsDefinedAbove[i] = i;
2754     for (unsigned j = 0; j < i; ++j) {
2755       if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
2756         operandsDefinedAbove[i] = j;
2757         break;
2758       }
2759     }
2760   }
2761 
2762   // 2. Identify bands of loops such that the operands of all of them are
2763   // defined above the first loop in the band.  Traverse the nest bottom-up
2764   // so that modifications don't invalidate the inner loops.
2765   for (unsigned end = loops.size(); end > 0; --end) {
2766     unsigned start = 0;
2767     for (; start < end - 1; ++start) {
2768       auto maxPos =
2769           *std::max_element(std::next(operandsDefinedAbove.begin(), start),
2770                             std::next(operandsDefinedAbove.begin(), end));
2771       if (maxPos > start)
2772         continue;
2773       assert(maxPos == start &&
2774              "expected loop bounds to be known at the start of the band");
2775       auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
2776       if (succeeded(coalesceLoops(band)))
2777         result = success();
2778       break;
2779     }
2780     // If a band was found and transformed, keep looking at the loops above
2781     // the outermost transformed loop.
2782     if (start != end - 1)
2783       end = start + 1;
2784   }
2785   return result;
2786 }
2787 
2788 int64_t mlir::affine::numEnclosingInvariantLoops(OpOperand &operand) {
2789   int64_t count = 0;
2790   Operation *currentOp = operand.getOwner();
2791   while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
2792     if (!loopOp.isDefinedOutsideOfLoop(operand.get()))
2793       break;
2794     currentOp = loopOp;
2795     count++;
2796   }
2797   return count;
2798 }
2799