xref: /llvm-project/mlir/lib/IR/AffineExpr.cpp (revision 8272b6bd6146aab973ff7018ad642b99fde00904)
1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cmath>
10 #include <cstdint>
11 #include <limits>
12 #include <utility>
13 
14 #include "AffineExprDetail.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineExprVisitor.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/IntegerSet.h"
19 #include "mlir/Support/TypeID.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/MathExtras.h"
22 #include <numeric>
23 #include <optional>
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 using llvm::divideCeilSigned;
29 using llvm::divideFloorSigned;
30 using llvm::divideSignedWouldOverflow;
31 using llvm::mod;
32 
33 MLIRContext *AffineExpr::getContext() const { return expr->context; }
34 
35 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
36 
37 /// Walk all of the AffineExprs in `e` in postorder. This is a private factory
38 /// method to help handle lambda walk functions. Users should use the regular
39 /// (non-static) `walk` method.
40 template <typename WalkRetTy>
41 WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
42                                  function_ref<WalkRetTy(AffineExpr)> callback) {
43   struct AffineExprWalker
44       : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
45     function_ref<WalkRetTy(AffineExpr)> callback;
46 
47     AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
48         : callback(callback) {}
49 
50     WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
51       return callback(expr);
52     }
53     WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
54       return callback(expr);
55     }
56     WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
57     WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
58   };
59 
60   return AffineExprWalker(callback).walkPostOrder(e);
61 }
62 // Explicitly instantiate for the two supported return types.
63 template void mlir::AffineExpr::walk(AffineExpr e,
64                                      function_ref<void(AffineExpr)> callback);
65 template WalkResult
66 mlir::AffineExpr::walk(AffineExpr e,
67                        function_ref<WalkResult(AffineExpr)> callback);
68 
69 // Dispatch affine expression construction based on kind.
70 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
71                                        AffineExpr rhs) {
72   if (kind == AffineExprKind::Add)
73     return lhs + rhs;
74   if (kind == AffineExprKind::Mul)
75     return lhs * rhs;
76   if (kind == AffineExprKind::FloorDiv)
77     return lhs.floorDiv(rhs);
78   if (kind == AffineExprKind::CeilDiv)
79     return lhs.ceilDiv(rhs);
80   if (kind == AffineExprKind::Mod)
81     return lhs % rhs;
82 
83   llvm_unreachable("unknown binary operation on affine expressions");
84 }
85 
86 /// This method substitutes any uses of dimensions and symbols (e.g.
87 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
88 AffineExpr
89 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
90                                   ArrayRef<AffineExpr> symReplacements) const {
91   switch (getKind()) {
92   case AffineExprKind::Constant:
93     return *this;
94   case AffineExprKind::DimId: {
95     unsigned dimId = llvm::cast<AffineDimExpr>(*this).getPosition();
96     if (dimId >= dimReplacements.size())
97       return *this;
98     return dimReplacements[dimId];
99   }
100   case AffineExprKind::SymbolId: {
101     unsigned symId = llvm::cast<AffineSymbolExpr>(*this).getPosition();
102     if (symId >= symReplacements.size())
103       return *this;
104     return symReplacements[symId];
105   }
106   case AffineExprKind::Add:
107   case AffineExprKind::Mul:
108   case AffineExprKind::FloorDiv:
109   case AffineExprKind::CeilDiv:
110   case AffineExprKind::Mod:
111     auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
112     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
113     auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
114     auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
115     if (newLHS == lhs && newRHS == rhs)
116       return *this;
117     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
118   }
119   llvm_unreachable("Unknown AffineExpr");
120 }
121 
122 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
123   return replaceDimsAndSymbols(dimReplacements, {});
124 }
125 
126 AffineExpr
127 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
128   return replaceDimsAndSymbols({}, symReplacements);
129 }
130 
131 /// Replace dims[offset ... numDims)
132 /// by dims[offset + shift ... shift + numDims).
133 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
134                                  unsigned offset) const {
135   SmallVector<AffineExpr, 4> dims;
136   for (unsigned idx = 0; idx < offset; ++idx)
137     dims.push_back(getAffineDimExpr(idx, getContext()));
138   for (unsigned idx = offset; idx < numDims; ++idx)
139     dims.push_back(getAffineDimExpr(idx + shift, getContext()));
140   return replaceDimsAndSymbols(dims, {});
141 }
142 
143 /// Replace symbols[offset ... numSymbols)
144 /// by symbols[offset + shift ... shift + numSymbols).
145 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
146                                     unsigned offset) const {
147   SmallVector<AffineExpr, 4> symbols;
148   for (unsigned idx = 0; idx < offset; ++idx)
149     symbols.push_back(getAffineSymbolExpr(idx, getContext()));
150   for (unsigned idx = offset; idx < numSymbols; ++idx)
151     symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
152   return replaceDimsAndSymbols({}, symbols);
153 }
154 
155 /// Sparse replace method. Return the modified expression tree.
156 AffineExpr
157 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
158   auto it = map.find(*this);
159   if (it != map.end())
160     return it->second;
161   switch (getKind()) {
162   default:
163     return *this;
164   case AffineExprKind::Add:
165   case AffineExprKind::Mul:
166   case AffineExprKind::FloorDiv:
167   case AffineExprKind::CeilDiv:
168   case AffineExprKind::Mod:
169     auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
170     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
171     auto newLHS = lhs.replace(map);
172     auto newRHS = rhs.replace(map);
173     if (newLHS == lhs && newRHS == rhs)
174       return *this;
175     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
176   }
177   llvm_unreachable("Unknown AffineExpr");
178 }
179 
180 /// Sparse replace method. Return the modified expression tree.
181 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
182   DenseMap<AffineExpr, AffineExpr> map;
183   map.insert(std::make_pair(expr, replacement));
184   return replace(map);
185 }
186 /// Returns true if this expression is made out of only symbols and
187 /// constants (no dimensional identifiers).
188 bool AffineExpr::isSymbolicOrConstant() const {
189   switch (getKind()) {
190   case AffineExprKind::Constant:
191     return true;
192   case AffineExprKind::DimId:
193     return false;
194   case AffineExprKind::SymbolId:
195     return true;
196 
197   case AffineExprKind::Add:
198   case AffineExprKind::Mul:
199   case AffineExprKind::FloorDiv:
200   case AffineExprKind::CeilDiv:
201   case AffineExprKind::Mod: {
202     auto expr = llvm::cast<AffineBinaryOpExpr>(*this);
203     return expr.getLHS().isSymbolicOrConstant() &&
204            expr.getRHS().isSymbolicOrConstant();
205   }
206   }
207   llvm_unreachable("Unknown AffineExpr");
208 }
209 
210 /// Returns true if this is a pure affine expression, i.e., multiplication,
211 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
212 bool AffineExpr::isPureAffine() const {
213   switch (getKind()) {
214   case AffineExprKind::SymbolId:
215   case AffineExprKind::DimId:
216   case AffineExprKind::Constant:
217     return true;
218   case AffineExprKind::Add: {
219     auto op = llvm::cast<AffineBinaryOpExpr>(*this);
220     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
221   }
222 
223   case AffineExprKind::Mul: {
224     // TODO: Canonicalize the constants in binary operators to the RHS when
225     // possible, allowing this to merge into the next case.
226     auto op = llvm::cast<AffineBinaryOpExpr>(*this);
227     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
228            (llvm::isa<AffineConstantExpr>(op.getLHS()) ||
229             llvm::isa<AffineConstantExpr>(op.getRHS()));
230   }
231   case AffineExprKind::FloorDiv:
232   case AffineExprKind::CeilDiv:
233   case AffineExprKind::Mod: {
234     auto op = llvm::cast<AffineBinaryOpExpr>(*this);
235     return op.getLHS().isPureAffine() &&
236            llvm::isa<AffineConstantExpr>(op.getRHS());
237   }
238   }
239   llvm_unreachable("Unknown AffineExpr");
240 }
241 
242 // Returns the greatest known integral divisor of this affine expression.
243 int64_t AffineExpr::getLargestKnownDivisor() const {
244   AffineBinaryOpExpr binExpr(nullptr);
245   switch (getKind()) {
246   case AffineExprKind::DimId:
247     [[fallthrough]];
248   case AffineExprKind::SymbolId:
249     return 1;
250   case AffineExprKind::CeilDiv:
251     [[fallthrough]];
252   case AffineExprKind::FloorDiv: {
253     // If the RHS is a constant and divides the known divisor on the LHS, the
254     // quotient is a known divisor of the expression.
255     binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
256     auto rhs = llvm::dyn_cast<AffineConstantExpr>(binExpr.getRHS());
257     // Leave alone undefined expressions.
258     if (rhs && rhs.getValue() != 0) {
259       int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
260       if (lhsDiv % rhs.getValue() == 0)
261         return std::abs(lhsDiv / rhs.getValue());
262     }
263     return 1;
264   }
265   case AffineExprKind::Constant:
266     return std::abs(llvm::cast<AffineConstantExpr>(*this).getValue());
267   case AffineExprKind::Mul: {
268     binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
269     return binExpr.getLHS().getLargestKnownDivisor() *
270            binExpr.getRHS().getLargestKnownDivisor();
271   }
272   case AffineExprKind::Add:
273     [[fallthrough]];
274   case AffineExprKind::Mod: {
275     binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
276     return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
277                     (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
278   }
279   }
280   llvm_unreachable("Unknown AffineExpr");
281 }
282 
283 bool AffineExpr::isMultipleOf(int64_t factor) const {
284   AffineBinaryOpExpr binExpr(nullptr);
285   uint64_t l, u;
286   switch (getKind()) {
287   case AffineExprKind::SymbolId:
288     [[fallthrough]];
289   case AffineExprKind::DimId:
290     return factor * factor == 1;
291   case AffineExprKind::Constant:
292     return llvm::cast<AffineConstantExpr>(*this).getValue() % factor == 0;
293   case AffineExprKind::Mul: {
294     binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
295     // It's probably not worth optimizing this further (to not traverse the
296     // whole sub-tree under - it that would require a version of isMultipleOf
297     // that on a 'false' return also returns the largest known divisor).
298     return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
299            (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
300            (l * u) % factor == 0;
301   }
302   case AffineExprKind::Add:
303   case AffineExprKind::FloorDiv:
304   case AffineExprKind::CeilDiv:
305   case AffineExprKind::Mod: {
306     binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
307     return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
308                     (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
309                factor ==
310            0;
311   }
312   }
313   llvm_unreachable("Unknown AffineExpr");
314 }
315 
316 bool AffineExpr::isFunctionOfDim(unsigned position) const {
317   if (getKind() == AffineExprKind::DimId) {
318     return *this == mlir::getAffineDimExpr(position, getContext());
319   }
320   if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
321     return expr.getLHS().isFunctionOfDim(position) ||
322            expr.getRHS().isFunctionOfDim(position);
323   }
324   return false;
325 }
326 
327 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
328   if (getKind() == AffineExprKind::SymbolId) {
329     return *this == mlir::getAffineSymbolExpr(position, getContext());
330   }
331   if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
332     return expr.getLHS().isFunctionOfSymbol(position) ||
333            expr.getRHS().isFunctionOfSymbol(position);
334   }
335   return false;
336 }
337 
338 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
339     : AffineExpr(ptr) {}
340 AffineExpr AffineBinaryOpExpr::getLHS() const {
341   return static_cast<ImplType *>(expr)->lhs;
342 }
343 AffineExpr AffineBinaryOpExpr::getRHS() const {
344   return static_cast<ImplType *>(expr)->rhs;
345 }
346 
347 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
348 unsigned AffineDimExpr::getPosition() const {
349   return static_cast<ImplType *>(expr)->position;
350 }
351 
352 /// Returns true if the expression is divisible by the given symbol with
353 /// position `symbolPos`. The argument `opKind` specifies here what kind of
354 /// division or mod operation called this division. It helps in implementing the
355 /// commutative property of the floordiv and ceildiv operations. If the argument
356 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
357 /// operation, then the commutative property can be used otherwise, the floordiv
358 /// operation is not divisible. The same argument holds for ceildiv operation.
359 static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
360                                         AffineExprKind opKind,
361                                         bool fromMul = false) {
362   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
363   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
364           opKind == AffineExprKind::CeilDiv) &&
365          "unexpected opKind");
366   switch (expr.getKind()) {
367   case AffineExprKind::Constant:
368     return cast<AffineConstantExpr>(expr).getValue() == 0;
369   case AffineExprKind::DimId:
370     return false;
371   case AffineExprKind::SymbolId:
372     return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
373   // Checks divisibility by the given symbol for both operands.
374   case AffineExprKind::Add: {
375     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
376     return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
377                                        opKind) &&
378            canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
379   }
380   // Checks divisibility by the given symbol for both operands. Consider the
381   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
382   // this is a division by s1 and both the operands of modulo are divisible by
383   // s1 but it is not divisible by s1 always. The third argument is
384   // `AffineExprKind::Mod` for this reason.
385   case AffineExprKind::Mod: {
386     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
387     return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
388                                        AffineExprKind::Mod) &&
389            canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos,
390                                        AffineExprKind::Mod);
391   }
392   // Checks if any of the operand divisible by the given symbol.
393   case AffineExprKind::Mul: {
394     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
395     return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind,
396                                        true) ||
397            canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind,
398                                        true);
399   }
400   // Floordiv and ceildiv are divisible by the given symbol when the first
401   // operand is divisible, and the affine expression kind of the argument expr
402   // is same as the argument `opKind`. This can be inferred from commutative
403   // property of floordiv and ceildiv operations and are as follow:
404   // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
405   // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
406   // It will fail 1.if operations are not same. For example:
407   // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
408   // multiplication operation in the expression. For example:
409   //  (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
410   case AffineExprKind::FloorDiv:
411   case AffineExprKind::CeilDiv: {
412     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
413     if (opKind != expr.getKind())
414       return false;
415     if (fromMul)
416       return false;
417     return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
418                                        expr.getKind());
419   }
420   }
421   llvm_unreachable("Unknown AffineExpr");
422 }
423 
424 /// Divides the given expression by the given symbol at position `symbolPos`. It
425 /// considers the divisibility condition is checked before calling itself. A
426 /// null expression is returned whenever the divisibility condition fails.
427 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
428                                  AffineExprKind opKind) {
429   // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
430   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
431           opKind == AffineExprKind::CeilDiv) &&
432          "unexpected opKind");
433   switch (expr.getKind()) {
434   case AffineExprKind::Constant:
435     if (cast<AffineConstantExpr>(expr).getValue() != 0)
436       return nullptr;
437     return getAffineConstantExpr(0, expr.getContext());
438   case AffineExprKind::DimId:
439     return nullptr;
440   case AffineExprKind::SymbolId:
441     return getAffineConstantExpr(1, expr.getContext());
442   // Dividing both operands by the given symbol.
443   case AffineExprKind::Add: {
444     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
445     return getAffineBinaryOpExpr(
446         expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
447         symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
448   }
449   // Dividing both operands by the given symbol.
450   case AffineExprKind::Mod: {
451     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
452     return getAffineBinaryOpExpr(
453         expr.getKind(),
454         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
455         symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
456   }
457   // Dividing any of the operand by the given symbol.
458   case AffineExprKind::Mul: {
459     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
460     if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
461       return binaryExpr.getLHS() *
462              symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
463     return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
464            binaryExpr.getRHS();
465   }
466   // Dividing first operand only by the given symbol.
467   case AffineExprKind::FloorDiv:
468   case AffineExprKind::CeilDiv: {
469     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
470     return getAffineBinaryOpExpr(
471         expr.getKind(),
472         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
473         binaryExpr.getRHS());
474   }
475   }
476   llvm_unreachable("Unknown AffineExpr");
477 }
478 
479 /// Populate `result` with all summand operands of given (potentially nested)
480 /// addition. If the given expression is not an addition, just populate the
481 /// expression itself.
482 /// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
483 static void getSummandExprs(AffineExpr expr, SmallVector<AffineExpr> &result) {
484   auto addExpr = dyn_cast<AffineBinaryOpExpr>(expr);
485   if (!addExpr || addExpr.getKind() != AffineExprKind::Add) {
486     result.push_back(expr);
487     return;
488   }
489   getSummandExprs(addExpr.getLHS(), result);
490   getSummandExprs(addExpr.getRHS(), result);
491 }
492 
493 /// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
494 /// If so, also return the non-negated expression via `expr`.
495 static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) {
496   auto mulExpr = dyn_cast<AffineBinaryOpExpr>(candidate);
497   if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
498     return false;
499   if (auto lhs = dyn_cast<AffineConstantExpr>(mulExpr.getLHS())) {
500     if (lhs.getValue() == -1) {
501       expr = mulExpr.getRHS();
502       return true;
503     }
504   }
505   if (auto rhs = dyn_cast<AffineConstantExpr>(mulExpr.getRHS())) {
506     if (rhs.getValue() == -1) {
507       expr = mulExpr.getLHS();
508       return true;
509     }
510   }
511   return false;
512 }
513 
514 /// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
515 /// the fact that `lhs` contains another modulo expression that ensures that
516 /// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
517 /// after loop peeling.
518 ///
519 /// Example: lhs = ub - ub % step
520 ///          rhs = step
521 ///       => (ub - ub % step) % step is guaranteed to evaluate to 0.
522 static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs,
523                                   unsigned numDims, unsigned numSymbols) {
524   // TODO: Try to unify this function with `getBoundForAffineExpr`.
525   // Collect all summands in lhs.
526   SmallVector<AffineExpr> summands;
527   getSummandExprs(lhs, summands);
528   // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
529   // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
530   for (int64_t i = 0, e = summands.size(); i < e; ++i) {
531     AffineExpr current = summands[i];
532     AffineExpr beforeNegation;
533     if (!isNegatedAffineExpr(current, beforeNegation))
534       continue;
535     AffineBinaryOpExpr innerMod = dyn_cast<AffineBinaryOpExpr>(beforeNegation);
536     if (!innerMod || innerMod.getKind() != AffineExprKind::Mod)
537       continue;
538     if (innerMod.getRHS() != rhs)
539       continue;
540     // Sum all remaining summands and subtract x. If that expression can be
541     // simplified to zero, then the remaining summands and x are equal.
542     AffineExpr diff = getAffineConstantExpr(0, lhs.getContext());
543     for (int64_t j = 0; j < e; ++j)
544       if (i != j)
545         diff = diff + summands[j];
546     diff = diff - innerMod.getLHS();
547     diff = simplifyAffineExpr(diff, numDims, numSymbols);
548     auto constExpr = dyn_cast<AffineConstantExpr>(diff);
549     if (constExpr && constExpr.getValue() == 0)
550       return true;
551   }
552   return false;
553 }
554 
555 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
556 /// operations when the second operand simplifies to a symbol and the first
557 /// operand is divisible by that symbol. It can be applied to any semi-affine
558 /// expression. Returned expression can either be a semi-affine or pure affine
559 /// expression.
560 static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
561                                      unsigned numSymbols) {
562   switch (expr.getKind()) {
563   case AffineExprKind::Constant:
564   case AffineExprKind::DimId:
565   case AffineExprKind::SymbolId:
566     return expr;
567   case AffineExprKind::Add:
568   case AffineExprKind::Mul: {
569     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
570     return getAffineBinaryOpExpr(
571         expr.getKind(),
572         simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols),
573         simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
574   }
575   // Check if the simplification of the second operand is a symbol, and the
576   // first operand is divisible by it. If the operation is a modulo, a constant
577   // zero expression is returned. In the case of floordiv and ceildiv, the
578   // symbol from the simplification of the second operand divides the first
579   // operand. Otherwise, simplification is not possible.
580   case AffineExprKind::FloorDiv:
581   case AffineExprKind::CeilDiv:
582   case AffineExprKind::Mod: {
583     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
584     AffineExpr sLHS =
585         simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols);
586     AffineExpr sRHS =
587         simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols);
588     if (isModOfModSubtraction(sLHS, sRHS, numDims, numSymbols))
589       return getAffineConstantExpr(0, expr.getContext());
590     AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>(
591         simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
592     if (!symbolExpr)
593       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
594     unsigned symbolPos = symbolExpr.getPosition();
595     if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
596                                      expr.getKind()))
597       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
598     if (expr.getKind() == AffineExprKind::Mod)
599       return getAffineConstantExpr(0, expr.getContext());
600     return symbolicDivide(sLHS, symbolPos, expr.getKind());
601   }
602   }
603   llvm_unreachable("Unknown AffineExpr");
604 }
605 
606 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
607                                        MLIRContext *context) {
608   auto assignCtx = [context](AffineDimExprStorage *storage) {
609     storage->context = context;
610   };
611 
612   StorageUniquer &uniquer = context->getAffineUniquer();
613   return uniquer.get<AffineDimExprStorage>(
614       assignCtx, static_cast<unsigned>(kind), position);
615 }
616 
617 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
618   return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
619 }
620 
621 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
622     : AffineExpr(ptr) {}
623 unsigned AffineSymbolExpr::getPosition() const {
624   return static_cast<ImplType *>(expr)->position;
625 }
626 
627 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
628   return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
629 }
630 
631 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
632     : AffineExpr(ptr) {}
633 int64_t AffineConstantExpr::getValue() const {
634   return static_cast<ImplType *>(expr)->constant;
635 }
636 
637 bool AffineExpr::operator==(int64_t v) const {
638   return *this == getAffineConstantExpr(v, getContext());
639 }
640 
641 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
642   auto assignCtx = [context](AffineConstantExprStorage *storage) {
643     storage->context = context;
644   };
645 
646   StorageUniquer &uniquer = context->getAffineUniquer();
647   return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
648 }
649 
650 SmallVector<AffineExpr>
651 mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
652                              MLIRContext *context) {
653   return llvm::to_vector(llvm::map_range(constants, [&](int64_t constant) {
654     return getAffineConstantExpr(constant, context);
655   }));
656 }
657 
658 /// Simplify add expression. Return nullptr if it can't be simplified.
659 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
660   auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
661   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
662   // Fold if both LHS, RHS are a constant and the sum does not overflow.
663   if (lhsConst && rhsConst) {
664     int64_t sum;
665     if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
666       return nullptr;
667     }
668     return getAffineConstantExpr(sum, lhs.getContext());
669   }
670 
671   // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
672   // If only one of them is a symbolic expressions, make it the RHS.
673   if (isa<AffineConstantExpr>(lhs) ||
674       (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
675     return rhs + lhs;
676   }
677 
678   // At this point, if there was a constant, it would be on the right.
679 
680   // Addition with a zero is a noop, return the other input.
681   if (rhsConst) {
682     if (rhsConst.getValue() == 0)
683       return lhs;
684   }
685   // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
686   auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
687   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
688     if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
689       return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
690   }
691 
692   // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
693   // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
694   // respective multiplicands.
695   std::optional<int64_t> rLhsConst, rRhsConst;
696   AffineExpr firstExpr, secondExpr;
697   AffineConstantExpr rLhsConstExpr;
698   auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lhs);
699   if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
700       (rLhsConstExpr = dyn_cast<AffineConstantExpr>(lBinOpExpr.getRHS()))) {
701     rLhsConst = rLhsConstExpr.getValue();
702     firstExpr = lBinOpExpr.getLHS();
703   } else {
704     rLhsConst = 1;
705     firstExpr = lhs;
706   }
707 
708   auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(rhs);
709   AffineConstantExpr rRhsConstExpr;
710   if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
711       (rRhsConstExpr = dyn_cast<AffineConstantExpr>(rBinOpExpr.getRHS()))) {
712     rRhsConst = rRhsConstExpr.getValue();
713     secondExpr = rBinOpExpr.getLHS();
714   } else {
715     rRhsConst = 1;
716     secondExpr = rhs;
717   }
718 
719   if (rLhsConst && rRhsConst && firstExpr == secondExpr)
720     return getAffineBinaryOpExpr(
721         AffineExprKind::Mul, firstExpr,
722         getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext()));
723 
724   // When doing successive additions, bring constant to the right: turn (d0 + 2)
725   // + d1 into (d0 + d1) + 2.
726   if (lBin && lBin.getKind() == AffineExprKind::Add) {
727     if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
728       return lBin.getLHS() + rhs + lrhs;
729     }
730   }
731 
732   // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
733   // q may be a constant or symbolic expression. This leads to a much more
734   // efficient form when 'c' is a power of two, and in general a more compact
735   // and readable form.
736 
737   // Process '(expr floordiv c) * (-c)'.
738   if (!rBinOpExpr)
739     return nullptr;
740 
741   auto lrhs = rBinOpExpr.getLHS();
742   auto rrhs = rBinOpExpr.getRHS();
743 
744   AffineExpr llrhs, rlrhs;
745 
746   // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
747   // symbolic expression.
748   auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
749   // Check rrhsConstOpExpr = -1.
750   auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rrhs);
751   if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
752       lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
753     // Check llrhs = expr floordiv q.
754     llrhs = lrhsBinOpExpr.getLHS();
755     // Check rlrhs = q.
756     rlrhs = lrhsBinOpExpr.getRHS();
757     auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(llrhs);
758     if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
759       return nullptr;
760     if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
761       return lhs % rlrhs;
762   }
763 
764   // Process lrhs, which is 'expr floordiv c'.
765   // expr + (expr // c * -c) = expr % c
766   AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
767   if (!lrBinOpExpr || rhs.getKind() != AffineExprKind::Mul ||
768       lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
769     return nullptr;
770 
771   llrhs = lrBinOpExpr.getLHS();
772   rlrhs = lrBinOpExpr.getRHS();
773   auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rlrhs);
774   // We don't support modulo with a negative RHS.
775   bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0;
776 
777   if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) {
778     return lhs % rlrhs;
779   }
780   return nullptr;
781 }
782 
783 AffineExpr AffineExpr::operator+(int64_t v) const {
784   return *this + getAffineConstantExpr(v, getContext());
785 }
786 AffineExpr AffineExpr::operator+(AffineExpr other) const {
787   if (auto simplified = simplifyAdd(*this, other))
788     return simplified;
789 
790   StorageUniquer &uniquer = getContext()->getAffineUniquer();
791   return uniquer.get<AffineBinaryOpExprStorage>(
792       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
793 }
794 
795 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
796 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
797   auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
798   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
799 
800   if (lhsConst && rhsConst) {
801     int64_t product;
802     if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) {
803       return nullptr;
804     }
805     return getAffineConstantExpr(product, lhs.getContext());
806   }
807 
808   if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
809     return nullptr;
810 
811   // Canonicalize the mul expression so that the constant/symbolic term is the
812   // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
813   // constant. (Note that a constant is trivially symbolic).
814   if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
815     // At least one of them has to be symbolic.
816     return rhs * lhs;
817   }
818 
819   // At this point, if there was a constant, it would be on the right.
820 
821   // Multiplication with a one is a noop, return the other input.
822   if (rhsConst) {
823     if (rhsConst.getValue() == 1)
824       return lhs;
825     // Multiplication with zero.
826     if (rhsConst.getValue() == 0)
827       return rhsConst;
828   }
829 
830   // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
831   auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
832   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
833     if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
834       return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
835   }
836 
837   // When doing successive multiplication, bring constant to the right: turn (d0
838   // * 2) * d1 into (d0 * d1) * 2.
839   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
840     if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
841       return (lBin.getLHS() * rhs) * lrhs;
842     }
843   }
844 
845   return nullptr;
846 }
847 
848 AffineExpr AffineExpr::operator*(int64_t v) const {
849   return *this * getAffineConstantExpr(v, getContext());
850 }
851 AffineExpr AffineExpr::operator*(AffineExpr other) const {
852   if (auto simplified = simplifyMul(*this, other))
853     return simplified;
854 
855   StorageUniquer &uniquer = getContext()->getAffineUniquer();
856   return uniquer.get<AffineBinaryOpExprStorage>(
857       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
858 }
859 
860 // Unary minus, delegate to operator*.
861 AffineExpr AffineExpr::operator-() const {
862   return *this * getAffineConstantExpr(-1, getContext());
863 }
864 
865 // Delegate to operator+.
866 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
867 AffineExpr AffineExpr::operator-(AffineExpr other) const {
868   return *this + (-other);
869 }
870 
871 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
872   auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
873   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
874 
875   if (!rhsConst || rhsConst.getValue() == 0)
876     return nullptr;
877 
878   if (lhsConst) {
879     if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
880       return nullptr;
881     return getAffineConstantExpr(
882         divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
883         lhs.getContext());
884   }
885 
886   // Fold floordiv of a multiply with a constant that is a multiple of the
887   // divisor. Eg: (i * 128) floordiv 64 = i * 2.
888   if (rhsConst == 1)
889     return lhs;
890 
891   // Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
892   // multiple of `rhsConst`.
893   auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
894   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
895     if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
896       // `rhsConst` is known to be a nonzero constant.
897       if (lrhs.getValue() % rhsConst.getValue() == 0)
898         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
899     }
900   }
901 
902   // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
903   // known to be a multiple of divConst.
904   if (lBin && lBin.getKind() == AffineExprKind::Add) {
905     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
906     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
907     // rhsConst is known to be a nonzero constant.
908     if (llhsDiv % rhsConst.getValue() == 0 ||
909         lrhsDiv % rhsConst.getValue() == 0)
910       return lBin.getLHS().floorDiv(rhsConst.getValue()) +
911              lBin.getRHS().floorDiv(rhsConst.getValue());
912   }
913 
914   return nullptr;
915 }
916 
917 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
918   return floorDiv(getAffineConstantExpr(v, getContext()));
919 }
920 AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
921   if (auto simplified = simplifyFloorDiv(*this, other))
922     return simplified;
923 
924   StorageUniquer &uniquer = getContext()->getAffineUniquer();
925   return uniquer.get<AffineBinaryOpExprStorage>(
926       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
927       other);
928 }
929 
930 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
931   auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
932   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
933 
934   if (!rhsConst || rhsConst.getValue() == 0)
935     return nullptr;
936 
937   if (lhsConst) {
938     if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
939       return nullptr;
940     return getAffineConstantExpr(
941         divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
942         lhs.getContext());
943   }
944 
945   // Fold ceildiv of a multiply with a constant that is a multiple of the
946   // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
947   if (rhsConst.getValue() == 1)
948     return lhs;
949 
950   // Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
951   // multiple of `rhsConst`.
952   auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
953   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
954     if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
955       // `rhsConst` is known to be a nonzero constant.
956       if (lrhs.getValue() % rhsConst.getValue() == 0)
957         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
958     }
959   }
960 
961   return nullptr;
962 }
963 
964 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
965   return ceilDiv(getAffineConstantExpr(v, getContext()));
966 }
967 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
968   if (auto simplified = simplifyCeilDiv(*this, other))
969     return simplified;
970 
971   StorageUniquer &uniquer = getContext()->getAffineUniquer();
972   return uniquer.get<AffineBinaryOpExprStorage>(
973       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
974       other);
975 }
976 
977 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
978   auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
979   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
980 
981   // mod w.r.t zero or negative numbers is undefined and preserved as is.
982   if (!rhsConst || rhsConst.getValue() < 1)
983     return nullptr;
984 
985   if (lhsConst) {
986     // mod never overflows.
987     return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
988                                  lhs.getContext());
989   }
990 
991   // Fold modulo of an expression that is known to be a multiple of a constant
992   // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
993   // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
994   if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
995     return getAffineConstantExpr(0, lhs.getContext());
996 
997   // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
998   // known to be a multiple of divConst.
999   auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
1000   if (lBin && lBin.getKind() == AffineExprKind::Add) {
1001     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
1002     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
1003     // rhsConst is known to be a positive constant.
1004     if (llhsDiv % rhsConst.getValue() == 0)
1005       return lBin.getRHS() % rhsConst.getValue();
1006     if (lrhsDiv % rhsConst.getValue() == 0)
1007       return lBin.getLHS() % rhsConst.getValue();
1008   }
1009 
1010   // Simplify (e % a) % b to e % b when b evenly divides a
1011   if (lBin && lBin.getKind() == AffineExprKind::Mod) {
1012     auto intermediate = dyn_cast<AffineConstantExpr>(lBin.getRHS());
1013     if (intermediate && intermediate.getValue() >= 1 &&
1014         mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
1015       return lBin.getLHS() % rhsConst.getValue();
1016     }
1017   }
1018 
1019   return nullptr;
1020 }
1021 
1022 AffineExpr AffineExpr::operator%(uint64_t v) const {
1023   return *this % getAffineConstantExpr(v, getContext());
1024 }
1025 AffineExpr AffineExpr::operator%(AffineExpr other) const {
1026   if (auto simplified = simplifyMod(*this, other))
1027     return simplified;
1028 
1029   StorageUniquer &uniquer = getContext()->getAffineUniquer();
1030   return uniquer.get<AffineBinaryOpExprStorage>(
1031       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
1032 }
1033 
1034 AffineExpr AffineExpr::compose(AffineMap map) const {
1035   SmallVector<AffineExpr, 8> dimReplacements(map.getResults());
1036   return replaceDimsAndSymbols(dimReplacements, {});
1037 }
1038 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
1039   expr.print(os);
1040   return os;
1041 }
1042 
1043 /// Constructs an affine expression from a flat ArrayRef. If there are local
1044 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
1045 /// products expression, `localExprs` is expected to have the AffineExpr
1046 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
1047 /// in the format [dims, symbols, locals, constant term].
1048 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1049                                            unsigned numDims,
1050                                            unsigned numSymbols,
1051                                            ArrayRef<AffineExpr> localExprs,
1052                                            MLIRContext *context) {
1053   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1054   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1055          "unexpected number of local expressions");
1056 
1057   auto expr = getAffineConstantExpr(0, context);
1058   // Dimensions and symbols.
1059   for (unsigned j = 0; j < numDims + numSymbols; j++) {
1060     if (flatExprs[j] == 0)
1061       continue;
1062     auto id = j < numDims ? getAffineDimExpr(j, context)
1063                           : getAffineSymbolExpr(j - numDims, context);
1064     expr = expr + id * flatExprs[j];
1065   }
1066 
1067   // Local identifiers.
1068   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1069        j++) {
1070     if (flatExprs[j] == 0)
1071       continue;
1072     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1073     expr = expr + term;
1074   }
1075 
1076   // Constant term.
1077   int64_t constTerm = flatExprs[flatExprs.size() - 1];
1078   if (constTerm != 0)
1079     expr = expr + constTerm;
1080   return expr;
1081 }
1082 
1083 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
1084 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
1085 /// of products expression, `localExprs` is expected to have the AffineExprs for
1086 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
1087 /// the format [dims, symbols, locals, constant term]. The semi-affine
1088 /// expression is constructed in the sorted order of dimension and symbol
1089 /// position numbers. Note:  local expressions/ids are used for mod, div as well
1090 /// as symbolic RHS terms for terms that are not pure affine.
1091 static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1092                                                 unsigned numDims,
1093                                                 unsigned numSymbols,
1094                                                 ArrayRef<AffineExpr> localExprs,
1095                                                 MLIRContext *context) {
1096   assert(!flatExprs.empty() && "flatExprs cannot be empty");
1097 
1098   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1099   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1100          "unexpected number of local expressions");
1101 
1102   AffineExpr expr = getAffineConstantExpr(0, context);
1103 
1104   // We design indices as a pair which help us present the semi-affine map as
1105   // sum of product where terms are sorted based on dimension or symbol
1106   // position: <keyA, keyB> for expressions of the form dimension * symbol,
1107   // where keyA is the position number of the dimension and keyB is the
1108   // position number of the symbol. For dimensional expressions we set the index
1109   // as (position number of the dimension, -1), as we want dimensional
1110   // expressions to appear before symbolic and product of dimensional and
1111   // symbolic expressions having the dimension with the same position number.
1112   // For symbolic expression set the index as (position number of the symbol,
1113   // maximum of last dimension and symbol position) number. For example, we want
1114   // the expression we are constructing to look something like: d0 + d0 * s0 +
1115   // s0 + d1*s1 + s1.
1116 
1117   // Stores the affine expression corresponding to a given index.
1118   DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap;
1119   // Stores the constant coefficient value corresponding to a given
1120   // dimension, symbol or a non-pure affine expression stored in `localExprs`.
1121   DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
1122   // Stores the indices as defined above, and later sorted to produce
1123   // the semi-affine expression in the desired form.
1124   SmallVector<std::pair<unsigned, signed>, 8> indices;
1125 
1126   // Example: expression = d0 + d0 * s0 + 2 * s0.
1127   // indices = [{0,-1}, {0, 0}, {0, 1}]
1128   // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
1129   // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
1130 
1131   // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
1132   auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
1133                       AffineExpr expr) {
1134     assert(!llvm::is_contained(indices, index) &&
1135            "Key is already present in indices vector and overwriting will "
1136            "happen in `indexToExprMap` and `coefficients`!");
1137 
1138     indices.push_back(index);
1139     coefficients.insert({index, coefficient});
1140     indexToExprMap.insert({index, expr});
1141   };
1142 
1143   // Design indices for dimensional or symbolic terms, and store the indices,
1144   // constant coefficient corresponding to the indices in `coefficients` map,
1145   // and affine expression corresponding to indices in `indexToExprMap` map.
1146 
1147   // Ensure we do not have duplicate keys in `indexToExpr` map.
1148   unsigned offsetSym = 0;
1149   signed offsetDim = -1;
1150   for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1151     if (flatExprs[j] == 0)
1152       continue;
1153     // For symbolic expression set the index as <position number
1154     // of the symbol, max(dimCount, symCount)> number,
1155     // as we want symbolic expressions with the same positional number to
1156     // appear after dimensional expressions having the same positional number.
1157     std::pair<unsigned, signed> indexEntry(
1158         j - numDims, std::max(numDims, numSymbols) + offsetSym++);
1159     addEntry(indexEntry, flatExprs[j],
1160              getAffineSymbolExpr(j - numDims, context));
1161   }
1162 
1163   // Denotes semi-affine product, modulo or division terms, which has been added
1164   // to the `indexToExpr` map.
1165   SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1166                                   false);
1167   unsigned lhsPos, rhsPos;
1168   // Construct indices for product terms involving dimension, symbol or constant
1169   // as lhs/rhs, and store the indices, constant coefficient corresponding to
1170   // the indices in `coefficients` map, and affine expression corresponding to
1171   // in indices in `indexToExprMap` map.
1172   for (const auto &it : llvm::enumerate(localExprs)) {
1173     AffineExpr expr = it.value();
1174     if (flatExprs[numDims + numSymbols + it.index()] == 0)
1175       continue;
1176     AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
1177     AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
1178     if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
1179           (isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
1180            isa<AffineConstantExpr>(rhs)))) {
1181       continue;
1182     }
1183     if (isa<AffineConstantExpr>(rhs)) {
1184       // For product/modulo/division expressions, when rhs of modulo/division
1185       // expression is constant, we put 0 in place of keyB, because we want
1186       // them to appear earlier in the semi-affine expression we are
1187       // constructing. When rhs is constant, we place 0 in place of keyB.
1188       if (isa<AffineDimExpr>(lhs)) {
1189         lhsPos = cast<AffineDimExpr>(lhs).getPosition();
1190         std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1191         addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1192                  expr);
1193       } else {
1194         lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
1195         std::pair<unsigned, signed> indexEntry(
1196             lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1197         addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1198                  expr);
1199       }
1200     } else if (isa<AffineDimExpr>(lhs)) {
1201       // For product/modulo/division expressions having lhs as dimension and rhs
1202       // as symbol, we order the terms in the semi-affine expression based on
1203       // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1204       // where keyA is the position number of the dimension and keyB is the
1205       // position number of the symbol.
1206       lhsPos = cast<AffineDimExpr>(lhs).getPosition();
1207       rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
1208       std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1209       addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1210     } else {
1211       // For product/modulo/division expressions having both lhs and rhs as
1212       // symbol, we design indices as a pair: <keyA, keyB> for expressions
1213       // of the form dimension * symbol, where keyA is the position number of
1214       // the dimension and keyB is the position number of the symbol.
1215       lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
1216       rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
1217       std::pair<unsigned, signed> indexEntry(
1218           lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1219       addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1220     }
1221     addedToMap[it.index()] = true;
1222   }
1223 
1224   for (unsigned j = 0; j < numDims; ++j) {
1225     if (flatExprs[j] == 0)
1226       continue;
1227     // For dimensional expressions we set the index as <position number of the
1228     // dimension, 0>, as we want dimensional expressions to appear before
1229     // symbolic ones and products of dimensional and symbolic expressions
1230     // having the dimension with the same position number.
1231     std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1232     addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1233   }
1234 
1235   // Constructing the simplified semi-affine sum of product/division/mod
1236   // expression from the flattened form in the desired sorted order of indices
1237   // of the various individual product/division/mod expressions.
1238   llvm::sort(indices);
1239   for (const std::pair<unsigned, unsigned> index : indices) {
1240     assert(indexToExprMap.lookup(index) &&
1241            "cannot find key in `indexToExprMap` map");
1242     expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1243   }
1244 
1245   // Local identifiers.
1246   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1247        j++) {
1248     // If the coefficient of the local expression is 0, continue as we need not
1249     // add it in out final expression.
1250     if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1251       continue;
1252     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1253     expr = expr + term;
1254   }
1255 
1256   // Constant term.
1257   int64_t constTerm = flatExprs.back();
1258   if (constTerm != 0)
1259     expr = expr + constTerm;
1260   return expr;
1261 }
1262 
1263 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
1264                                                      unsigned numSymbols)
1265     : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1266   operandExprStack.reserve(8);
1267 }
1268 
1269 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1270 //
1271 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1272 // introduce a local variable p (= expr * symbolic_expr), and the affine
1273 // expression expr * symbolic_expr is added to `localExprs`.
1274 LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1275   assert(operandExprStack.size() >= 2);
1276   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1277   operandExprStack.pop_back();
1278   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1279 
1280   // Flatten semi-affine multiplication expressions by introducing a local
1281   // variable in place of the product; the affine expression
1282   // corresponding to the quantifier is added to `localExprs`.
1283   if (!isa<AffineConstantExpr>(expr.getRHS())) {
1284     SmallVector<int64_t, 8> mulLhs(lhs);
1285     MLIRContext *context = expr.getContext();
1286     AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
1287                                              localExprs, context);
1288     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1289                                              localExprs, context);
1290     return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
1291   }
1292 
1293   // Get the RHS constant.
1294   int64_t rhsConst = rhs[getConstantIndex()];
1295   for (int64_t &lhsElt : lhs)
1296     lhsElt *= rhsConst;
1297 
1298   return success();
1299 }
1300 
1301 LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
1302   assert(operandExprStack.size() >= 2);
1303   const auto &rhs = operandExprStack.back();
1304   auto &lhs = operandExprStack[operandExprStack.size() - 2];
1305   assert(lhs.size() == rhs.size());
1306   // Update the LHS in place.
1307   for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1308     lhs[i] += rhs[i];
1309   }
1310   // Pop off the RHS.
1311   operandExprStack.pop_back();
1312   return success();
1313 }
1314 
1315 //
1316 // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
1317 //
1318 // A mod expression "expr mod c" is thus flattened by introducing a new local
1319 // variable q (= expr floordiv c), such that expr mod c is replaced with
1320 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1321 //
1322 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1323 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1324 // expression expr mod symbolic_expr is added to `localExprs`.
1325 LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1326   assert(operandExprStack.size() >= 2);
1327 
1328   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1329   operandExprStack.pop_back();
1330   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1331   MLIRContext *context = expr.getContext();
1332 
1333   // Flatten semi affine modulo expressions by introducing a local
1334   // variable in place of the modulo value, and the affine expression
1335   // corresponding to the quantifier is added to `localExprs`.
1336   if (!isa<AffineConstantExpr>(expr.getRHS())) {
1337     SmallVector<int64_t, 8> modLhs(lhs);
1338     AffineExpr dividendExpr = getAffineExprFromFlatForm(
1339         lhs, numDims, numSymbols, localExprs, context);
1340     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1341                                                        localExprs, context);
1342     AffineExpr modExpr = dividendExpr % divisorExpr;
1343     return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
1344   }
1345 
1346   int64_t rhsConst = rhs[getConstantIndex()];
1347   if (rhsConst <= 0)
1348     return failure();
1349 
1350   // Check if the LHS expression is a multiple of modulo factor.
1351   unsigned i, e;
1352   for (i = 0, e = lhs.size(); i < e; i++)
1353     if (lhs[i] % rhsConst != 0)
1354       break;
1355   // If yes, modulo expression here simplifies to zero.
1356   if (i == lhs.size()) {
1357     std::fill(lhs.begin(), lhs.end(), 0);
1358     return success();
1359   }
1360 
1361   // Add a local variable for the quotient, i.e., expr % c is replaced by
1362   // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1363   // the GCD of expr and c.
1364   SmallVector<int64_t, 8> floorDividend(lhs);
1365   uint64_t gcd = rhsConst;
1366   for (int64_t lhsElt : lhs)
1367     gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1368   // Simplify the numerator and the denominator.
1369   if (gcd != 1) {
1370     for (int64_t &floorDividendElt : floorDividend)
1371       floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
1372   }
1373   int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1374 
1375   // Construct the AffineExpr form of the floordiv to store in localExprs.
1376 
1377   AffineExpr dividendExpr = getAffineExprFromFlatForm(
1378       floorDividend, numDims, numSymbols, localExprs, context);
1379   AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1380   AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1381   int loc;
1382   if ((loc = findLocalId(floorDivExpr)) == -1) {
1383     addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1384     // Set result at top of stack to "lhs - rhsConst * q".
1385     lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1386   } else {
1387     // Reuse the existing local id.
1388     lhs[getLocalVarStartIndex() + loc] -= rhsConst;
1389   }
1390   return success();
1391 }
1392 
1393 LogicalResult
1394 SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
1395   return visitDivExpr(expr, /*isCeil=*/true);
1396 }
1397 LogicalResult
1398 SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
1399   return visitDivExpr(expr, /*isCeil=*/false);
1400 }
1401 
1402 LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
1403   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1404   auto &eq = operandExprStack.back();
1405   assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1406   eq[getDimStartIndex() + expr.getPosition()] = 1;
1407   return success();
1408 }
1409 
1410 LogicalResult
1411 SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
1412   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1413   auto &eq = operandExprStack.back();
1414   assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1415   eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1416   return success();
1417 }
1418 
1419 LogicalResult
1420 SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1421   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1422   auto &eq = operandExprStack.back();
1423   eq[getConstantIndex()] = expr.getValue();
1424   return success();
1425 }
1426 
1427 LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1428     ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr,
1429     SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
1430   assert(result.size() == resultSize &&
1431          "`result` vector passed is not of correct size");
1432   int loc;
1433   if ((loc = findLocalId(localExpr)) == -1) {
1434     if (failed(addLocalIdSemiAffine(lhs, rhs, localExpr)))
1435       return failure();
1436   }
1437   std::fill(result.begin(), result.end(), 0);
1438   if (loc == -1)
1439     result[getLocalVarStartIndex() + numLocals - 1] = 1;
1440   else
1441     result[getLocalVarStartIndex() + loc] = 1;
1442   return success();
1443 }
1444 
1445 // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
1446 // A floordiv is thus flattened by introducing a new local variable q, and
1447 // replacing that expression with 'q' while adding the constraints
1448 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1449 // IntegerRelation::addLocalFloorDiv).
1450 //
1451 // A ceildiv is similarly flattened:
1452 // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
1453 //
1454 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1455 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1456 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1457 // `localExprs`.
1458 LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1459                                                       bool isCeil) {
1460   assert(operandExprStack.size() >= 2);
1461 
1462   MLIRContext *context = expr.getContext();
1463   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1464   operandExprStack.pop_back();
1465   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1466 
1467   // Flatten semi affine division expressions by introducing a local
1468   // variable in place of the quotient, and the affine expression corresponding
1469   // to the quantifier is added to `localExprs`.
1470   if (!isa<AffineConstantExpr>(expr.getRHS())) {
1471     SmallVector<int64_t, 8> divLhs(lhs);
1472     AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
1473                                              localExprs, context);
1474     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1475                                              localExprs, context);
1476     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1477     return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
1478   }
1479 
1480   // This is a pure affine expr; the RHS is a positive constant.
1481   int64_t rhsConst = rhs[getConstantIndex()];
1482   if (rhsConst <= 0)
1483     return failure();
1484 
1485   // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1486   // common divisors of the numerator and denominator.
1487   uint64_t gcd = std::abs(rhsConst);
1488   for (int64_t lhsElt : lhs)
1489     gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1490   // Simplify the numerator and the denominator.
1491   if (gcd != 1) {
1492     for (int64_t &lhsElt : lhs)
1493       lhsElt = lhsElt / static_cast<int64_t>(gcd);
1494   }
1495   int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1496   // If the divisor becomes 1, the updated LHS is the result. (The
1497   // divisor can't be negative since rhsConst is positive).
1498   if (divisor == 1)
1499     return success();
1500 
1501   // If the divisor cannot be simplified to one, we will have to retain
1502   // the ceil/floor expr (simplified up until here). Add an existential
1503   // quantifier to express its result, i.e., expr1 div expr2 is replaced
1504   // by a new identifier, q.
1505   AffineExpr a =
1506       getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
1507   AffineExpr b = getAffineConstantExpr(divisor, context);
1508 
1509   int loc;
1510   AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1511   if ((loc = findLocalId(divExpr)) == -1) {
1512     if (!isCeil) {
1513       SmallVector<int64_t, 8> dividend(lhs);
1514       addLocalFloorDivId(dividend, divisor, divExpr);
1515     } else {
1516       // lhs ceildiv c <=>  (lhs + c - 1) floordiv c
1517       SmallVector<int64_t, 8> dividend(lhs);
1518       dividend.back() += divisor - 1;
1519       addLocalFloorDivId(dividend, divisor, divExpr);
1520     }
1521   }
1522   // Set the expression on stack to the local var introduced to capture the
1523   // result of the division (floor or ceil).
1524   std::fill(lhs.begin(), lhs.end(), 0);
1525   if (loc == -1)
1526     lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1527   else
1528     lhs[getLocalVarStartIndex() + loc] = 1;
1529   return success();
1530 }
1531 
1532 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1533 // The local identifier added is always a floordiv of a pure add/mul affine
1534 // function of other identifiers, coefficients of which are specified in
1535 // dividend and with respect to a positive constant divisor. localExpr is the
1536 // simplified tree expression (AffineExpr) corresponding to the quantifier.
1537 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1538                                                    int64_t divisor,
1539                                                    AffineExpr localExpr) {
1540   assert(divisor > 0 && "positive constant divisor expected");
1541   for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1542     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1543   localExprs.push_back(localExpr);
1544   numLocals++;
1545   // dividend and divisor are not used here; an override of this method uses it.
1546 }
1547 
1548 LogicalResult SimpleAffineExprFlattener::addLocalIdSemiAffine(
1549     ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) {
1550   for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1551     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1552   localExprs.push_back(localExpr);
1553   ++numLocals;
1554   // lhs and rhs are not used here; an override of this method uses them.
1555   return success();
1556 }
1557 
1558 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1559   SmallVectorImpl<AffineExpr>::iterator it;
1560   if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1561     return -1;
1562   return it - localExprs.begin();
1563 }
1564 
1565 /// Simplify the affine expression by flattening it and reconstructing it.
1566 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1567                                     unsigned numSymbols) {
1568   // Simplify semi-affine expressions separately.
1569   if (!expr.isPureAffine())
1570     expr = simplifySemiAffine(expr, numDims, numSymbols);
1571 
1572   SimpleAffineExprFlattener flattener(numDims, numSymbols);
1573   // has poison expression
1574   if (failed(flattener.walkPostOrder(expr)))
1575     return expr;
1576   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1577   if (!expr.isPureAffine() &&
1578       expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1579                                         flattener.localExprs,
1580                                         expr.getContext()))
1581     return expr;
1582   AffineExpr simplifiedExpr =
1583       expr.isPureAffine()
1584           ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1585                                       flattener.localExprs, expr.getContext())
1586           : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1587                                           flattener.localExprs,
1588                                           expr.getContext());
1589 
1590   flattener.operandExprStack.pop_back();
1591   assert(flattener.operandExprStack.empty());
1592   return simplifiedExpr;
1593 }
1594 
1595 std::optional<int64_t> mlir::getBoundForAffineExpr(
1596     AffineExpr expr, unsigned numDims, unsigned numSymbols,
1597     ArrayRef<std::optional<int64_t>> constLowerBounds,
1598     ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
1599   // Handle divs and mods.
1600   if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
1601     // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1602     // can compute an upper bound.
1603     if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
1604       auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1605       if (!rhsConst || rhsConst.getValue() < 1)
1606         return std::nullopt;
1607       auto bound =
1608           getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1609                                 constLowerBounds, constUpperBounds, isUpper);
1610       if (!bound)
1611         return std::nullopt;
1612       return divideFloorSigned(*bound, rhsConst.getValue());
1613     }
1614     if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
1615       auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1616       if (rhsConst && rhsConst.getValue() >= 1) {
1617         auto bound =
1618             getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1619                                   constLowerBounds, constUpperBounds, isUpper);
1620         if (!bound)
1621           return std::nullopt;
1622         return divideCeilSigned(*bound, rhsConst.getValue());
1623       }
1624       return std::nullopt;
1625     }
1626     if (binOpExpr.getKind() == AffineExprKind::Mod) {
1627       // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1628       // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1629       // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1630       auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1631       if (rhsConst && rhsConst.getValue() >= 1) {
1632         int64_t rhsConstVal = rhsConst.getValue();
1633         auto lb = getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1634                                         constLowerBounds, constUpperBounds,
1635                                         /*isUpper=*/false);
1636         auto ub =
1637             getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1638                                   constLowerBounds, constUpperBounds, isUpper);
1639         if (ub && lb &&
1640             divideFloorSigned(*lb, rhsConstVal) ==
1641                 divideFloorSigned(*ub, rhsConstVal))
1642           return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
1643         return isUpper ? rhsConstVal - 1 : 0;
1644       }
1645     }
1646   }
1647   // Flatten the expression.
1648   SimpleAffineExprFlattener flattener(numDims, numSymbols);
1649   auto simpleResult = flattener.walkPostOrder(expr);
1650   // has poison expression
1651   if (failed(simpleResult))
1652     return std::nullopt;
1653   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1654   // TODO: Handle local variables. We can get hold of flattener.localExprs and
1655   // get bound on the local expr recursively.
1656   if (flattener.numLocals > 0)
1657     return std::nullopt;
1658   int64_t bound = 0;
1659   // Substitute the constant lower or upper bound for the dimensional or
1660   // symbolic input depending on `isUpper` to determine the bound.
1661   for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1662     if (flattenedExpr[i] > 0) {
1663       auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1664       if (!constBound)
1665         return std::nullopt;
1666       bound += *constBound * flattenedExpr[i];
1667     } else if (flattenedExpr[i] < 0) {
1668       auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1669       if (!constBound)
1670         return std::nullopt;
1671       bound += *constBound * flattenedExpr[i];
1672     }
1673   }
1674   // Constant term.
1675   bound += flattenedExpr.back();
1676   return bound;
1677 }
1678