xref: /llvm-project/mlir/lib/Dialect/Affine/Utils/Utils.cpp (revision 2ec27848c00cda734697619047e640eadb254555)
1 //===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous transformation utilities for the Affine
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/Utils.h"
15 
16 #include "mlir/Dialect/Affine/Analysis/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Arith/Utils/Utils.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/IR/AffineExprVisitor.h"
25 #include "mlir/IR/Dominance.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/IR/IntegerSet.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "llvm/Support/LogicalResult.h"
31 #include <optional>
32 
33 #define DEBUG_TYPE "affine-utils"
34 
35 using namespace mlir;
36 using namespace affine;
37 using namespace presburger;
38 
39 namespace {
40 /// Visit affine expressions recursively and build the sequence of operations
41 /// that correspond to it.  Visitation functions return an Value of the
42 /// expression subtree they visited or `nullptr` on error.
43 class AffineApplyExpander
44     : public AffineExprVisitor<AffineApplyExpander, Value> {
45 public:
46   /// This internal class expects arguments to be non-null, checks must be
47   /// performed at the call site.
48   AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
49                       ValueRange symbolValues, Location loc)
50       : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
51         loc(loc) {}
52 
53   template <typename OpTy>
54   Value buildBinaryExpr(AffineBinaryOpExpr expr,
55                         arith::IntegerOverflowFlags overflowFlags =
56                             arith::IntegerOverflowFlags::none) {
57     auto lhs = visit(expr.getLHS());
58     auto rhs = visit(expr.getRHS());
59     if (!lhs || !rhs)
60       return nullptr;
61     auto op = builder.create<OpTy>(loc, lhs, rhs, overflowFlags);
62     return op.getResult();
63   }
64 
65   Value visitAddExpr(AffineBinaryOpExpr expr) {
66     return buildBinaryExpr<arith::AddIOp>(expr);
67   }
68 
69   Value visitMulExpr(AffineBinaryOpExpr expr) {
70     return buildBinaryExpr<arith::MulIOp>(expr,
71                                           arith::IntegerOverflowFlags::nsw);
72   }
73 
74   /// Euclidean modulo operation: negative RHS is not allowed.
75   /// Remainder of the euclidean integer division is always non-negative.
76   ///
77   /// Implemented as
78   ///
79   ///     a mod b =
80   ///         let remainder = srem a, b;
81   ///             negative = a < 0 in
82   ///         select negative, remainder + b, remainder.
83   Value visitModExpr(AffineBinaryOpExpr expr) {
84     if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
85       if (rhsConst.getValue() <= 0) {
86         emitError(loc, "modulo by non-positive value is not supported");
87         return nullptr;
88       }
89     }
90 
91     auto lhs = visit(expr.getLHS());
92     auto rhs = visit(expr.getRHS());
93     assert(lhs && rhs && "unexpected affine expr lowering failure");
94 
95     Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
96     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
97     Value isRemainderNegative = builder.create<arith::CmpIOp>(
98         loc, arith::CmpIPredicate::slt, remainder, zeroCst);
99     Value correctedRemainder =
100         builder.create<arith::AddIOp>(loc, remainder, rhs);
101     Value result = builder.create<arith::SelectOp>(
102         loc, isRemainderNegative, correctedRemainder, remainder);
103     return result;
104   }
105 
106   /// Floor division operation (rounds towards negative infinity).
107   ///
108   /// For positive divisors, it can be implemented without branching and with a
109   /// single division operation as
110   ///
111   ///        a floordiv b =
112   ///            let negative = a < 0 in
113   ///            let absolute = negative ? -a - 1 : a in
114   ///            let quotient = absolute / b in
115   ///                negative ? -quotient - 1 : quotient
116   ///
117   /// Note: this lowering does not use arith.floordivsi because the lowering of
118   /// that to arith.divsi (see populateCeilFloorDivExpandOpsPatterns) generates
119   /// not one but two arith.divsi. That could be changed to one divsi, but one
120   /// way or another, going through arith.floordivsi will result in more complex
121   /// IR because arith.floordivsi is more general than affine floordiv in that
122   /// it supports negative RHS.
123   Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
124     if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
125       if (rhsConst.getValue() <= 0) {
126         emitError(loc, "division by non-positive value is not supported");
127         return nullptr;
128       }
129     }
130     auto lhs = visit(expr.getLHS());
131     auto rhs = visit(expr.getRHS());
132     assert(lhs && rhs && "unexpected affine expr lowering failure");
133 
134     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
135     Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1);
136     Value negative = builder.create<arith::CmpIOp>(
137         loc, arith::CmpIPredicate::slt, lhs, zeroCst);
138     Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
139     Value dividend =
140         builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs);
141     Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
142     Value correctedQuotient =
143         builder.create<arith::SubIOp>(loc, noneCst, quotient);
144     Value result = builder.create<arith::SelectOp>(loc, negative,
145                                                    correctedQuotient, quotient);
146     return result;
147   }
148 
149   /// Ceiling division operation (rounds towards positive infinity).
150   ///
151   /// For positive divisors, it can be implemented without branching and with a
152   /// single division operation as
153   ///
154   ///     a ceildiv b =
155   ///         let negative = a <= 0 in
156   ///         let absolute = negative ? -a : a - 1 in
157   ///         let quotient = absolute / b in
158   ///             negative ? -quotient : quotient + 1
159   ///
160   /// Note: not using arith.ceildivsi for the same reason as explained in the
161   /// visitFloorDivExpr comment.
162   Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
163     if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
164       if (rhsConst.getValue() <= 0) {
165         emitError(loc, "division by non-positive value is not supported");
166         return nullptr;
167       }
168     }
169     auto lhs = visit(expr.getLHS());
170     auto rhs = visit(expr.getRHS());
171     assert(lhs && rhs && "unexpected affine expr lowering failure");
172 
173     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
174     Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1);
175     Value nonPositive = builder.create<arith::CmpIOp>(
176         loc, arith::CmpIPredicate::sle, lhs, zeroCst);
177     Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
178     Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
179     Value dividend =
180         builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented);
181     Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
182     Value negatedQuotient =
183         builder.create<arith::SubIOp>(loc, zeroCst, quotient);
184     Value incrementedQuotient =
185         builder.create<arith::AddIOp>(loc, quotient, oneCst);
186     Value result = builder.create<arith::SelectOp>(
187         loc, nonPositive, negatedQuotient, incrementedQuotient);
188     return result;
189   }
190 
191   Value visitConstantExpr(AffineConstantExpr expr) {
192     auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
193     return op.getResult();
194   }
195 
196   Value visitDimExpr(AffineDimExpr expr) {
197     assert(expr.getPosition() < dimValues.size() &&
198            "affine dim position out of range");
199     return dimValues[expr.getPosition()];
200   }
201 
202   Value visitSymbolExpr(AffineSymbolExpr expr) {
203     assert(expr.getPosition() < symbolValues.size() &&
204            "symbol dim position out of range");
205     return symbolValues[expr.getPosition()];
206   }
207 
208 private:
209   OpBuilder &builder;
210   ValueRange dimValues;
211   ValueRange symbolValues;
212 
213   Location loc;
214 };
215 } // namespace
216 
217 /// Create a sequence of operations that implement the `expr` applied to the
218 /// given dimension and symbol values.
219 mlir::Value mlir::affine::expandAffineExpr(OpBuilder &builder, Location loc,
220                                            AffineExpr expr,
221                                            ValueRange dimValues,
222                                            ValueRange symbolValues) {
223   return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
224 }
225 
226 /// Create a sequence of operations that implement the `affineMap` applied to
227 /// the given `operands` (as it it were an AffineApplyOp).
228 std::optional<SmallVector<Value, 8>>
229 mlir::affine::expandAffineMap(OpBuilder &builder, Location loc,
230                               AffineMap affineMap, ValueRange operands) {
231   auto numDims = affineMap.getNumDims();
232   auto expanded = llvm::to_vector<8>(
233       llvm::map_range(affineMap.getResults(),
234                       [numDims, &builder, loc, operands](AffineExpr expr) {
235                         return expandAffineExpr(builder, loc, expr,
236                                                 operands.take_front(numDims),
237                                                 operands.drop_front(numDims));
238                       }));
239   if (llvm::all_of(expanded, [](Value v) { return v; }))
240     return expanded;
241   return std::nullopt;
242 }
243 
244 /// Promotes the `then` or the `else` block of `ifOp` (depending on whether
245 /// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
246 /// the rest of the op.
247 static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
248   if (elseBlock)
249     assert(ifOp.hasElse() && "else block expected");
250 
251   Block *destBlock = ifOp->getBlock();
252   Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
253   destBlock->getOperations().splice(
254       Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(),
255       std::prev(srcBlock->end()));
256   ifOp.erase();
257 }
258 
259 /// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
260 /// on. The `ifOp` could be hoisted and placed right before such an operation.
261 /// This method assumes that the ifOp has been canonicalized (to be correct and
262 /// effective).
263 static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
264   // Walk up the parents past all for op that this conditional is invariant on.
265   auto ifOperands = ifOp.getOperands();
266   auto *res = ifOp.getOperation();
267   while (!isa<func::FuncOp>(res->getParentOp())) {
268     auto *parentOp = res->getParentOp();
269     if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
270       if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
271         break;
272     } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
273       for (auto iv : parallelOp.getIVs())
274         if (llvm::is_contained(ifOperands, iv))
275           break;
276     } else if (!isa<AffineIfOp>(parentOp)) {
277       // Won't walk up past anything other than affine.for/if ops.
278       break;
279     }
280     // You can always hoist up past any affine.if ops.
281     res = parentOp;
282   }
283   return res;
284 }
285 
286 /// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
287 /// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
288 /// otherwise the same `ifOp`.
289 static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
290   // No hoisting to do.
291   if (hoistOverOp == ifOp)
292     return ifOp;
293 
294   // Create the hoisted 'if' first. Then, clone the op we are hoisting over for
295   // the else block. Then drop the else block of the original 'if' in the 'then'
296   // branch while promoting its then block, and analogously drop the 'then'
297   // block of the original 'if' from the 'else' branch while promoting its else
298   // block.
299   IRMapping operandMap;
300   OpBuilder b(hoistOverOp);
301   auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
302                                           ifOp.getOperands(),
303                                           /*elseBlock=*/true);
304 
305   // Create a clone of hoistOverOp to use for the else branch of the hoisted
306   // conditional. The else block may get optimized away if empty.
307   Operation *hoistOverOpClone = nullptr;
308   // We use this unique name to identify/find  `ifOp`'s clone in the else
309   // version.
310   StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting");
311   operandMap.clear();
312   b.setInsertionPointAfter(hoistOverOp);
313   // We'll set an attribute to identify this op in a clone of this sub-tree.
314   ifOp->setAttr(idForIfOp, b.getBoolAttr(true));
315   hoistOverOpClone = b.clone(*hoistOverOp, operandMap);
316 
317   // Promote the 'then' block of the original affine.if in the then version.
318   promoteIfBlock(ifOp, /*elseBlock=*/false);
319 
320   // Move the then version to the hoisted if op's 'then' block.
321   auto *thenBlock = hoistedIfOp.getThenBlock();
322   thenBlock->getOperations().splice(thenBlock->begin(),
323                                     hoistOverOp->getBlock()->getOperations(),
324                                     Block::iterator(hoistOverOp));
325 
326   // Find the clone of the original affine.if op in the else version.
327   AffineIfOp ifCloneInElse;
328   hoistOverOpClone->walk([&](AffineIfOp ifClone) {
329     if (!ifClone->getAttr(idForIfOp))
330       return WalkResult::advance();
331     ifCloneInElse = ifClone;
332     return WalkResult::interrupt();
333   });
334   assert(ifCloneInElse && "if op clone should exist");
335   // For the else block, promote the else block of the original 'if' if it had
336   // one; otherwise, the op itself is to be erased.
337   if (!ifCloneInElse.hasElse())
338     ifCloneInElse.erase();
339   else
340     promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
341 
342   // Move the else version into the else block of the hoisted if op.
343   auto *elseBlock = hoistedIfOp.getElseBlock();
344   elseBlock->getOperations().splice(
345       elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
346       Block::iterator(hoistOverOpClone));
347 
348   return hoistedIfOp;
349 }
350 
351 LogicalResult
352 mlir::affine::affineParallelize(AffineForOp forOp,
353                                 ArrayRef<LoopReduction> parallelReductions,
354                                 AffineParallelOp *resOp) {
355   // Fail early if there are iter arguments that are not reductions.
356   unsigned numReductions = parallelReductions.size();
357   if (numReductions != forOp.getNumIterOperands())
358     return failure();
359 
360   Location loc = forOp.getLoc();
361   OpBuilder outsideBuilder(forOp);
362   AffineMap lowerBoundMap = forOp.getLowerBoundMap();
363   ValueRange lowerBoundOperands = forOp.getLowerBoundOperands();
364   AffineMap upperBoundMap = forOp.getUpperBoundMap();
365   ValueRange upperBoundOperands = forOp.getUpperBoundOperands();
366 
367   // Creating empty 1-D affine.parallel op.
368   auto reducedValues = llvm::to_vector<4>(llvm::map_range(
369       parallelReductions, [](const LoopReduction &red) { return red.value; }));
370   auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
371       parallelReductions, [](const LoopReduction &red) { return red.kind; }));
372   AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
373       loc, ValueRange(reducedValues).getTypes(), reductionKinds,
374       llvm::ArrayRef(lowerBoundMap), lowerBoundOperands,
375       llvm::ArrayRef(upperBoundMap), upperBoundOperands,
376       llvm::ArrayRef(forOp.getStepAsInt()));
377   // Steal the body of the old affine for op.
378   newPloop.getRegion().takeBody(forOp.getRegion());
379   Operation *yieldOp = &newPloop.getBody()->back();
380 
381   // Handle the initial values of reductions because the parallel loop always
382   // starts from the neutral value.
383   SmallVector<Value> newResults;
384   newResults.reserve(numReductions);
385   for (unsigned i = 0; i < numReductions; ++i) {
386     Value init = forOp.getInits()[i];
387     // This works because we are only handling single-op reductions at the
388     // moment. A switch on reduction kind or a mechanism to collect operations
389     // participating in the reduction will be necessary for multi-op reductions.
390     Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp();
391     assert(reductionOp && "yielded value is expected to be produced by an op");
392     outsideBuilder.getInsertionBlock()->getOperations().splice(
393         outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(),
394         reductionOp);
395     reductionOp->setOperands({init, newPloop->getResult(i)});
396     forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0));
397   }
398 
399   // Update the loop terminator to yield reduced values bypassing the reduction
400   // operation itself (now moved outside of the loop) and erase the block
401   // arguments that correspond to reductions. Note that the loop always has one
402   // "main" induction variable whenc coming from a non-parallel for.
403   unsigned numIVs = 1;
404   yieldOp->setOperands(reducedValues);
405   newPloop.getBody()->eraseArguments(numIVs, numReductions);
406 
407   forOp.erase();
408   if (resOp)
409     *resOp = newPloop;
410   return success();
411 }
412 
413 // Returns success if any hoisting happened.
414 LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
415   // Bail out early if the ifOp returns a result.  TODO: Consider how to
416   // properly support this case.
417   if (ifOp.getNumResults() != 0)
418     return failure();
419 
420   // Apply canonicalization patterns and folding - this is necessary for the
421   // hoisting check to be correct (operands should be composed), and to be more
422   // effective (no unused operands). Since the pattern rewriter's folding is
423   // entangled with application of patterns, we may fold/end up erasing the op,
424   // in which case we return with `folded` being set.
425   RewritePatternSet patterns(ifOp.getContext());
426   AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
427   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
428   GreedyRewriteConfig config;
429   config.strictMode = GreedyRewriteStrictness::ExistingOps;
430   bool erased;
431   (void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config,
432                                 /*changed=*/nullptr, &erased);
433   if (erased) {
434     if (folded)
435       *folded = true;
436     return failure();
437   }
438   if (folded)
439     *folded = false;
440 
441   // The folding above should have ensured this, but the affine.if's
442   // canonicalization is missing composition of affine.applys into it.
443   assert(llvm::all_of(ifOp.getOperands(),
444                       [](Value v) {
445                         return isTopLevelValue(v) || isAffineForInductionVar(v);
446                       }) &&
447          "operands not composed");
448 
449   // We are going hoist as high as possible.
450   // TODO: this could be customized in the future.
451   auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
452 
453   AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
454   // Nothing to hoist over.
455   if (hoistedIfOp == ifOp)
456     return failure();
457 
458   // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
459   // a sequence of affine.fors that are all perfectly nested).
460   (void)applyPatternsGreedily(
461       hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
462       frozenPatterns);
463 
464   return success();
465 }
466 
467 // Return the min expr after replacing the given dim.
468 AffineExpr mlir::affine::substWithMin(AffineExpr e, AffineExpr dim,
469                                       AffineExpr min, AffineExpr max,
470                                       bool positivePath) {
471   if (e == dim)
472     return positivePath ? min : max;
473   if (auto bin = dyn_cast<AffineBinaryOpExpr>(e)) {
474     AffineExpr lhs = bin.getLHS();
475     AffineExpr rhs = bin.getRHS();
476     if (bin.getKind() == mlir::AffineExprKind::Add)
477       return substWithMin(lhs, dim, min, max, positivePath) +
478              substWithMin(rhs, dim, min, max, positivePath);
479 
480     auto c1 = dyn_cast<AffineConstantExpr>(bin.getLHS());
481     auto c2 = dyn_cast<AffineConstantExpr>(bin.getRHS());
482     if (c1 && c1.getValue() < 0)
483       return getAffineBinaryOpExpr(
484           bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath));
485     if (c2 && c2.getValue() < 0)
486       return getAffineBinaryOpExpr(
487           bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2);
488     return getAffineBinaryOpExpr(
489         bin.getKind(), substWithMin(lhs, dim, min, max, positivePath),
490         substWithMin(rhs, dim, min, max, positivePath));
491   }
492   return e;
493 }
494 
495 void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
496   // Loops with min/max in bounds are not normalized at the moment.
497   if (op.hasMinMaxBounds())
498     return;
499 
500   AffineMap lbMap = op.getLowerBoundsMap();
501   SmallVector<int64_t, 8> steps = op.getSteps();
502   // No need to do any work if the parallel op is already normalized.
503   bool isAlreadyNormalized =
504       llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
505         int64_t step = std::get<0>(tuple);
506         auto lbExpr = dyn_cast<AffineConstantExpr>(std::get<1>(tuple));
507         return lbExpr && lbExpr.getValue() == 0 && step == 1;
508       });
509   if (isAlreadyNormalized)
510     return;
511 
512   AffineValueMap ranges;
513   AffineValueMap::difference(op.getUpperBoundsValueMap(),
514                              op.getLowerBoundsValueMap(), &ranges);
515   auto builder = OpBuilder::atBlockBegin(op.getBody());
516   auto zeroExpr = builder.getAffineConstantExpr(0);
517   SmallVector<AffineExpr, 8> lbExprs;
518   SmallVector<AffineExpr, 8> ubExprs;
519   for (unsigned i = 0, e = steps.size(); i < e; ++i) {
520     int64_t step = steps[i];
521 
522     // Adjust the lower bound to be 0.
523     lbExprs.push_back(zeroExpr);
524 
525     // Adjust the upper bound expression: 'range / step'.
526     AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step);
527     ubExprs.push_back(ubExpr);
528 
529     // Adjust the corresponding IV: 'lb + i * step'.
530     BlockArgument iv = op.getBody()->getArgument(i);
531     AffineExpr lbExpr = lbMap.getResult(i);
532     unsigned nDims = lbMap.getNumDims();
533     auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step;
534     auto map = AffineMap::get(/*dimCount=*/nDims + 1,
535                               /*symbolCount=*/lbMap.getNumSymbols(), expr);
536 
537     // Use an 'affine.apply' op that will be simplified later in subsequent
538     // canonicalizations.
539     OperandRange lbOperands = op.getLowerBoundsOperands();
540     OperandRange dimOperands = lbOperands.take_front(nDims);
541     OperandRange symbolOperands = lbOperands.drop_front(nDims);
542     SmallVector<Value, 8> applyOperands{dimOperands};
543     applyOperands.push_back(iv);
544     applyOperands.append(symbolOperands.begin(), symbolOperands.end());
545     auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
546     iv.replaceAllUsesExcept(apply, apply);
547   }
548 
549   SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
550   op.setSteps(newSteps);
551   auto newLowerMap = AffineMap::get(
552       /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext());
553   op.setLowerBounds({}, newLowerMap);
554   auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(),
555                                     ubExprs, op.getContext());
556   op.setUpperBounds(ranges.getOperands(), newUpperMap);
557 }
558 
559 LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op,
560                                                bool promoteSingleIter) {
561   if (promoteSingleIter && succeeded(promoteIfSingleIteration(op)))
562     return success();
563 
564   // Check if the forop is already normalized.
565   if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) &&
566       (op.getStep() == 1))
567     return success();
568 
569   // Check if the lower bound has a single result only. Loops with a max lower
570   // bound can't be normalized without additional support like
571   // affine.execute_region's. If the lower bound does not have a single result
572   // then skip this op.
573   if (op.getLowerBoundMap().getNumResults() != 1)
574     return failure();
575 
576   Location loc = op.getLoc();
577   OpBuilder opBuilder(op);
578   int64_t origLoopStep = op.getStepAsInt();
579 
580   // Construct the new upper bound value map.
581   AffineMap oldLbMap = op.getLowerBoundMap();
582   // The upper bound can have multiple results. To use
583   // AffineValueMap::difference, we need to have the same number of results in
584   // both lower and upper bound maps. So, we just create a value map for the
585   // lower bound with the only available lower bound result repeated to pad up
586   // to the number of upper bound results.
587   SmallVector<AffineExpr> lbExprs(op.getUpperBoundMap().getNumResults(),
588                                   op.getLowerBoundMap().getResult(0));
589   AffineValueMap lbMap(oldLbMap, op.getLowerBoundOperands());
590   AffineMap paddedLbMap =
591       AffineMap::get(oldLbMap.getNumDims(), oldLbMap.getNumSymbols(), lbExprs,
592                      op.getContext());
593   AffineValueMap paddedLbValueMap(paddedLbMap, op.getLowerBoundOperands());
594   AffineValueMap ubValueMap(op.getUpperBoundMap(), op.getUpperBoundOperands());
595   AffineValueMap newUbValueMap;
596   // Compute the `upper bound - lower bound`.
597   AffineValueMap::difference(ubValueMap, paddedLbValueMap, &newUbValueMap);
598   (void)newUbValueMap.canonicalize();
599 
600   // Scale down the upper bound value map by the loop step.
601   unsigned numResult = newUbValueMap.getNumResults();
602   SmallVector<AffineExpr> scaleDownExprs(numResult);
603   for (unsigned i = 0; i < numResult; ++i)
604     scaleDownExprs[i] = opBuilder.getAffineDimExpr(i).ceilDiv(origLoopStep);
605   // `scaleDownMap` is (d0, d1, ..., d_n) -> (d0 / step, d1 / step, ..., d_n /
606   // step). Where `n` is the number of results in the upper bound map.
607   AffineMap scaleDownMap =
608       AffineMap::get(numResult, 0, scaleDownExprs, op.getContext());
609   AffineMap newUbMap = scaleDownMap.compose(newUbValueMap.getAffineMap());
610 
611   // Set the newly create upper bound map and operands.
612   op.setUpperBound(newUbValueMap.getOperands(), newUbMap);
613   op.setLowerBound({}, opBuilder.getConstantAffineMap(0));
614   op.setStep(1);
615 
616   // Calculate the Value of new loopIV. Create affine.apply for the value of
617   // the loopIV in normalized loop.
618   opBuilder.setInsertionPointToStart(op.getBody());
619   // Construct an affine.apply op mapping the new IV to the old IV.
620   AffineMap scaleIvMap =
621       AffineMap::get(1, 0, -opBuilder.getAffineDimExpr(0) * origLoopStep);
622   AffineValueMap scaleIvValueMap(scaleIvMap, ValueRange{op.getInductionVar()});
623   AffineValueMap newIvToOldIvMap;
624   AffineValueMap::difference(lbMap, scaleIvValueMap, &newIvToOldIvMap);
625   (void)newIvToOldIvMap.canonicalize();
626   auto newIV = opBuilder.create<AffineApplyOp>(
627       loc, newIvToOldIvMap.getAffineMap(), newIvToOldIvMap.getOperands());
628   op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
629   return success();
630 }
631 
632 /// Returns true if the memory operation of `destAccess` depends on `srcAccess`
633 /// inside of the innermost common surrounding affine loop between the two
634 /// accesses.
635 static bool mustReachAtInnermost(const MemRefAccess &srcAccess,
636                                  const MemRefAccess &destAccess) {
637   // Affine dependence analysis is possible only if both ops in the same
638   // AffineScope.
639   if (getAffineScope(srcAccess.opInst) != getAffineScope(destAccess.opInst))
640     return false;
641 
642   unsigned nsLoops =
643       getNumCommonSurroundingLoops(*srcAccess.opInst, *destAccess.opInst);
644   DependenceResult result =
645       checkMemrefAccessDependence(srcAccess, destAccess, nsLoops + 1);
646   return hasDependence(result);
647 }
648 
649 /// Returns true if `srcMemOp` may have an effect on `destMemOp` within the
650 /// scope of the outermost `minSurroundingLoops` loops that surround them.
651 /// `srcMemOp` and `destMemOp` are expected to be affine read/write ops.
652 static bool mayHaveEffect(Operation *srcMemOp, Operation *destMemOp,
653                           unsigned minSurroundingLoops) {
654   MemRefAccess srcAccess(srcMemOp);
655   MemRefAccess destAccess(destMemOp);
656 
657   // Affine dependence analysis here is applicable only if both ops operate on
658   // the same memref and if `srcMemOp` and `destMemOp` are in the same
659   // AffineScope. Also, we can only check if our affine scope is isolated from
660   // above; otherwise, values can from outside of the affine scope that the
661   // check below cannot analyze.
662   Region *srcScope = getAffineScope(srcMemOp);
663   if (srcAccess.memref == destAccess.memref &&
664       srcScope == getAffineScope(destMemOp)) {
665     unsigned nsLoops = getNumCommonSurroundingLoops(*srcMemOp, *destMemOp);
666     FlatAffineValueConstraints dependenceConstraints;
667     for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) {
668       DependenceResult result = checkMemrefAccessDependence(
669           srcAccess, destAccess, d, &dependenceConstraints,
670           /*dependenceComponents=*/nullptr);
671       // A dependence failure or the presence of a dependence implies a
672       // side effect.
673       if (!noDependence(result))
674         return true;
675     }
676     // No side effect was seen.
677     return false;
678   }
679   // TODO: Check here if the memrefs alias: there is no side effect if
680   // `srcAccess.memref` and `destAccess.memref` don't alias.
681   return true;
682 }
683 
684 template <typename EffectType, typename T>
685 bool mlir::affine::hasNoInterveningEffect(
686     Operation *start, T memOp,
687     llvm::function_ref<bool(Value, Value)> mayAlias) {
688   // A boolean representing whether an intervening operation could have impacted
689   // memOp.
690   bool hasSideEffect = false;
691 
692   // Check whether the effect on memOp can be caused by a given operation op.
693   Value memref = memOp.getMemRef();
694   std::function<void(Operation *)> checkOperation = [&](Operation *op) {
695     // If the effect has alreay been found, early exit,
696     if (hasSideEffect)
697       return;
698 
699     if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
700       SmallVector<MemoryEffects::EffectInstance, 1> effects;
701       memEffect.getEffects(effects);
702 
703       bool opMayHaveEffect = false;
704       for (auto effect : effects) {
705         // If op causes EffectType on a potentially aliasing location for
706         // memOp, mark as having the effect.
707         if (isa<EffectType>(effect.getEffect())) {
708           if (effect.getValue() && effect.getValue() != memref &&
709               !mayAlias(effect.getValue(), memref))
710             continue;
711           opMayHaveEffect = true;
712           break;
713         }
714       }
715 
716       if (!opMayHaveEffect)
717         return;
718 
719       // If the side effect comes from an affine read or write, try to
720       // prove the side effecting `op` cannot reach `memOp`.
721       if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
722         // For ease, let's consider the case that `op` is a store and
723         // we're looking for other potential stores that overwrite memory after
724         // `start`, and before being read in `memOp`. In this case, we only
725         // need to consider other potential stores with depth >
726         // minSurroundingLoops since `start` would overwrite any store with a
727         // smaller number of surrounding loops before.
728         unsigned minSurroundingLoops =
729             getNumCommonSurroundingLoops(*start, *memOp);
730         if (mayHaveEffect(op, memOp, minSurroundingLoops))
731           hasSideEffect = true;
732         return;
733       }
734 
735       // We have an op with a memory effect and we cannot prove if it
736       // intervenes.
737       hasSideEffect = true;
738       return;
739     }
740 
741     if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
742       // Recurse into the regions for this op and check whether the internal
743       // operations may have the side effect `EffectType` on memOp.
744       for (Region &region : op->getRegions())
745         for (Block &block : region)
746           for (Operation &op : block)
747             checkOperation(&op);
748       return;
749     }
750 
751     // Otherwise, conservatively assume generic operations have the effect
752     // on the operation
753     hasSideEffect = true;
754   };
755 
756   // Check all paths from ancestor op `parent` to the operation `to` for the
757   // effect. It is known that `to` must be contained within `parent`.
758   auto until = [&](Operation *parent, Operation *to) {
759     // TODO check only the paths from `parent` to `to`.
760     // Currently we fallback and check the entire parent op, rather than
761     // just the paths from the parent path, stopping after reaching `to`.
762     // This is conservatively correct, but could be made more aggressive.
763     assert(parent->isAncestor(to));
764     checkOperation(parent);
765   };
766 
767   // Check for all paths from operation `from` to operation `untilOp` for the
768   // given memory effect.
769   std::function<void(Operation *, Operation *)> recur =
770       [&](Operation *from, Operation *untilOp) {
771         assert(
772             from->getParentRegion()->isAncestor(untilOp->getParentRegion()) &&
773             "Checking for side effect between two operations without a common "
774             "ancestor");
775 
776         // If the operations are in different regions, recursively consider all
777         // path from `from` to the parent of `to` and all paths from the parent
778         // of `to` to `to`.
779         if (from->getParentRegion() != untilOp->getParentRegion()) {
780           recur(from, untilOp->getParentOp());
781           until(untilOp->getParentOp(), untilOp);
782           return;
783         }
784 
785         // Now, assuming that `from` and `to` exist in the same region, perform
786         // a CFG traversal to check all the relevant operations.
787 
788         // Additional blocks to consider.
789         SmallVector<Block *, 2> todoBlocks;
790         {
791           // First consider the parent block of `from` an check all operations
792           // after `from`.
793           for (auto iter = ++from->getIterator(), end = from->getBlock()->end();
794                iter != end && &*iter != untilOp; ++iter) {
795             checkOperation(&*iter);
796           }
797 
798           // If the parent of `from` doesn't contain `to`, add the successors
799           // to the list of blocks to check.
800           if (untilOp->getBlock() != from->getBlock())
801             for (Block *succ : from->getBlock()->getSuccessors())
802               todoBlocks.push_back(succ);
803         }
804 
805         SmallPtrSet<Block *, 4> done;
806         // Traverse the CFG until hitting `to`.
807         while (!todoBlocks.empty()) {
808           Block *blk = todoBlocks.pop_back_val();
809           if (done.count(blk))
810             continue;
811           done.insert(blk);
812           for (auto &op : *blk) {
813             if (&op == untilOp)
814               break;
815             checkOperation(&op);
816             if (&op == blk->getTerminator())
817               for (Block *succ : blk->getSuccessors())
818                 todoBlocks.push_back(succ);
819           }
820         }
821       };
822   recur(start, memOp);
823   return !hasSideEffect;
824 }
825 
826 /// Attempt to eliminate loadOp by replacing it with a value stored into memory
827 /// which the load is guaranteed to retrieve. This check involves three
828 /// components: 1) The store and load must be on the same location 2) The store
829 /// must dominate (and therefore must always occur prior to) the load 3) No
830 /// other operations will overwrite the memory loaded between the given load
831 /// and store.  If such a value exists, the replaced `loadOp` will be added to
832 /// `loadOpsToErase` and its memref will be added to `memrefsToErase`.
833 static void forwardStoreToLoad(
834     AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase,
835     SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo,
836     llvm::function_ref<bool(Value, Value)> mayAlias) {
837 
838   // The store op candidate for forwarding that satisfies all conditions
839   // to replace the load, if any.
840   Operation *lastWriteStoreOp = nullptr;
841 
842   for (auto *user : loadOp.getMemRef().getUsers()) {
843     auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
844     if (!storeOp)
845       continue;
846     MemRefAccess srcAccess(storeOp);
847     MemRefAccess destAccess(loadOp);
848 
849     // 1. Check if the store and the load have mathematically equivalent
850     // affine access functions; this implies that they statically refer to the
851     // same single memref element. As an example this filters out cases like:
852     //     store %A[%i0 + 1]
853     //     load %A[%i0]
854     //     store %A[%M]
855     //     load %A[%N]
856     // Use the AffineValueMap difference based memref access equality checking.
857     if (srcAccess != destAccess)
858       continue;
859 
860     // 2. The store has to dominate the load op to be candidate.
861     if (!domInfo.dominates(storeOp, loadOp))
862       continue;
863 
864     // 3. The store must reach the load. Access function equivalence only
865     // guarantees this for accesses in the same block. The load could be in a
866     // nested block that is unreachable.
867     if (!mustReachAtInnermost(srcAccess, destAccess))
868       continue;
869 
870     // 4. Ensure there is no intermediate operation which could replace the
871     // value in memory.
872     if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp,
873                                                               mayAlias))
874       continue;
875 
876     // We now have a candidate for forwarding.
877     assert(lastWriteStoreOp == nullptr &&
878            "multiple simultaneous replacement stores");
879     lastWriteStoreOp = storeOp;
880   }
881 
882   if (!lastWriteStoreOp)
883     return;
884 
885   // Perform the actual store to load forwarding.
886   Value storeVal =
887       cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
888   // Check if 2 values have the same shape. This is needed for affine vector
889   // loads and stores.
890   if (storeVal.getType() != loadOp.getValue().getType())
891     return;
892   loadOp.getValue().replaceAllUsesWith(storeVal);
893   // Record the memref for a later sweep to optimize away.
894   memrefsToErase.insert(loadOp.getMemRef());
895   // Record this to erase later.
896   loadOpsToErase.push_back(loadOp);
897 }
898 
899 template bool
900 mlir::affine::hasNoInterveningEffect<mlir::MemoryEffects::Read,
901                                      affine::AffineReadOpInterface>(
902     mlir::Operation *, affine::AffineReadOpInterface,
903     llvm::function_ref<bool(Value, Value)>);
904 
905 // This attempts to find stores which have no impact on the final result.
906 // A writing op writeA will be eliminated if there exists an op writeB if
907 // 1) writeA and writeB have mathematically equivalent affine access functions.
908 // 2) writeB postdominates writeA.
909 // 3) There is no potential read between writeA and writeB.
910 static void findUnusedStore(AffineWriteOpInterface writeA,
911                             SmallVectorImpl<Operation *> &opsToErase,
912                             PostDominanceInfo &postDominanceInfo,
913                             llvm::function_ref<bool(Value, Value)> mayAlias) {
914 
915   for (Operation *user : writeA.getMemRef().getUsers()) {
916     // Only consider writing operations.
917     auto writeB = dyn_cast<AffineWriteOpInterface>(user);
918     if (!writeB)
919       continue;
920 
921     // The operations must be distinct.
922     if (writeB == writeA)
923       continue;
924 
925     // Both operations must lie in the same region.
926     if (writeB->getParentRegion() != writeA->getParentRegion())
927       continue;
928 
929     // Both operations must write to the same memory.
930     MemRefAccess srcAccess(writeB);
931     MemRefAccess destAccess(writeA);
932 
933     if (srcAccess != destAccess)
934       continue;
935 
936     // writeB must postdominate writeA.
937     if (!postDominanceInfo.postDominates(writeB, writeA))
938       continue;
939 
940     // There cannot be an operation which reads from memory between
941     // the two writes.
942     if (!affine::hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB,
943                                                              mayAlias))
944       continue;
945 
946     opsToErase.push_back(writeA);
947     break;
948   }
949 }
950 
951 // The load to load forwarding / redundant load elimination is similar to the
952 // store to load forwarding.
953 // loadA will be be replaced with loadB if:
954 // 1) loadA and loadB have mathematically equivalent affine access functions.
955 // 2) loadB dominates loadA.
956 // 3) There is no write between loadA and loadB.
957 static void loadCSE(AffineReadOpInterface loadA,
958                     SmallVectorImpl<Operation *> &loadOpsToErase,
959                     DominanceInfo &domInfo,
960                     llvm::function_ref<bool(Value, Value)> mayAlias) {
961   SmallVector<AffineReadOpInterface, 4> loadCandidates;
962   for (auto *user : loadA.getMemRef().getUsers()) {
963     auto loadB = dyn_cast<AffineReadOpInterface>(user);
964     if (!loadB || loadB == loadA)
965       continue;
966 
967     MemRefAccess srcAccess(loadB);
968     MemRefAccess destAccess(loadA);
969 
970     // 1. The accesses should be to be to the same location.
971     if (srcAccess != destAccess) {
972       continue;
973     }
974 
975     // 2. loadB should dominate loadA.
976     if (!domInfo.dominates(loadB, loadA))
977       continue;
978 
979     // 3. There should not be a write between loadA and loadB.
980     if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(
981             loadB.getOperation(), loadA, mayAlias))
982       continue;
983 
984     // Check if two values have the same shape. This is needed for affine vector
985     // loads.
986     if (loadB.getValue().getType() != loadA.getValue().getType())
987       continue;
988 
989     loadCandidates.push_back(loadB);
990   }
991 
992   // Of the legal load candidates, use the one that dominates all others
993   // to minimize the subsequent need to loadCSE
994   Value loadB;
995   for (AffineReadOpInterface option : loadCandidates) {
996     if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) {
997           return depStore == option ||
998                  domInfo.dominates(option.getOperation(),
999                                    depStore.getOperation());
1000         })) {
1001       loadB = option.getValue();
1002       break;
1003     }
1004   }
1005 
1006   if (loadB) {
1007     loadA.getValue().replaceAllUsesWith(loadB);
1008     // Record this to erase later.
1009     loadOpsToErase.push_back(loadA);
1010   }
1011 }
1012 
1013 // The store to load forwarding and load CSE rely on three conditions:
1014 //
1015 // 1) store/load providing a replacement value and load being replaced need to
1016 // have mathematically equivalent affine access functions (checked after full
1017 // composition of load/store operands); this implies that they access the same
1018 // single memref element for all iterations of the common surrounding loop,
1019 //
1020 // 2) the store/load op should dominate the load op,
1021 //
1022 // 3) no operation that may write to memory read by the load being replaced can
1023 // occur after executing the instruction (load or store) providing the
1024 // replacement value and before the load being replaced (thus potentially
1025 // allowing overwriting the memory read by the load).
1026 //
1027 // The above conditions are simple to check, sufficient, and powerful for most
1028 // cases in practice - they are sufficient, but not necessary --- since they
1029 // don't reason about loops that are guaranteed to execute at least once or
1030 // multiple sources to forward from.
1031 //
1032 // TODO: more forwarding can be done when support for
1033 // loop/conditional live-out SSA values is available.
1034 // TODO: do general dead store elimination for memref's. This pass
1035 // currently only eliminates the stores only if no other loads/uses (other
1036 // than dealloc) remain.
1037 //
1038 void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
1039                                        PostDominanceInfo &postDomInfo,
1040                                        AliasAnalysis &aliasAnalysis) {
1041   // Load op's whose results were replaced by those forwarded from stores.
1042   SmallVector<Operation *, 8> opsToErase;
1043 
1044   // A list of memref's that are potentially dead / could be eliminated.
1045   SmallPtrSet<Value, 4> memrefsToErase;
1046 
1047   auto mayAlias = [&](Value val1, Value val2) -> bool {
1048     return !aliasAnalysis.alias(val1, val2).isNo();
1049   };
1050 
1051   // Walk all load's and perform store to load forwarding.
1052   f.walk([&](AffineReadOpInterface loadOp) {
1053     forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo, mayAlias);
1054   });
1055   for (auto *op : opsToErase)
1056     op->erase();
1057   opsToErase.clear();
1058 
1059   // Walk all store's and perform unused store elimination
1060   f.walk([&](AffineWriteOpInterface storeOp) {
1061     findUnusedStore(storeOp, opsToErase, postDomInfo, mayAlias);
1062   });
1063   for (auto *op : opsToErase)
1064     op->erase();
1065   opsToErase.clear();
1066 
1067   // Check if the store fwd'ed memrefs are now left with only stores and
1068   // deallocs and can thus be completely deleted. Note: the canonicalize pass
1069   // should be able to do this as well, but we'll do it here since we collected
1070   // these anyway.
1071   for (auto memref : memrefsToErase) {
1072     // If the memref hasn't been locally alloc'ed, skip.
1073     Operation *defOp = memref.getDefiningOp();
1074     if (!defOp || !hasSingleEffect<MemoryEffects::Allocate>(defOp, memref))
1075       // TODO: if the memref was returned by a 'call' operation, we
1076       // could still erase it if the call had no side-effects.
1077       continue;
1078     if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
1079           return !isa<AffineWriteOpInterface>(ownerOp) &&
1080                  !hasSingleEffect<MemoryEffects::Free>(ownerOp, memref);
1081         }))
1082       continue;
1083 
1084     // Erase all stores, the dealloc, and the alloc on the memref.
1085     for (auto *user : llvm::make_early_inc_range(memref.getUsers()))
1086       user->erase();
1087     defOp->erase();
1088   }
1089 
1090   // To eliminate as many loads as possible, run load CSE after eliminating
1091   // stores. Otherwise, some stores are wrongly seen as having an intervening
1092   // effect.
1093   f.walk([&](AffineReadOpInterface loadOp) {
1094     loadCSE(loadOp, opsToErase, domInfo, mayAlias);
1095   });
1096   for (auto *op : opsToErase)
1097     op->erase();
1098 }
1099 
1100 // Private helper function to transform memref.load with reduced rank.
1101 // This function will modify the indices of the memref.load to match the
1102 // newMemRef.
1103 LogicalResult transformMemRefLoadWithReducedRank(
1104     Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
1105     ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
1106     ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
1107   unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
1108   unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
1109   unsigned oldMapNumInputs = oldMemRefRank;
1110   SmallVector<Value, 4> oldMapOperands(
1111       op->operand_begin() + memRefOperandPos + 1,
1112       op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
1113   SmallVector<Value, 4> oldMemRefOperands;
1114   oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
1115   SmallVector<Value, 4> remapOperands;
1116   remapOperands.reserve(extraOperands.size() + oldMemRefRank +
1117                         symbolOperands.size());
1118   remapOperands.append(extraOperands.begin(), extraOperands.end());
1119   remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
1120   remapOperands.append(symbolOperands.begin(), symbolOperands.end());
1121 
1122   SmallVector<Value, 4> remapOutputs;
1123   remapOutputs.reserve(oldMemRefRank);
1124   SmallVector<Value, 4> affineApplyOps;
1125 
1126   OpBuilder builder(op);
1127 
1128   if (indexRemap &&
1129       indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
1130     // Remapped indices.
1131     for (auto resultExpr : indexRemap.getResults()) {
1132       auto singleResMap = AffineMap::get(
1133           indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
1134       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1135                                                 remapOperands);
1136       remapOutputs.push_back(afOp);
1137       affineApplyOps.push_back(afOp);
1138     }
1139   } else {
1140     // No remapping specified.
1141     remapOutputs.assign(remapOperands.begin(), remapOperands.end());
1142   }
1143 
1144   SmallVector<Value, 4> newMapOperands;
1145   newMapOperands.reserve(newMemRefRank);
1146 
1147   // Prepend 'extraIndices' in 'newMapOperands'.
1148   for (Value extraIndex : extraIndices) {
1149     assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
1150            "invalid memory op index");
1151     newMapOperands.push_back(extraIndex);
1152   }
1153 
1154   // Append 'remapOutputs' to 'newMapOperands'.
1155   newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
1156 
1157   // Create new fully composed AffineMap for new op to be created.
1158   assert(newMapOperands.size() == newMemRefRank);
1159 
1160   OperationState state(op->getLoc(), op->getName());
1161   // Construct the new operation using this memref.
1162   state.operands.reserve(newMapOperands.size() + extraIndices.size());
1163   state.operands.push_back(newMemRef);
1164 
1165   // Insert the new memref map operands.
1166   state.operands.append(newMapOperands.begin(), newMapOperands.end());
1167 
1168   state.types.reserve(op->getNumResults());
1169   for (auto result : op->getResults())
1170     state.types.push_back(result.getType());
1171 
1172   // Copy over the attributes from the old operation to the new operation.
1173   for (auto namedAttr : op->getAttrs()) {
1174     state.attributes.push_back(namedAttr);
1175   }
1176 
1177   // Create the new operation.
1178   auto *repOp = builder.create(state);
1179   op->replaceAllUsesWith(repOp);
1180   op->erase();
1181 
1182   return success();
1183 }
1184 // Perform the replacement in `op`.
1185 LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1186     Value oldMemRef, Value newMemRef, Operation *op,
1187     ArrayRef<Value> extraIndices, AffineMap indexRemap,
1188     ArrayRef<Value> extraOperands, ArrayRef<Value> symbolOperands,
1189     bool allowNonDereferencingOps) {
1190   unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
1191   (void)newMemRefRank; // unused in opt mode
1192   unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
1193   (void)oldMemRefRank; // unused in opt mode
1194   if (indexRemap) {
1195     assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1196            "symbolic operand count mismatch");
1197     assert(indexRemap.getNumInputs() ==
1198            extraOperands.size() + oldMemRefRank + symbolOperands.size());
1199     assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1200   } else {
1201     assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1202   }
1203 
1204   // Assert same elemental type.
1205   assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
1206          cast<MemRefType>(newMemRef.getType()).getElementType());
1207 
1208   SmallVector<unsigned, 2> usePositions;
1209   for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
1210     if (opEntry.value() == oldMemRef)
1211       usePositions.push_back(opEntry.index());
1212   }
1213 
1214   // If memref doesn't appear, nothing to do.
1215   if (usePositions.empty())
1216     return success();
1217 
1218   if (usePositions.size() > 1) {
1219     // TODO: extend it for this case when needed (rare).
1220     assert(false && "multiple dereferencing uses in a single op not supported");
1221     return failure();
1222   }
1223 
1224   unsigned memRefOperandPos = usePositions.front();
1225 
1226   OpBuilder builder(op);
1227   // The following checks if op is dereferencing memref and performs the access
1228   // index rewrites.
1229   auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1230   if (!affMapAccInterface) {
1231     if (!allowNonDereferencingOps) {
1232       // Failure: memref used in a non-dereferencing context (potentially
1233       // escapes); no replacement in these cases unless allowNonDereferencingOps
1234       // is set.
1235       return failure();
1236     }
1237 
1238     // Check if it is a memref.load
1239     auto memrefLoad = dyn_cast<memref::LoadOp>(op);
1240     bool isReductionLike =
1241         indexRemap.getNumResults() < indexRemap.getNumInputs();
1242     if (!memrefLoad || !isReductionLike) {
1243       op->setOperand(memRefOperandPos, newMemRef);
1244       return success();
1245     }
1246 
1247     return transformMemRefLoadWithReducedRank(
1248         op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
1249         symbolOperands, indexRemap);
1250   }
1251   // Perform index rewrites for the dereferencing op and then replace the op
1252   NamedAttribute oldMapAttrPair =
1253       affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
1254   AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
1255   unsigned oldMapNumInputs = oldMap.getNumInputs();
1256   SmallVector<Value, 4> oldMapOperands(
1257       op->operand_begin() + memRefOperandPos + 1,
1258       op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
1259 
1260   // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
1261   SmallVector<Value, 4> oldMemRefOperands;
1262   SmallVector<Value, 4> affineApplyOps;
1263   oldMemRefOperands.reserve(oldMemRefRank);
1264   if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
1265     for (auto resultExpr : oldMap.getResults()) {
1266       auto singleResMap = AffineMap::get(oldMap.getNumDims(),
1267                                          oldMap.getNumSymbols(), resultExpr);
1268       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1269                                                 oldMapOperands);
1270       oldMemRefOperands.push_back(afOp);
1271       affineApplyOps.push_back(afOp);
1272     }
1273   } else {
1274     oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
1275   }
1276 
1277   // Construct new indices as a remap of the old ones if a remapping has been
1278   // provided. The indices of a memref come right after it, i.e.,
1279   // at position memRefOperandPos + 1.
1280   SmallVector<Value, 4> remapOperands;
1281   remapOperands.reserve(extraOperands.size() + oldMemRefRank +
1282                         symbolOperands.size());
1283   remapOperands.append(extraOperands.begin(), extraOperands.end());
1284   remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
1285   remapOperands.append(symbolOperands.begin(), symbolOperands.end());
1286 
1287   SmallVector<Value, 4> remapOutputs;
1288   remapOutputs.reserve(oldMemRefRank);
1289 
1290   if (indexRemap &&
1291       indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
1292     // Remapped indices.
1293     for (auto resultExpr : indexRemap.getResults()) {
1294       auto singleResMap = AffineMap::get(
1295           indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
1296       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1297                                                 remapOperands);
1298       remapOutputs.push_back(afOp);
1299       affineApplyOps.push_back(afOp);
1300     }
1301   } else {
1302     // No remapping specified.
1303     remapOutputs.assign(remapOperands.begin(), remapOperands.end());
1304   }
1305 
1306   SmallVector<Value, 4> newMapOperands;
1307   newMapOperands.reserve(newMemRefRank);
1308 
1309   // Prepend 'extraIndices' in 'newMapOperands'.
1310   for (Value extraIndex : extraIndices) {
1311     assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
1312            "invalid memory op index");
1313     newMapOperands.push_back(extraIndex);
1314   }
1315 
1316   // Append 'remapOutputs' to 'newMapOperands'.
1317   newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
1318 
1319   // Create new fully composed AffineMap for new op to be created.
1320   assert(newMapOperands.size() == newMemRefRank);
1321   auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
1322   fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
1323   newMap = simplifyAffineMap(newMap);
1324   canonicalizeMapAndOperands(&newMap, &newMapOperands);
1325   // Remove any affine.apply's that became dead as a result of composition.
1326   for (Value value : affineApplyOps)
1327     if (value.use_empty())
1328       value.getDefiningOp()->erase();
1329 
1330   OperationState state(op->getLoc(), op->getName());
1331   // Construct the new operation using this memref.
1332   state.operands.reserve(op->getNumOperands() + extraIndices.size());
1333   // Insert the non-memref operands.
1334   state.operands.append(op->operand_begin(),
1335                         op->operand_begin() + memRefOperandPos);
1336   // Insert the new memref value.
1337   state.operands.push_back(newMemRef);
1338 
1339   // Insert the new memref map operands.
1340   state.operands.append(newMapOperands.begin(), newMapOperands.end());
1341 
1342   // Insert the remaining operands unmodified.
1343   state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
1344                             oldMapNumInputs,
1345                         op->operand_end());
1346 
1347   // Result types don't change. Both memref's are of the same elemental type.
1348   state.types.reserve(op->getNumResults());
1349   for (auto result : op->getResults())
1350     state.types.push_back(result.getType());
1351 
1352   // Add attribute for 'newMap', other Attributes do not change.
1353   auto newMapAttr = AffineMapAttr::get(newMap);
1354   for (auto namedAttr : op->getAttrs()) {
1355     if (namedAttr.getName() == oldMapAttrPair.getName())
1356       state.attributes.push_back({namedAttr.getName(), newMapAttr});
1357     else
1358       state.attributes.push_back(namedAttr);
1359   }
1360 
1361   // Create the new operation.
1362   auto *repOp = builder.create(state);
1363   op->replaceAllUsesWith(repOp);
1364   op->erase();
1365 
1366   return success();
1367 }
1368 
1369 LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1370     Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
1371     AffineMap indexRemap, ArrayRef<Value> extraOperands,
1372     ArrayRef<Value> symbolOperands, Operation *domOpFilter,
1373     Operation *postDomOpFilter, bool allowNonDereferencingOps,
1374     bool replaceInDeallocOp) {
1375   unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
1376   (void)newMemRefRank; // unused in opt mode
1377   unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
1378   (void)oldMemRefRank;
1379   if (indexRemap) {
1380     assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1381            "symbol operand count mismatch");
1382     assert(indexRemap.getNumInputs() ==
1383            extraOperands.size() + oldMemRefRank + symbolOperands.size());
1384     assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1385   } else {
1386     assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1387   }
1388 
1389   // Assert same elemental type.
1390   assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
1391          cast<MemRefType>(newMemRef.getType()).getElementType());
1392 
1393   std::unique_ptr<DominanceInfo> domInfo;
1394   std::unique_ptr<PostDominanceInfo> postDomInfo;
1395   if (domOpFilter)
1396     domInfo = std::make_unique<DominanceInfo>(
1397         domOpFilter->getParentOfType<FunctionOpInterface>());
1398 
1399   if (postDomOpFilter)
1400     postDomInfo = std::make_unique<PostDominanceInfo>(
1401         postDomOpFilter->getParentOfType<FunctionOpInterface>());
1402 
1403   // Walk all uses of old memref; collect ops to perform replacement. We use a
1404   // DenseSet since an operation could potentially have multiple uses of a
1405   // memref (although rare), and the replacement later is going to erase ops.
1406   DenseSet<Operation *> opsToReplace;
1407   for (auto *op : oldMemRef.getUsers()) {
1408     // Skip this use if it's not dominated by domOpFilter.
1409     if (domOpFilter && !domInfo->dominates(domOpFilter, op))
1410       continue;
1411 
1412     // Skip this use if it's not post-dominated by postDomOpFilter.
1413     if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op))
1414       continue;
1415 
1416     // Skip dealloc's - no replacement is necessary, and a memref replacement
1417     // at other uses doesn't hurt these dealloc's.
1418     if (hasSingleEffect<MemoryEffects::Free>(op, oldMemRef) &&
1419         !replaceInDeallocOp)
1420       continue;
1421 
1422     // Check if the memref was used in a non-dereferencing context. It is fine
1423     // for the memref to be used in a non-dereferencing way outside of the
1424     // region where this replacement is happening.
1425     if (!isa<AffineMapAccessInterface>(*op)) {
1426       if (!allowNonDereferencingOps) {
1427         LLVM_DEBUG(llvm::dbgs()
1428                    << "Memref replacement failed: non-deferencing memref op: \n"
1429                    << *op << '\n');
1430         return failure();
1431       }
1432       // Non-dereferencing ops with the MemRefsNormalizable trait are
1433       // supported for replacement.
1434       if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) {
1435         LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a "
1436                                    "memrefs normalizable trait: \n"
1437                                 << *op << '\n');
1438         return failure();
1439       }
1440     }
1441 
1442     // We'll first collect and then replace --- since replacement erases the op
1443     // that has the use, and that op could be postDomFilter or domFilter itself!
1444     opsToReplace.insert(op);
1445   }
1446 
1447   for (auto *op : opsToReplace) {
1448     if (failed(replaceAllMemRefUsesWith(
1449             oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
1450             symbolOperands, allowNonDereferencingOps)))
1451       llvm_unreachable("memref replacement guaranteed to succeed here");
1452   }
1453 
1454   return success();
1455 }
1456 
1457 /// Given an operation, inserts one or more single result affine
1458 /// apply operations, results of which are exclusively used by this operation
1459 /// operation. The operands of these newly created affine apply ops are
1460 /// guaranteed to be loop iterators or terminal symbols of a function.
1461 ///
1462 /// Before
1463 ///
1464 /// affine.for %i = 0 to #map(%N)
1465 ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1466 ///   "send"(%idx, %A, ...)
1467 ///   "compute"(%idx)
1468 ///
1469 /// After
1470 ///
1471 /// affine.for %i = 0 to #map(%N)
1472 ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1473 ///   "send"(%idx, %A, ...)
1474 ///   %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
1475 ///   "compute"(%idx_)
1476 ///
1477 /// This allows applying different transformations on send and compute (for eg.
1478 /// different shifts/delays).
1479 ///
1480 /// Returns nullptr either if none of opInst's operands were the result of an
1481 /// affine.apply and thus there was no affine computation slice to create, or if
1482 /// all the affine.apply op's supplying operands to this opInst did not have any
1483 /// uses besides this opInst; otherwise returns the list of affine.apply
1484 /// operations created in output argument `sliceOps`.
1485 void mlir::affine::createAffineComputationSlice(
1486     Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
1487   // Collect all operands that are results of affine apply ops.
1488   SmallVector<Value, 4> subOperands;
1489   subOperands.reserve(opInst->getNumOperands());
1490   for (auto operand : opInst->getOperands())
1491     if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp()))
1492       subOperands.push_back(operand);
1493 
1494   // Gather sequence of AffineApplyOps reachable from 'subOperands'.
1495   SmallVector<Operation *, 4> affineApplyOps;
1496   getReachableAffineApplyOps(subOperands, affineApplyOps);
1497   // Skip transforming if there are no affine maps to compose.
1498   if (affineApplyOps.empty())
1499     return;
1500 
1501   // Check if all uses of the affine apply op's lie only in this op op, in
1502   // which case there would be nothing to do.
1503   bool localized = true;
1504   for (auto *op : affineApplyOps) {
1505     for (auto result : op->getResults()) {
1506       for (auto *user : result.getUsers()) {
1507         if (user != opInst) {
1508           localized = false;
1509           break;
1510         }
1511       }
1512     }
1513   }
1514   if (localized)
1515     return;
1516 
1517   OpBuilder builder(opInst);
1518   SmallVector<Value, 4> composedOpOperands(subOperands);
1519   auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
1520   fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
1521 
1522   // Create an affine.apply for each of the map results.
1523   sliceOps->reserve(composedMap.getNumResults());
1524   for (auto resultExpr : composedMap.getResults()) {
1525     auto singleResMap = AffineMap::get(composedMap.getNumDims(),
1526                                        composedMap.getNumSymbols(), resultExpr);
1527     sliceOps->push_back(builder.create<AffineApplyOp>(
1528         opInst->getLoc(), singleResMap, composedOpOperands));
1529   }
1530 
1531   // Construct the new operands that include the results from the composed
1532   // affine apply op above instead of existing ones (subOperands). So, they
1533   // differ from opInst's operands only for those operands in 'subOperands', for
1534   // which they will be replaced by the corresponding one from 'sliceOps'.
1535   SmallVector<Value, 4> newOperands(opInst->getOperands());
1536   for (Value &operand : newOperands) {
1537     // Replace the subOperands from among the new operands.
1538     unsigned j, f;
1539     for (j = 0, f = subOperands.size(); j < f; j++) {
1540       if (operand == subOperands[j])
1541         break;
1542     }
1543     if (j < subOperands.size())
1544       operand = (*sliceOps)[j];
1545   }
1546   for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++)
1547     opInst->setOperand(idx, newOperands[idx]);
1548 }
1549 
1550 /// Enum to set patterns of affine expr in tiled-layout map.
1551 /// TileFloorDiv: <dim expr> div <tile size>
1552 /// TileMod: <dim expr> mod <tile size>
1553 /// TileNone: None of the above
1554 /// Example:
1555 /// #tiled_2d_128x256 = affine_map<(d0, d1)
1556 ///            -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1557 /// "d0 div 128" and "d1 div 256" ==> TileFloorDiv
1558 /// "d0 mod 128" and "d1 mod 256" ==> TileMod
1559 enum TileExprPattern { TileFloorDiv, TileMod, TileNone };
1560 
1561 /// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions
1562 /// being floordiv'ed by respective tile sizes appeare in a mod with the same
1563 /// tile sizes, and no other expression involves those k dimensions. This
1564 /// function stores a vector of tuples (`tileSizePos`) including AffineExpr for
1565 /// tile size, positions of corresponding `floordiv` and `mod`. If it is not a
1566 /// tiled layout, an empty vector is returned.
1567 static LogicalResult getTileSizePos(
1568     AffineMap map,
1569     SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) {
1570   // Create `floordivExprs` which is a vector of tuples including LHS and RHS of
1571   // `floordiv` and its position in `map` output.
1572   // Example: #tiled_2d_128x256 = affine_map<(d0, d1)
1573   //                -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1574   // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}.
1575   SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs;
1576   unsigned pos = 0;
1577   for (AffineExpr expr : map.getResults()) {
1578     if (expr.getKind() == AffineExprKind::FloorDiv) {
1579       AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
1580       if (isa<AffineConstantExpr>(binaryExpr.getRHS()))
1581         floordivExprs.emplace_back(
1582             std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos));
1583     }
1584     pos++;
1585   }
1586   // Not tiled layout if `floordivExprs` is empty.
1587   if (floordivExprs.empty()) {
1588     tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1589     return success();
1590   }
1591 
1592   // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is
1593   // not tiled layout.
1594   for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) {
1595     AffineExpr floordivExprLHS = std::get<0>(fexpr);
1596     AffineExpr floordivExprRHS = std::get<1>(fexpr);
1597     unsigned floordivPos = std::get<2>(fexpr);
1598 
1599     // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS
1600     // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used
1601     // other expr, the map is not tiled layout. Example of non tiled layout:
1602     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)>
1603     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)>
1604     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod
1605     //   256)>
1606     bool found = false;
1607     pos = 0;
1608     for (AffineExpr expr : map.getResults()) {
1609       bool notTiled = false;
1610       if (pos != floordivPos) {
1611         expr.walk([&](AffineExpr e) {
1612           if (e == floordivExprLHS) {
1613             if (expr.getKind() == AffineExprKind::Mod) {
1614               AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
1615               // If LHS and RHS of `mod` are the same with those of floordiv.
1616               if (floordivExprLHS == binaryExpr.getLHS() &&
1617                   floordivExprRHS == binaryExpr.getRHS()) {
1618                 // Save tile size (RHS of `mod`), and position of `floordiv` and
1619                 // `mod` if same expr with `mod` is not found yet.
1620                 if (!found) {
1621                   tileSizePos.emplace_back(
1622                       std::make_tuple(binaryExpr.getRHS(), floordivPos, pos));
1623                   found = true;
1624                 } else {
1625                   // Non tiled layout: Have multilpe `mod` with the same LHS.
1626                   // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1627                   // mod 256, d2 mod 256)>
1628                   notTiled = true;
1629                 }
1630               } else {
1631                 // Non tiled layout: RHS of `mod` is different from `floordiv`.
1632                 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1633                 // mod 128)>
1634                 notTiled = true;
1635               }
1636             } else {
1637               // Non tiled layout: LHS is the same, but not `mod`.
1638               // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1639               // floordiv 256)>
1640               notTiled = true;
1641             }
1642           }
1643         });
1644       }
1645       if (notTiled) {
1646         tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1647         return success();
1648       }
1649       pos++;
1650     }
1651   }
1652   return success();
1653 }
1654 
1655 /// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic
1656 /// after normalization. Dimensions that include dynamic dimensions in the map
1657 /// output will become dynamic dimensions. Return true if `dim` is dynamic
1658 /// dimension.
1659 ///
1660 /// Example:
1661 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1662 ///
1663 /// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic.
1664 /// memref<4x?xf32, #map0>  ==>  memref<4x?x?xf32>
1665 static bool
1666 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
1667                              SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
1668   AffineExpr expr = layoutMap.getResults()[dim];
1669   // Check if affine expr of the dimension includes dynamic dimension of input
1670   // memrefType.
1671   MLIRContext *context = layoutMap.getContext();
1672   return expr
1673       .walk([&](AffineExpr e) {
1674         if (isa<AffineDimExpr>(e) &&
1675             llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
1676               return e == getAffineDimExpr(dim, context);
1677             }))
1678           return WalkResult::interrupt();
1679         return WalkResult::advance();
1680       })
1681       .wasInterrupted();
1682 }
1683 
1684 /// Create affine expr to calculate dimension size for a tiled-layout map.
1685 static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
1686                                                   TileExprPattern pat) {
1687   // Create map output for the patterns.
1688   // "floordiv <tile size>" ==> "ceildiv <tile size>"
1689   // "mod <tile size>" ==> "<tile size>"
1690   AffineExpr newMapOutput;
1691   AffineBinaryOpExpr binaryExpr = nullptr;
1692   switch (pat) {
1693   case TileExprPattern::TileMod:
1694     binaryExpr = cast<AffineBinaryOpExpr>(oldMapOutput);
1695     newMapOutput = binaryExpr.getRHS();
1696     break;
1697   case TileExprPattern::TileFloorDiv:
1698     binaryExpr = cast<AffineBinaryOpExpr>(oldMapOutput);
1699     newMapOutput = getAffineBinaryOpExpr(
1700         AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS());
1701     break;
1702   default:
1703     newMapOutput = oldMapOutput;
1704   }
1705   return newMapOutput;
1706 }
1707 
1708 /// Create new maps to calculate each dimension size of `newMemRefType`, and
1709 /// create `newDynamicSizes` from them by using AffineApplyOp.
1710 ///
1711 /// Steps for normalizing dynamic memrefs for a tiled layout map
1712 /// Example:
1713 ///    #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1714 ///    %0 = dim %arg0, %c1 :memref<4x?xf32>
1715 ///    %1 = alloc(%0) : memref<4x?xf32, #map0>
1716 ///
1717 /// (Before this function)
1718 /// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only
1719 /// single layout map is supported.
1720 ///
1721 /// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It
1722 /// is memref<4x?x?xf32> in the above example.
1723 ///
1724 /// (In this function)
1725 /// 3. Create new maps to calculate each dimension of the normalized memrefType
1726 /// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the
1727 /// dimension size can be calculated by replacing "floordiv <tile size>" with
1728 /// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>".
1729 /// - New map in the above example
1730 ///   #map0 = affine_map<(d0, d1) -> (d0)>
1731 ///   #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)>
1732 ///   #map2 = affine_map<(d0, d1) -> (32)>
1733 ///
1734 /// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp
1735 /// is used in dynamicSizes of new AllocOp.
1736 ///   %0 = dim %arg0, %c1 : memref<4x?xf32>
1737 ///   %c4 = arith.constant 4 : index
1738 ///   %1 = affine.apply #map1(%c4, %0)
1739 ///   %2 = affine.apply #map2(%c4, %0)
1740 template <typename AllocLikeOp>
1741 static void createNewDynamicSizes(MemRefType oldMemRefType,
1742                                   MemRefType newMemRefType, AffineMap map,
1743                                   AllocLikeOp *allocOp, OpBuilder b,
1744                                   SmallVectorImpl<Value> &newDynamicSizes) {
1745   // Create new input for AffineApplyOp.
1746   SmallVector<Value, 4> inAffineApply;
1747   ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape();
1748   unsigned dynIdx = 0;
1749   for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) {
1750     if (oldMemRefShape[d] < 0) {
1751       // Use dynamicSizes of allocOp for dynamic dimension.
1752       inAffineApply.emplace_back(allocOp->getDynamicSizes()[dynIdx]);
1753       dynIdx++;
1754     } else {
1755       // Create ConstantOp for static dimension.
1756       auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
1757       inAffineApply.emplace_back(
1758           b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr));
1759     }
1760   }
1761 
1762   // Create new map to calculate each dimension size of new memref for each
1763   // original map output. Only for dynamic dimesion of `newMemRefType`.
1764   unsigned newDimIdx = 0;
1765   ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape();
1766   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1767   (void)getTileSizePos(map, tileSizePos);
1768   for (AffineExpr expr : map.getResults()) {
1769     if (newMemRefShape[newDimIdx] < 0) {
1770       // Create new maps to calculate each dimension size of new memref.
1771       enum TileExprPattern pat = TileExprPattern::TileNone;
1772       for (auto pos : tileSizePos) {
1773         if (newDimIdx == std::get<1>(pos))
1774           pat = TileExprPattern::TileFloorDiv;
1775         else if (newDimIdx == std::get<2>(pos))
1776           pat = TileExprPattern::TileMod;
1777       }
1778       AffineExpr newMapOutput = createDimSizeExprForTiledLayout(expr, pat);
1779       AffineMap newMap =
1780           AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput);
1781       Value affineApp =
1782           b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply);
1783       newDynamicSizes.emplace_back(affineApp);
1784     }
1785     newDimIdx++;
1786   }
1787 }
1788 
1789 // TODO: Currently works for static memrefs with a single layout map.
1790 template <typename AllocLikeOp>
1791 LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
1792   MemRefType memrefType = allocOp->getType();
1793   OpBuilder b(*allocOp);
1794 
1795   // Fetch a new memref type after normalizing the old memref to have an
1796   // identity map layout.
1797   MemRefType newMemRefType = normalizeMemRefType(memrefType);
1798   if (newMemRefType == memrefType)
1799     // Either memrefType already had an identity map or the map couldn't be
1800     // transformed to an identity map.
1801     return failure();
1802 
1803   Value oldMemRef = allocOp->getResult();
1804 
1805   SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
1806   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1807   AllocLikeOp newAlloc;
1808   // Check if `layoutMap` is a tiled layout. Only single layout map is
1809   // supported for normalizing dynamic memrefs.
1810   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1811   (void)getTileSizePos(layoutMap, tileSizePos);
1812   if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
1813     MemRefType oldMemRefType = cast<MemRefType>(oldMemRef.getType());
1814     SmallVector<Value, 4> newDynamicSizes;
1815     createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
1816                           newDynamicSizes);
1817     // Add the new dynamic sizes in new AllocOp.
1818     newAlloc =
1819         b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, newDynamicSizes,
1820                               allocOp->getAlignmentAttr());
1821   } else {
1822     newAlloc = b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType,
1823                                      allocOp->getAlignmentAttr());
1824   }
1825   // Replace all uses of the old memref.
1826   if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
1827                                       /*extraIndices=*/{},
1828                                       /*indexRemap=*/layoutMap,
1829                                       /*extraOperands=*/{},
1830                                       /*symbolOperands=*/symbolOperands,
1831                                       /*domOpFilter=*/nullptr,
1832                                       /*postDomOpFilter=*/nullptr,
1833                                       /*allowNonDereferencingOps=*/true))) {
1834     // If it failed (due to escapes for example), bail out.
1835     newAlloc.erase();
1836     return failure();
1837   }
1838   // Replace any uses of the original alloc op and erase it. All remaining uses
1839   // have to be dealloc's; RAMUW above would've failed otherwise.
1840   assert(llvm::all_of(oldMemRef.getUsers(), [&](Operation *op) {
1841     return hasSingleEffect<MemoryEffects::Free>(op, oldMemRef);
1842   }));
1843   oldMemRef.replaceAllUsesWith(newAlloc);
1844   allocOp->erase();
1845   return success();
1846 }
1847 
1848 template LogicalResult
1849 mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
1850 template LogicalResult
1851 mlir::affine::normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
1852 
1853 MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
1854   unsigned rank = memrefType.getRank();
1855   if (rank == 0)
1856     return memrefType;
1857 
1858   if (memrefType.getLayout().isIdentity()) {
1859     // Either no maps is associated with this memref or this memref has
1860     // a trivial (identity) map.
1861     return memrefType;
1862   }
1863   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1864   unsigned numSymbolicOperands = layoutMap.getNumSymbols();
1865 
1866   // We don't do any checks for one-to-one'ness; we assume that it is
1867   // one-to-one.
1868 
1869   // Normalize only static memrefs and dynamic memrefs with a tiled-layout map
1870   // for now.
1871   // TODO: Normalize the other types of dynamic memrefs.
1872   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1873   (void)getTileSizePos(layoutMap, tileSizePos);
1874   if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
1875     return memrefType;
1876 
1877   // We have a single map that is not an identity map. Create a new memref
1878   // with the right shape and an identity layout map.
1879   ArrayRef<int64_t> shape = memrefType.getShape();
1880   // FlatAffineValueConstraint may later on use symbolicOperands.
1881   FlatAffineValueConstraints fac(rank, numSymbolicOperands);
1882   SmallVector<unsigned, 4> memrefTypeDynDims;
1883   for (unsigned d = 0; d < rank; ++d) {
1884     // Use constraint system only in static dimensions.
1885     if (shape[d] > 0) {
1886       fac.addBound(BoundType::LB, d, 0);
1887       fac.addBound(BoundType::UB, d, shape[d] - 1);
1888     } else {
1889       memrefTypeDynDims.emplace_back(d);
1890     }
1891   }
1892   // We compose this map with the original index (logical) space to derive
1893   // the upper bounds for the new index space.
1894   unsigned newRank = layoutMap.getNumResults();
1895   if (failed(fac.composeMatchingMap(layoutMap)))
1896     return memrefType;
1897   // TODO: Handle semi-affine maps.
1898   // Project out the old data dimensions.
1899   fac.projectOut(newRank, fac.getNumVars() - newRank - fac.getNumLocalVars());
1900   SmallVector<int64_t, 4> newShape(newRank);
1901   MLIRContext *context = memrefType.getContext();
1902   for (unsigned d = 0; d < newRank; ++d) {
1903     // Check if this dimension is dynamic.
1904     if (isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
1905       newShape[d] = ShapedType::kDynamic;
1906       continue;
1907     }
1908     // The lower bound for the shape is always zero.
1909     std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
1910     // For a static memref and an affine map with no symbols, this is
1911     // always bounded. However, when we have symbols, we may not be able to
1912     // obtain a constant upper bound. Also, mapping to a negative space is
1913     // invalid for normalization.
1914     if (!ubConst.has_value() || *ubConst < 0) {
1915       LLVM_DEBUG(llvm::dbgs()
1916                  << "can't normalize map due to unknown/invalid upper bound");
1917       return memrefType;
1918     }
1919     // If dimension of new memrefType is dynamic, the value is -1.
1920     newShape[d] = *ubConst + 1;
1921   }
1922 
1923   // Create the new memref type after trivializing the old layout map.
1924   auto newMemRefType =
1925       MemRefType::Builder(memrefType)
1926           .setShape(newShape)
1927           .setLayout(AffineMapAttr::get(
1928               AffineMap::getMultiDimIdentityMap(newRank, context)));
1929   return newMemRefType;
1930 }
1931 
1932 DivModValue mlir::affine::getDivMod(OpBuilder &b, Location loc, Value lhs,
1933                                     Value rhs) {
1934   DivModValue result;
1935   AffineExpr d0, d1;
1936   bindDims(b.getContext(), d0, d1);
1937   result.quotient =
1938       affine::makeComposedAffineApply(b, loc, d0.floorDiv(d1), {lhs, rhs});
1939   result.remainder =
1940       affine::makeComposedAffineApply(b, loc, d0 % d1, {lhs, rhs});
1941   return result;
1942 }
1943 
1944 /// Create an affine map that computes `lhs` * `rhs`, composing in any other
1945 /// affine maps.
1946 static FailureOr<OpFoldResult> composedAffineMultiply(OpBuilder &b,
1947                                                       Location loc,
1948                                                       OpFoldResult lhs,
1949                                                       OpFoldResult rhs) {
1950   AffineExpr s0, s1;
1951   bindSymbols(b.getContext(), s0, s1);
1952   return makeComposedFoldedAffineApply(b, loc, s0 * s1, {lhs, rhs});
1953 }
1954 
1955 FailureOr<SmallVector<Value>>
1956 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
1957                                ArrayRef<Value> basis, bool hasOuterBound) {
1958   if (hasOuterBound)
1959     basis = basis.drop_front();
1960 
1961   // Note: the divisors are backwards due to the scan.
1962   SmallVector<Value> divisors;
1963   OpFoldResult basisProd = b.getIndexAttr(1);
1964   for (OpFoldResult basisElem : llvm::reverse(basis)) {
1965     FailureOr<OpFoldResult> nextProd =
1966         composedAffineMultiply(b, loc, basisElem, basisProd);
1967     if (failed(nextProd))
1968       return failure();
1969     basisProd = *nextProd;
1970     divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, basisProd));
1971   }
1972 
1973   SmallVector<Value> results;
1974   results.reserve(divisors.size() + 1);
1975   Value residual = linearIndex;
1976   for (Value divisor : llvm::reverse(divisors)) {
1977     DivModValue divMod = getDivMod(b, loc, residual, divisor);
1978     results.push_back(divMod.quotient);
1979     residual = divMod.remainder;
1980   }
1981   results.push_back(residual);
1982   return results;
1983 }
1984 
1985 FailureOr<SmallVector<Value>>
1986 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
1987                                ArrayRef<OpFoldResult> basis,
1988                                bool hasOuterBound) {
1989   if (hasOuterBound)
1990     basis = basis.drop_front();
1991 
1992   // Note: the divisors are backwards due to the scan.
1993   SmallVector<Value> divisors;
1994   OpFoldResult basisProd = b.getIndexAttr(1);
1995   for (OpFoldResult basisElem : llvm::reverse(basis)) {
1996     FailureOr<OpFoldResult> nextProd =
1997         composedAffineMultiply(b, loc, basisElem, basisProd);
1998     if (failed(nextProd))
1999       return failure();
2000     basisProd = *nextProd;
2001     divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, basisProd));
2002   }
2003 
2004   SmallVector<Value> results;
2005   results.reserve(divisors.size() + 1);
2006   Value residual = linearIndex;
2007   for (Value divisor : llvm::reverse(divisors)) {
2008     DivModValue divMod = getDivMod(b, loc, residual, divisor);
2009     results.push_back(divMod.quotient);
2010     residual = divMod.remainder;
2011   }
2012   results.push_back(residual);
2013   return results;
2014 }
2015 
2016 OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
2017                                           ArrayRef<OpFoldResult> basis,
2018                                           ImplicitLocOpBuilder &builder) {
2019   return linearizeIndex(builder, builder.getLoc(), multiIndex, basis);
2020 }
2021 
2022 OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
2023                                           ArrayRef<OpFoldResult> multiIndex,
2024                                           ArrayRef<OpFoldResult> basis) {
2025   assert(multiIndex.size() == basis.size() ||
2026          multiIndex.size() == basis.size() + 1);
2027   SmallVector<AffineExpr> basisAffine;
2028 
2029   // Add a fake initial size in order to make the later index linearization
2030   // computations line up if an outer bound is not provided.
2031   if (multiIndex.size() == basis.size() + 1)
2032     basisAffine.push_back(getAffineConstantExpr(1, builder.getContext()));
2033 
2034   for (size_t i = 0; i < basis.size(); ++i) {
2035     basisAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
2036   }
2037 
2038   SmallVector<AffineExpr> stridesAffine = computeStrides(basisAffine);
2039   SmallVector<OpFoldResult> strides;
2040   strides.reserve(stridesAffine.size());
2041   llvm::transform(stridesAffine, std::back_inserter(strides),
2042                   [&builder, &basis, loc](AffineExpr strideExpr) {
2043                     return affine::makeComposedFoldedAffineApply(
2044                         builder, loc, strideExpr, basis);
2045                   });
2046 
2047   auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
2048       OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
2049   return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr,
2050                                                multiIndexAndStrides);
2051 }
2052