xref: /llvm-project/mlir/lib/Analysis/FlatLinearValueConstraints.cpp (revision a24c468782010e17563f6aa93c5bb173c7f873b2)
1 //===- FlatLinearValueConstraints.cpp - Linear Constraint -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/FlatLinearValueConstraints.h"
10 
11 #include "mlir/Analysis/Presburger/LinearTransform.h"
12 #include "mlir/Analysis/Presburger/PresburgerSpace.h"
13 #include "mlir/Analysis/Presburger/Simplex.h"
14 #include "mlir/Analysis/Presburger/Utils.h"
15 #include "mlir/IR/AffineExprVisitor.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/IntegerSet.h"
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include <optional>
25 
26 #define DEBUG_TYPE "flat-value-constraints"
27 
28 using namespace mlir;
29 using namespace presburger;
30 
31 //===----------------------------------------------------------------------===//
32 // AffineExprFlattener
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 
37 // See comments for SimpleAffineExprFlattener.
38 // An AffineExprFlattenerWithLocalVars extends a SimpleAffineExprFlattener by
39 // recording constraint information associated with mod's, floordiv's, and
40 // ceildiv's in FlatLinearConstraints 'localVarCst'.
41 struct AffineExprFlattener : public SimpleAffineExprFlattener {
42   using SimpleAffineExprFlattener::SimpleAffineExprFlattener;
43 
44   // Constraints connecting newly introduced local variables (for mod's and
45   // div's) to existing (dimensional and symbolic) ones. These are always
46   // inequalities.
47   IntegerPolyhedron localVarCst;
48 
49   AffineExprFlattener(unsigned nDims, unsigned nSymbols)
50       : SimpleAffineExprFlattener(nDims, nSymbols),
51         localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {};
52 
53 private:
54   // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
55   // The local variable added is always a floordiv of a pure add/mul affine
56   // function of other variables, coefficients of which are specified in
57   // `dividend' and with respect to the positive constant `divisor'. localExpr
58   // is the simplified tree expression (AffineExpr) corresponding to the
59   // quantifier.
60   void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
61                           AffineExpr localExpr) override {
62     SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
63     // Update localVarCst.
64     localVarCst.addLocalFloorDiv(dividend, divisor);
65   }
66 
67   LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
68                                      ArrayRef<int64_t> rhs,
69                                      AffineExpr localExpr) override {
70     // AffineExprFlattener does not support semi-affine expressions.
71     return failure();
72   }
73 };
74 
75 // A SemiAffineExprFlattener is an AffineExprFlattenerWithLocalVars that adds
76 // conservative bounds for semi-affine expressions (given assumptions hold). If
77 // the assumptions required to add the semi-affine bounds are found not to hold
78 // the final constraints set will be empty/inconsistent. If the assumptions are
79 // never contradicted the final bounds still only will be correct if the
80 // assumptions hold.
81 struct SemiAffineExprFlattener : public AffineExprFlattener {
82   using AffineExprFlattener::AffineExprFlattener;
83 
84   LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
85                                      ArrayRef<int64_t> rhs,
86                                      AffineExpr localExpr) override {
87     auto result =
88         SimpleAffineExprFlattener::addLocalIdSemiAffine(lhs, rhs, localExpr);
89     assert(succeeded(result) &&
90            "unexpected failure in SimpleAffineExprFlattener");
91     (void)result;
92 
93     if (localExpr.getKind() == AffineExprKind::Mod) {
94       // Given two numbers a and b, division is defined as:
95       //
96       // a = bq + r
97       // 0 <= r < |b| (where |x| is the absolute value of x)
98       //
99       // q = a floordiv b
100       // r = a mod b
101 
102       // Add a new local variable (r) to represent the mod.
103       unsigned rPos = localVarCst.appendVar(VarKind::Local);
104 
105       // r >= 0 (Can ALWAYS be added)
106       localVarCst.addBound(BoundType::LB, rPos, 0);
107 
108       // r < b (Can be added if b > 0, which we assume here)
109       ArrayRef<int64_t> b = rhs;
110       SmallVector<int64_t> bSubR(b);
111       bSubR.insert(bSubR.begin() + rPos, -1);
112       // Note: bSubR = b - r
113       // So this adds the bound b - r >= 1 (equivalent to r < b)
114       localVarCst.addBound(BoundType::LB, bSubR, 1);
115 
116       // Note: The assumption of b > 0 is based on the affine expression docs,
117       // which state "RHS of mod is always a constant or a symbolic expression
118       // with a positive value." (see AffineExprKind in AffineExpr.h). If this
119       // assumption does not hold constraints (added above) are a contradiction.
120 
121       return success();
122     }
123 
124     // TODO: Support other semi-affine expressions.
125     return failure();
126   }
127 };
128 
129 } // namespace
130 
131 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
132 // flattened. For example two specific cases:
133 // 1. an unhandled semi-affine expressions is found.
134 // 2. has poison expression (i.e., division by zero).
135 static LogicalResult
136 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
137                         unsigned numSymbols,
138                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
139                         FlatLinearConstraints *localVarCst,
140                         bool addConservativeSemiAffineBounds = false) {
141   if (exprs.empty()) {
142     if (localVarCst)
143       *localVarCst = FlatLinearConstraints(numDims, numSymbols);
144     return success();
145   }
146 
147   auto flattenExprs = [&](AffineExprFlattener &flattener) -> LogicalResult {
148     // Use the same flattener to simplify each expression successively. This way
149     // local variables / expressions are shared.
150     for (auto expr : exprs) {
151       auto flattenResult = flattener.walkPostOrder(expr);
152       if (failed(flattenResult))
153         return failure();
154     }
155 
156     assert(flattener.operandExprStack.size() == exprs.size());
157     flattenedExprs->clear();
158     flattenedExprs->assign(flattener.operandExprStack.begin(),
159                            flattener.operandExprStack.end());
160 
161     if (localVarCst)
162       localVarCst->clearAndCopyFrom(flattener.localVarCst);
163 
164     return success();
165   };
166 
167   if (addConservativeSemiAffineBounds) {
168     SemiAffineExprFlattener flattener(numDims, numSymbols);
169     return flattenExprs(flattener);
170   }
171 
172   AffineExprFlattener flattener(numDims, numSymbols);
173   return flattenExprs(flattener);
174 }
175 
176 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
177 // be flattened (an unhandled semi-affine was found).
178 LogicalResult mlir::getFlattenedAffineExpr(
179     AffineExpr expr, unsigned numDims, unsigned numSymbols,
180     SmallVectorImpl<int64_t> *flattenedExpr, FlatLinearConstraints *localVarCst,
181     bool addConservativeSemiAffineBounds) {
182   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
183   LogicalResult ret =
184       ::getFlattenedAffineExprs({expr}, numDims, numSymbols, &flattenedExprs,
185                                 localVarCst, addConservativeSemiAffineBounds);
186   *flattenedExpr = flattenedExprs[0];
187   return ret;
188 }
189 
190 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
191 /// flattened (i.e., an unhandled semi-affine was found).
192 LogicalResult mlir::getFlattenedAffineExprs(
193     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
194     FlatLinearConstraints *localVarCst, bool addConservativeSemiAffineBounds) {
195   if (map.getNumResults() == 0) {
196     if (localVarCst)
197       *localVarCst =
198           FlatLinearConstraints(map.getNumDims(), map.getNumSymbols());
199     return success();
200   }
201   return ::getFlattenedAffineExprs(
202       map.getResults(), map.getNumDims(), map.getNumSymbols(), flattenedExprs,
203       localVarCst, addConservativeSemiAffineBounds);
204 }
205 
206 LogicalResult mlir::getFlattenedAffineExprs(
207     IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
208     FlatLinearConstraints *localVarCst) {
209   if (set.getNumConstraints() == 0) {
210     if (localVarCst)
211       *localVarCst =
212           FlatLinearConstraints(set.getNumDims(), set.getNumSymbols());
213     return success();
214   }
215   return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
216                                    set.getNumSymbols(), flattenedExprs,
217                                    localVarCst);
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // FlatLinearConstraints
222 //===----------------------------------------------------------------------===//
223 
224 // Similar to `composeMap` except that no Values need be associated with the
225 // constraint system nor are they looked at -- the dimensions and symbols of
226 // `other` are expected to correspond 1:1 to `this` system.
227 LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) {
228   assert(other.getNumDims() == getNumDimVars() && "dim mismatch");
229   assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
230 
231   std::vector<SmallVector<int64_t, 8>> flatExprs;
232   if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
233     return failure();
234   assert(flatExprs.size() == other.getNumResults());
235 
236   // Add dimensions corresponding to the map's results.
237   insertDimVar(/*pos=*/0, /*num=*/other.getNumResults());
238 
239   // We add one equality for each result connecting the result dim of the map to
240   // the other variables.
241   // E.g.: if the expression is 16*i0 + i1, and this is the r^th
242   // iteration/result of the value map, we are adding the equality:
243   // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
244   // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
245   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
246     const auto &flatExpr = flatExprs[r];
247     assert(flatExpr.size() >= other.getNumInputs() + 1);
248 
249     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
250     // Set the coefficient for this result to one.
251     eqToAdd[r] = 1;
252 
253     // Dims and symbols.
254     for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
255       // Negate `eq[r]` since the newly added dimension will be set to this one.
256       eqToAdd[e + i] = -flatExpr[i];
257     }
258     // Local columns of `eq` are at the beginning.
259     unsigned j = getNumDimVars() + getNumSymbolVars();
260     unsigned end = flatExpr.size() - 1;
261     for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
262       eqToAdd[j] = -flatExpr[i];
263     }
264 
265     // Constant term.
266     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
267 
268     // Add the equality connecting the result of the map to this constraint set.
269     addEquality(eqToAdd);
270   }
271 
272   return success();
273 }
274 
275 // Determine whether the variable at 'pos' (say var_r) can be expressed as
276 // modulo of another known variable (say var_n) w.r.t a constant. For example,
277 // if the following constraints hold true:
278 // ```
279 // 0 <= var_r <= divisor - 1
280 // var_n - (divisor * q_expr) = var_r
281 // ```
282 // where `var_n` is a known variable (called dividend), and `q_expr` is an
283 // `AffineExpr` (called the quotient expression), `var_r` can be written as:
284 //
285 // `var_r = var_n mod divisor`.
286 //
287 // Additionally, in a special case of the above constaints where `q_expr` is an
288 // variable itself that is not yet known (say `var_q`), it can be written as a
289 // floordiv in the following way:
290 //
291 // `var_q = var_n floordiv divisor`.
292 //
293 // First 'num' dimensional variables starting at 'offset' are
294 // derived/to-be-derived in terms of the remaining variables. The remaining
295 // variables are assigned trivial affine expressions in `memo`. For example,
296 // memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2:
297 // memo ==>  d0  d1  .   .   d2 ...
298 // cst  ==>  c0  c1  c2  c3  c4 ...
299 //
300 // Returns true if the above mod or floordiv are detected, updating 'memo' with
301 // these new expressions. Returns false otherwise.
302 static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos,
303                         unsigned offset, unsigned num, int64_t lbConst,
304                         int64_t ubConst, MLIRContext *context,
305                         SmallVectorImpl<AffineExpr> &memo) {
306   assert(pos < cst.getNumVars() && "invalid position");
307 
308   // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can
309   // be determined.
310   if (lbConst != 0 || ubConst < 1)
311     return false;
312   int64_t divisor = ubConst + 1;
313 
314   // Check for the aforementioned conditions in each equality.
315   for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
316        curEquality < numEqualities; curEquality++) {
317     int64_t coefficientAtPos = cst.atEq64(curEquality, pos);
318     // If current equality does not involve `var_r`, continue to the next
319     // equality.
320     if (coefficientAtPos == 0)
321       continue;
322 
323     // Constant term should be 0 in this equality.
324     if (cst.atEq64(curEquality, cst.getNumCols() - 1) != 0)
325       continue;
326 
327     // Traverse through the equality and construct the dividend expression
328     // `dividendExpr`, to contain all the variables which are known and are
329     // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
330     // `dividendExpr` gets simplified into a single variable `var_n` discussed
331     // above.
332     auto dividendExpr = getAffineConstantExpr(0, context);
333 
334     // Track the terms that go into quotient expression, later used to detect
335     // additional floordiv.
336     unsigned quotientCount = 0;
337     int quotientPosition = -1;
338     int quotientSign = 1;
339 
340     // Consider each term in the current equality.
341     unsigned curVar, e;
342     for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) {
343       // Ignore var_r.
344       if (curVar == pos)
345         continue;
346       int64_t coefficientOfCurVar = cst.atEq64(curEquality, curVar);
347       // Ignore vars that do not contribute to the current equality.
348       if (coefficientOfCurVar == 0)
349         continue;
350       // Check if the current var goes into the quotient expression.
351       if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) {
352         quotientCount++;
353         quotientPosition = curVar;
354         quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1;
355         continue;
356       }
357       // Variables that are part of dividendExpr should be known.
358       if (!memo[curVar])
359         break;
360       // Append the current variable to the dividend expression.
361       dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar;
362     }
363 
364     // Can't construct expression as it depends on a yet uncomputed var.
365     if (curVar < e)
366       continue;
367 
368     // Express `var_r` in terms of the other vars collected so far.
369     if (coefficientAtPos > 0)
370       dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
371     else
372       dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
373 
374     // Simplify the expression.
375     dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimVars(),
376                                       cst.getNumSymbolVars());
377     // Only if the final dividend expression is just a single var (which we call
378     // `var_n`), we can proceed.
379     // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
380     // to dims themselves.
381     auto dimExpr = dyn_cast<AffineDimExpr>(dividendExpr);
382     if (!dimExpr)
383       continue;
384 
385     // Express `var_r` as `var_n % divisor` and store the expression in `memo`.
386     if (quotientCount >= 1) {
387       // Find the column corresponding to `dimExpr`. `num` columns starting at
388       // `offset` correspond to previously unknown variables. The column
389       // corresponding to the trivially known `dimExpr` can be on either side
390       // of these.
391       unsigned dimExprPos = dimExpr.getPosition();
392       unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num;
393       auto ub = cst.getConstantBound64(BoundType::UB, dimExprCol);
394       // If `var_n` has an upperbound that is less than the divisor, mod can be
395       // eliminated altogether.
396       if (ub && *ub < divisor)
397         memo[pos] = dimExpr;
398       else
399         memo[pos] = dimExpr % divisor;
400       // If a unique quotient `var_q` was seen, it can be expressed as
401       // `var_n floordiv divisor`.
402       if (quotientCount == 1 && !memo[quotientPosition])
403         memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
404 
405       return true;
406     }
407   }
408   return false;
409 }
410 
411 /// Check if the pos^th variable can be expressed as a floordiv of an affine
412 /// function of other variables (where the divisor is a positive constant)
413 /// given the initial set of expressions in `exprs`. If it can be, the
414 /// corresponding position in `exprs` is set as the detected affine expr. For
415 /// eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4. An equality can
416 /// also yield a floordiv: eg.  4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
417 /// <= i <= 32q + 31 => q = i floordiv 32.
418 static bool detectAsFloorDiv(const FlatLinearConstraints &cst, unsigned pos,
419                              MLIRContext *context,
420                              SmallVectorImpl<AffineExpr> &exprs) {
421   assert(pos < cst.getNumVars() && "invalid position");
422 
423   // Get upper-lower bound pair for this variable.
424   SmallVector<bool, 8> foundRepr(cst.getNumVars(), false);
425   for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i)
426     if (exprs[i])
427       foundRepr[i] = true;
428 
429   SmallVector<int64_t, 8> dividend(cst.getNumCols());
430   unsigned divisor;
431   auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
432 
433   // No upper-lower bound pair found for this var.
434   if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality)
435     return false;
436 
437   // Construct the dividend expression.
438   auto dividendExpr = getAffineConstantExpr(dividend.back(), context);
439   for (unsigned c = 0, f = cst.getNumVars(); c < f; c++)
440     if (dividend[c] != 0)
441       dividendExpr = dividendExpr + dividend[c] * exprs[c];
442 
443   // Successfully detected the floordiv.
444   exprs[pos] = dividendExpr.floorDiv(divisor);
445   return true;
446 }
447 
448 std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound(
449     unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
450     ArrayRef<AffineExpr> localExprs, MLIRContext *context,
451     bool closedUB) const {
452   assert(pos + offset < getNumDimVars() && "invalid dim start pos");
453   assert(symStartPos >= (pos + offset) && "invalid sym start pos");
454   assert(getNumLocalVars() == localExprs.size() &&
455          "incorrect local exprs count");
456 
457   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
458   getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
459                                offset, num);
460 
461   /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
462   auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
463     b.clear();
464     for (unsigned i = 0, e = a.size(); i < e; ++i) {
465       if (i < offset || i >= offset + num)
466         b.push_back(a[i]);
467     }
468   };
469 
470   SmallVector<int64_t, 8> lb, ub;
471   SmallVector<AffineExpr, 4> lbExprs;
472   unsigned dimCount = symStartPos - num;
473   unsigned symCount = getNumDimAndSymbolVars() - symStartPos;
474   lbExprs.reserve(lbIndices.size() + eqIndices.size());
475   // Lower bound expressions.
476   for (auto idx : lbIndices) {
477     auto ineq = getInequality64(idx);
478     // Extract the lower bound (in terms of other coeff's + const), i.e., if
479     // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
480     // - 1.
481     addCoeffs(ineq, lb);
482     std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
483     auto expr =
484         getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
485     // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
486     int64_t divisor = std::abs(ineq[pos + offset]);
487     expr = (expr + divisor - 1).floorDiv(divisor);
488     lbExprs.push_back(expr);
489   }
490 
491   SmallVector<AffineExpr, 4> ubExprs;
492   ubExprs.reserve(ubIndices.size() + eqIndices.size());
493   // Upper bound expressions.
494   for (auto idx : ubIndices) {
495     auto ineq = getInequality64(idx);
496     // Extract the upper bound (in terms of other coeff's + const).
497     addCoeffs(ineq, ub);
498     auto expr =
499         getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
500     expr = expr.floorDiv(std::abs(ineq[pos + offset]));
501     int64_t ubAdjustment = closedUB ? 0 : 1;
502     ubExprs.push_back(expr + ubAdjustment);
503   }
504 
505   // Equalities. It's both a lower and a upper bound.
506   SmallVector<int64_t, 4> b;
507   for (auto idx : eqIndices) {
508     auto eq = getEquality64(idx);
509     addCoeffs(eq, b);
510     if (eq[pos + offset] > 0)
511       std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
512 
513     // Extract the upper bound (in terms of other coeff's + const).
514     auto expr =
515         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
516     expr = expr.floorDiv(std::abs(eq[pos + offset]));
517     // Upper bound is exclusive.
518     ubExprs.push_back(expr + 1);
519     // Lower bound.
520     expr =
521         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
522     expr = expr.ceilDiv(std::abs(eq[pos + offset]));
523     lbExprs.push_back(expr);
524   }
525 
526   auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
527   auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
528 
529   return {lbMap, ubMap};
530 }
531 
532 /// Computes the lower and upper bounds of the first 'num' dimensional
533 /// variables (starting at 'offset') as affine maps of the remaining
534 /// variables (dimensional and symbolic variables). Local variables are
535 /// themselves explicitly computed as affine functions of other variables in
536 /// this process if needed.
537 void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
538                                            MLIRContext *context,
539                                            SmallVectorImpl<AffineMap> *lbMaps,
540                                            SmallVectorImpl<AffineMap> *ubMaps,
541                                            bool closedUB) {
542   assert(offset + num <= getNumDimVars() && "invalid range");
543 
544   // Basic simplification.
545   normalizeConstraintsByGCD();
546 
547   LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
548                           << " variables\n");
549   LLVM_DEBUG(dump());
550 
551   // Record computed/detected variables.
552   SmallVector<AffineExpr, 8> memo(getNumVars());
553   // Initialize dimensional and symbolic variables.
554   for (unsigned i = 0, e = getNumDimVars(); i < e; i++) {
555     if (i < offset)
556       memo[i] = getAffineDimExpr(i, context);
557     else if (i >= offset + num)
558       memo[i] = getAffineDimExpr(i - num, context);
559   }
560   for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++)
561     memo[i] = getAffineSymbolExpr(i - getNumDimVars(), context);
562 
563   bool changed;
564   do {
565     changed = false;
566     // Identify yet unknown variables as constants or mod's / floordiv's of
567     // other variables if possible.
568     for (unsigned pos = 0; pos < getNumVars(); pos++) {
569       if (memo[pos])
570         continue;
571 
572       auto lbConst = getConstantBound64(BoundType::LB, pos);
573       auto ubConst = getConstantBound64(BoundType::UB, pos);
574       if (lbConst.has_value() && ubConst.has_value()) {
575         // Detect equality to a constant.
576         if (*lbConst == *ubConst) {
577           memo[pos] = getAffineConstantExpr(*lbConst, context);
578           changed = true;
579           continue;
580         }
581 
582         // Detect a variable as modulo of another variable w.r.t a
583         // constant.
584         if (detectAsMod(*this, pos, offset, num, *lbConst, *ubConst, context,
585                         memo)) {
586           changed = true;
587           continue;
588         }
589       }
590 
591       // Detect a variable as a floordiv of an affine function of other
592       // variables (divisor is a positive constant).
593       if (detectAsFloorDiv(*this, pos, context, memo)) {
594         changed = true;
595         continue;
596       }
597 
598       // Detect a variable as an expression of other variables.
599       unsigned idx;
600       if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) {
601         continue;
602       }
603 
604       // Build AffineExpr solving for variable 'pos' in terms of all others.
605       auto expr = getAffineConstantExpr(0, context);
606       unsigned j, e;
607       for (j = 0, e = getNumVars(); j < e; ++j) {
608         if (j == pos)
609           continue;
610         int64_t c = atEq64(idx, j);
611         if (c == 0)
612           continue;
613         // If any of the involved IDs hasn't been found yet, we can't proceed.
614         if (!memo[j])
615           break;
616         expr = expr + memo[j] * c;
617       }
618       if (j < e)
619         // Can't construct expression as it depends on a yet uncomputed
620         // variable.
621         continue;
622 
623       // Add constant term to AffineExpr.
624       expr = expr + atEq64(idx, getNumVars());
625       int64_t vPos = atEq64(idx, pos);
626       assert(vPos != 0 && "expected non-zero here");
627       if (vPos > 0)
628         expr = (-expr).floorDiv(vPos);
629       else
630         // vPos < 0.
631         expr = expr.floorDiv(-vPos);
632       // Successfully constructed expression.
633       memo[pos] = expr;
634       changed = true;
635     }
636     // This loop is guaranteed to reach a fixed point - since once an
637     // variable's explicit form is computed (in memo[pos]), it's not updated
638     // again.
639   } while (changed);
640 
641   int64_t ubAdjustment = closedUB ? 0 : 1;
642 
643   // Set the lower and upper bound maps for all the variables that were
644   // computed as affine expressions of the rest as the "detected expr" and
645   // "detected expr + 1" respectively; set the undetected ones to null.
646   std::optional<FlatLinearConstraints> tmpClone;
647   for (unsigned pos = 0; pos < num; pos++) {
648     unsigned numMapDims = getNumDimVars() - num;
649     unsigned numMapSymbols = getNumSymbolVars();
650     AffineExpr expr = memo[pos + offset];
651     if (expr)
652       expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
653 
654     AffineMap &lbMap = (*lbMaps)[pos];
655     AffineMap &ubMap = (*ubMaps)[pos];
656 
657     if (expr) {
658       lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
659       ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + ubAdjustment);
660     } else {
661       // TODO: Whenever there are local variables in the dependence
662       // constraints, we'll conservatively over-approximate, since we don't
663       // always explicitly compute them above (in the while loop).
664       if (getNumLocalVars() == 0) {
665         // Work on a copy so that we don't update this constraint system.
666         if (!tmpClone) {
667           tmpClone.emplace(FlatLinearConstraints(*this));
668           // Removing redundant inequalities is necessary so that we don't get
669           // redundant loop bounds.
670           tmpClone->removeRedundantInequalities();
671         }
672         std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
673             pos, offset, num, getNumDimVars(), /*localExprs=*/{}, context,
674             closedUB);
675       }
676 
677       // If the above fails, we'll just use the constant lower bound and the
678       // constant upper bound (if they exist) as the slice bounds.
679       // TODO: being conservative for the moment in cases that
680       // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
681       // fixed (b/126426796).
682       if (!lbMap || lbMap.getNumResults() > 1) {
683         LLVM_DEBUG(llvm::dbgs()
684                    << "WARNING: Potentially over-approximating slice lb\n");
685         auto lbConst = getConstantBound64(BoundType::LB, pos + offset);
686         if (lbConst.has_value()) {
687           lbMap = AffineMap::get(numMapDims, numMapSymbols,
688                                  getAffineConstantExpr(*lbConst, context));
689         }
690       }
691       if (!ubMap || ubMap.getNumResults() > 1) {
692         LLVM_DEBUG(llvm::dbgs()
693                    << "WARNING: Potentially over-approximating slice ub\n");
694         auto ubConst = getConstantBound64(BoundType::UB, pos + offset);
695         if (ubConst.has_value()) {
696           ubMap = AffineMap::get(
697               numMapDims, numMapSymbols,
698               getAffineConstantExpr(*ubConst + ubAdjustment, context));
699         }
700       }
701     }
702     LLVM_DEBUG(llvm::dbgs()
703                << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
704     LLVM_DEBUG(lbMap.dump(););
705     LLVM_DEBUG(llvm::dbgs()
706                << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
707     LLVM_DEBUG(ubMap.dump(););
708   }
709 }
710 
711 LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
712     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
713     bool addConservativeSemiAffineBounds) {
714   FlatLinearConstraints localCst;
715   if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst,
716                                      addConservativeSemiAffineBounds))) {
717     LLVM_DEBUG(llvm::dbgs()
718                << "composition unimplemented for semi-affine maps\n");
719     return failure();
720   }
721 
722   // Add localCst information.
723   if (localCst.getNumLocalVars() > 0) {
724     unsigned numLocalVars = getNumLocalVars();
725     // Insert local dims of localCst at the beginning.
726     insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars());
727     // Insert local dims of `this` at the end of localCst.
728     localCst.appendLocalVar(/*num=*/numLocalVars);
729     // Dimensions of localCst and this constraint set match. Append localCst to
730     // this constraint set.
731     append(localCst);
732   }
733 
734   return success();
735 }
736 
737 LogicalResult FlatLinearConstraints::addBound(
738     BoundType type, unsigned pos, AffineMap boundMap, bool isClosedBound,
739     AddConservativeSemiAffineBounds addSemiAffineBounds) {
740   assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch");
741   assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
742   assert(pos < getNumDimAndSymbolVars() && "invalid position");
743   assert((type != BoundType::EQ || isClosedBound) &&
744          "EQ bound must be closed.");
745 
746   // Equality follows the logic of lower bound except that we add an equality
747   // instead of an inequality.
748   assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
749          "single result expected");
750   bool lower = type == BoundType::LB || type == BoundType::EQ;
751 
752   std::vector<SmallVector<int64_t, 8>> flatExprs;
753   if (failed(flattenAlignedMapAndMergeLocals(
754           boundMap, &flatExprs,
755           addSemiAffineBounds == AddConservativeSemiAffineBounds::Yes)))
756     return failure();
757   assert(flatExprs.size() == boundMap.getNumResults());
758 
759   // Add one (in)equality for each result.
760   for (const auto &flatExpr : flatExprs) {
761     SmallVector<int64_t> ineq(getNumCols(), 0);
762     // Dims and symbols.
763     for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
764       ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
765     }
766     // Invalid bound: pos appears in `boundMap`.
767     // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
768     // its callers to prevent invalid bounds from being added.
769     if (ineq[pos] != 0)
770       continue;
771     ineq[pos] = lower ? 1 : -1;
772     // Local columns of `ineq` are at the beginning.
773     unsigned j = getNumDimVars() + getNumSymbolVars();
774     unsigned end = flatExpr.size() - 1;
775     for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
776       ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
777     }
778     // Make the bound closed in if flatExpr is open. The inequality is always
779     // created in the upper bound form, so the adjustment is -1.
780     int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1;
781     // Constant term.
782     ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1]
783                                     : flatExpr[flatExpr.size() - 1]) +
784                              boundAdjustment;
785     type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq);
786   }
787 
788   return success();
789 }
790 
791 LogicalResult FlatLinearConstraints::addBound(
792     BoundType type, unsigned pos, AffineMap boundMap,
793     AddConservativeSemiAffineBounds addSemiAffineBounds) {
794   return addBound(type, pos, boundMap,
795                   /*isClosedBound=*/type != BoundType::UB, addSemiAffineBounds);
796 }
797 
798 /// Compute an explicit representation for local vars. For all systems coming
799 /// from MLIR integer sets, maps, or expressions where local vars were
800 /// introduced to model floordivs and mods, this always succeeds.
801 LogicalResult
802 FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo,
803                                         MLIRContext *context) const {
804   unsigned numDims = getNumDimVars();
805   unsigned numSyms = getNumSymbolVars();
806 
807   // Initialize dimensional and symbolic variables.
808   for (unsigned i = 0; i < numDims; i++)
809     memo[i] = getAffineDimExpr(i, context);
810   for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
811     memo[i] = getAffineSymbolExpr(i - numDims, context);
812 
813   bool changed;
814   do {
815     // Each time `changed` is true at the end of this iteration, one or more
816     // local vars would have been detected as floordivs and set in memo; so the
817     // number of null entries in memo[...] strictly reduces; so this converges.
818     changed = false;
819     for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i)
820       if (!memo[numDims + numSyms + i] &&
821           detectAsFloorDiv(*this, /*pos=*/numDims + numSyms + i, context, memo))
822         changed = true;
823   } while (changed);
824 
825   ArrayRef<AffineExpr> localExprs =
826       ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars());
827   return success(
828       llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
829 }
830 
831 IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const {
832   if (getNumConstraints() == 0)
833     // Return universal set (always true): 0 == 0.
834     return IntegerSet::get(getNumDimVars(), getNumSymbolVars(),
835                            getAffineConstantExpr(/*constant=*/0, context),
836                            /*eqFlags=*/true);
837 
838   // Construct local references.
839   SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
840 
841   if (failed(computeLocalVars(memo, context))) {
842     // Check if the local variables without an explicit representation have
843     // zero coefficients everywhere.
844     SmallVector<unsigned> noLocalRepVars;
845     unsigned numDimsSymbols = getNumDimAndSymbolVars();
846     for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) {
847       if (!memo[i] && !isColZero(/*pos=*/i))
848         noLocalRepVars.push_back(i - numDimsSymbols);
849     }
850     if (!noLocalRepVars.empty()) {
851       LLVM_DEBUG({
852         llvm::dbgs() << "local variables at position(s) ";
853         llvm::interleaveComma(noLocalRepVars, llvm::dbgs());
854         llvm::dbgs() << " do not have an explicit representation in:\n";
855         this->dump();
856       });
857       return IntegerSet();
858     }
859   }
860 
861   ArrayRef<AffineExpr> localExprs =
862       ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars());
863 
864   // Construct the IntegerSet from the equalities/inequalities.
865   unsigned numDims = getNumDimVars();
866   unsigned numSyms = getNumSymbolVars();
867 
868   SmallVector<bool, 16> eqFlags(getNumConstraints());
869   std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
870   std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
871 
872   SmallVector<AffineExpr, 8> exprs;
873   exprs.reserve(getNumConstraints());
874 
875   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
876     exprs.push_back(getAffineExprFromFlatForm(getEquality64(i), numDims,
877                                               numSyms, localExprs, context));
878   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
879     exprs.push_back(getAffineExprFromFlatForm(getInequality64(i), numDims,
880                                               numSyms, localExprs, context));
881   return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
882 }
883 
884 //===----------------------------------------------------------------------===//
885 // FlatLinearValueConstraints
886 //===----------------------------------------------------------------------===//
887 
888 // Construct from an IntegerSet.
889 FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set,
890                                                        ValueRange operands)
891     : FlatLinearConstraints(set.getNumInequalities(), set.getNumEqualities(),
892                             set.getNumDims() + set.getNumSymbols() + 1,
893                             set.getNumDims(), set.getNumSymbols(),
894                             /*numLocals=*/0) {
895   assert((operands.empty() || set.getNumInputs() == operands.size()) &&
896          "operand count mismatch");
897   // Set the values for the non-local variables.
898   for (unsigned i = 0, e = operands.size(); i < e; ++i)
899     setValue(i, operands[i]);
900 
901   // Flatten expressions and add them to the constraint system.
902   std::vector<SmallVector<int64_t, 8>> flatExprs;
903   FlatLinearConstraints localVarCst;
904   if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
905     assert(false && "flattening unimplemented for semi-affine integer sets");
906     return;
907   }
908   assert(flatExprs.size() == set.getNumConstraints());
909   insertVar(VarKind::Local, getNumVarKind(VarKind::Local),
910             /*num=*/localVarCst.getNumLocalVars());
911 
912   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
913     const auto &flatExpr = flatExprs[i];
914     assert(flatExpr.size() == getNumCols());
915     if (set.getEqFlags()[i]) {
916       addEquality(flatExpr);
917     } else {
918       addInequality(flatExpr);
919     }
920   }
921   // Add the other constraints involving local vars from flattening.
922   append(localVarCst);
923 }
924 
925 unsigned FlatLinearValueConstraints::appendDimVar(ValueRange vals) {
926   unsigned pos = getNumDimVars();
927   return insertVar(VarKind::SetDim, pos, vals);
928 }
929 
930 unsigned FlatLinearValueConstraints::appendSymbolVar(ValueRange vals) {
931   unsigned pos = getNumSymbolVars();
932   return insertVar(VarKind::Symbol, pos, vals);
933 }
934 
935 unsigned FlatLinearValueConstraints::insertDimVar(unsigned pos,
936                                                   ValueRange vals) {
937   return insertVar(VarKind::SetDim, pos, vals);
938 }
939 
940 unsigned FlatLinearValueConstraints::insertSymbolVar(unsigned pos,
941                                                      ValueRange vals) {
942   return insertVar(VarKind::Symbol, pos, vals);
943 }
944 
945 unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
946                                                unsigned num) {
947   unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
948 
949   return absolutePos;
950 }
951 
952 unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
953                                                ValueRange vals) {
954   assert(!vals.empty() && "expected ValueRange with Values.");
955   assert(kind != VarKind::Local &&
956          "values cannot be attached to local variables.");
957   unsigned num = vals.size();
958   unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
959 
960   // If a Value is provided, insert it; otherwise use std::nullopt.
961   for (unsigned i = 0, e = vals.size(); i < e; ++i)
962     if (vals[i])
963       setValue(absolutePos + i, vals[i]);
964 
965   return absolutePos;
966 }
967 
968 /// Checks if two constraint systems are in the same space, i.e., if they are
969 /// associated with the same set of variables, appearing in the same order.
970 static bool areVarsAligned(const FlatLinearValueConstraints &a,
971                            const FlatLinearValueConstraints &b) {
972   if (a.getNumDomainVars() != b.getNumDomainVars() ||
973       a.getNumRangeVars() != b.getNumRangeVars() ||
974       a.getNumSymbolVars() != b.getNumSymbolVars())
975     return false;
976   SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(),
977                                     bMaybeValues = b.getMaybeValues();
978   return std::equal(aMaybeValues.begin(), aMaybeValues.end(),
979                     bMaybeValues.begin(), bMaybeValues.end());
980 }
981 
982 /// Calls areVarsAligned to check if two constraint systems have the same set
983 /// of variables in the same order.
984 bool FlatLinearValueConstraints::areVarsAlignedWithOther(
985     const FlatLinearConstraints &other) {
986   return areVarsAligned(*this, other);
987 }
988 
989 /// Checks if the SSA values associated with `cst`'s variables in range
990 /// [start, end) are unique.
991 static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
992     const FlatLinearValueConstraints &cst, unsigned start, unsigned end) {
993 
994   assert(start <= cst.getNumDimAndSymbolVars() &&
995          "Start position out of bounds");
996   assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds");
997 
998   if (start >= end)
999     return true;
1000 
1001   SmallPtrSet<Value, 8> uniqueVars;
1002   SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues();
1003   ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start,
1004                                                 maybeValuesAll.data() + end};
1005 
1006   for (std::optional<Value> val : maybeValues)
1007     if (val && !uniqueVars.insert(*val).second)
1008       return false;
1009 
1010   return true;
1011 }
1012 
1013 /// Checks if the SSA values associated with `cst`'s variables are unique.
1014 static bool LLVM_ATTRIBUTE_UNUSED
1015 areVarsUnique(const FlatLinearValueConstraints &cst) {
1016   return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars());
1017 }
1018 
1019 /// Checks if the SSA values associated with `cst`'s variables of kind `kind`
1020 /// are unique.
1021 static bool LLVM_ATTRIBUTE_UNUSED
1022 areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
1023 
1024   if (kind == VarKind::SetDim)
1025     return areVarsUnique(cst, 0, cst.getNumDimVars());
1026   if (kind == VarKind::Symbol)
1027     return areVarsUnique(cst, cst.getNumDimVars(),
1028                          cst.getNumDimAndSymbolVars());
1029   llvm_unreachable("Unexpected VarKind");
1030 }
1031 
1032 /// Merge and align the variables of A and B starting at 'offset', so that
1033 /// both constraint systems get the union of the contained variables that is
1034 /// dimension-wise and symbol-wise unique; both constraint systems are updated
1035 /// so that they have the union of all variables, with A's original
1036 /// variables appearing first followed by any of B's variables that didn't
1037 /// appear in A. Local variables in B that have the same division
1038 /// representation as local variables in A are merged into one. We allow A
1039 /// and B to have non-unique values for their variables; in such cases, they are
1040 /// still aligned with the variables appearing first aligned with those
1041 /// appearing first in the other system from left to right.
1042 //  E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
1043 //        Output: both A, B have (%i, %j, %k) [%M, %N, %P]
1044 static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
1045                               FlatLinearValueConstraints *b) {
1046   assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars());
1047 
1048   assert(llvm::all_of(
1049       llvm::drop_begin(a->getMaybeValues(), offset),
1050       [](const std::optional<Value> &var) { return var.has_value(); }));
1051 
1052   assert(llvm::all_of(
1053       llvm::drop_begin(b->getMaybeValues(), offset),
1054       [](const std::optional<Value> &var) { return var.has_value(); }));
1055 
1056   SmallVector<Value, 4> aDimValues;
1057   a->getValues(offset, a->getNumDimVars(), &aDimValues);
1058 
1059   {
1060     // Merge dims from A into B.
1061     unsigned d = offset;
1062     for (Value aDimValue : aDimValues) {
1063       unsigned loc;
1064       // Find from the position `d` since we'd like to also consider the
1065       // possibility of multiple variables with the same `Value`. We align with
1066       // the next appearing one.
1067       if (b->findVar(aDimValue, &loc, d)) {
1068         assert(loc >= offset && "A's dim appears in B's aligned range");
1069         assert(loc < b->getNumDimVars() &&
1070                "A's dim appears in B's non-dim position");
1071         b->swapVar(d, loc);
1072       } else {
1073         b->insertDimVar(d, aDimValue);
1074       }
1075       d++;
1076     }
1077     // Dimensions that are in B, but not in A, are added at the end.
1078     for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) {
1079       a->appendDimVar(b->getValue(t));
1080     }
1081     assert(a->getNumDimVars() == b->getNumDimVars() &&
1082            "expected same number of dims");
1083   }
1084 
1085   // Merge and align symbols of A and B
1086   a->mergeSymbolVars(*b);
1087   // Merge and align locals of A and B
1088   a->mergeLocalVars(*b);
1089 
1090   assert(areVarsAligned(*a, *b) && "IDs expected to be aligned");
1091 }
1092 
1093 // Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'.
1094 void FlatLinearValueConstraints::mergeAndAlignVarsWithOther(
1095     unsigned offset, FlatLinearValueConstraints *other) {
1096   mergeAndAlignVars(offset, this, other);
1097 }
1098 
1099 /// Merge and align symbols of `this` and `other` such that both get union of
1100 /// of symbols. Existing symbols need not be unique; they will be aligned from
1101 /// left to right with duplicates aligned in the same order. Symbols with Value
1102 /// as `None` are considered to be inequal to all other symbols.
1103 void FlatLinearValueConstraints::mergeSymbolVars(
1104     FlatLinearValueConstraints &other) {
1105 
1106   SmallVector<Value, 4> aSymValues;
1107   getValues(getNumDimVars(), getNumDimAndSymbolVars(), &aSymValues);
1108 
1109   // Merge symbols: merge symbols into `other` first from `this`.
1110   unsigned s = other.getNumDimVars();
1111   for (Value aSymValue : aSymValues) {
1112     unsigned loc;
1113     // If the var is a symbol in `other`, then align it, otherwise assume that
1114     // it is a new symbol. Search in `other` starting at position `s` since the
1115     // left of it is aligned.
1116     if (other.findVar(aSymValue, &loc, s) && loc >= other.getNumDimVars() &&
1117         loc < other.getNumDimAndSymbolVars())
1118       other.swapVar(s, loc);
1119     else
1120       other.insertSymbolVar(s - other.getNumDimVars(), aSymValue);
1121     s++;
1122   }
1123 
1124   // Symbols that are in other, but not in this, are added at the end.
1125   for (unsigned t = other.getNumDimVars() + getNumSymbolVars(),
1126                 e = other.getNumDimAndSymbolVars();
1127        t < e; t++)
1128     insertSymbolVar(getNumSymbolVars(), other.getValue(t));
1129 
1130   assert(getNumSymbolVars() == other.getNumSymbolVars() &&
1131          "expected same number of symbols");
1132 }
1133 
1134 void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart,
1135                                                 unsigned varLimit) {
1136   IntegerPolyhedron::removeVarRange(kind, varStart, varLimit);
1137 }
1138 
1139 AffineMap
1140 FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
1141                                               ValueRange operands) const {
1142   assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
1143 
1144   SmallVector<Value> dims, syms;
1145 #ifndef NDEBUG
1146   SmallVector<Value> newSyms;
1147   SmallVector<Value> *newSymsPtr = &newSyms;
1148 #else
1149   SmallVector<Value> *newSymsPtr = nullptr;
1150 #endif // NDEBUG
1151 
1152   dims.reserve(getNumDimVars());
1153   syms.reserve(getNumSymbolVars());
1154   for (unsigned i = 0, e = getNumVarKind(VarKind::SetDim); i < e; ++i) {
1155     Identifier id = space.getId(VarKind::SetDim, i);
1156     dims.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value());
1157   }
1158   for (unsigned i = 0, e = getNumVarKind(VarKind::Symbol); i < e; ++i) {
1159     Identifier id = space.getId(VarKind::Symbol, i);
1160     syms.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value());
1161   }
1162 
1163   AffineMap alignedMap =
1164       alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
1165   // All symbols are already part of this FlatAffineValueConstraints.
1166   assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
1167   assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
1168          "unexpected new/missing symbols");
1169   return alignedMap;
1170 }
1171 
1172 bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
1173                                          unsigned offset) const {
1174   SmallVector<std::optional<Value>> maybeValues = getMaybeValues();
1175   for (unsigned i = offset, e = maybeValues.size(); i < e; ++i)
1176     if (maybeValues[i] && maybeValues[i].value() == val) {
1177       *pos = i;
1178       return true;
1179     }
1180   return false;
1181 }
1182 
1183 bool FlatLinearValueConstraints::containsVar(Value val) const {
1184   unsigned pos;
1185   return findVar(val, &pos, 0);
1186 }
1187 
1188 void FlatLinearValueConstraints::addBound(BoundType type, Value val,
1189                                           int64_t value) {
1190   unsigned pos;
1191   if (!findVar(val, &pos))
1192     // This is a pre-condition for this method.
1193     assert(0 && "var not found");
1194   addBound(type, pos, value);
1195 }
1196 
1197 void FlatLinearConstraints::printSpace(raw_ostream &os) const {
1198   IntegerPolyhedron::printSpace(os);
1199   os << "(";
1200   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++)
1201     os << "None\t";
1202   for (unsigned i = getVarKindOffset(VarKind::Local),
1203                 e = getVarKindEnd(VarKind::Local);
1204        i < e; ++i)
1205     os << "Local\t";
1206   os << "const)\n";
1207 }
1208 
1209 void FlatLinearValueConstraints::printSpace(raw_ostream &os) const {
1210   IntegerPolyhedron::printSpace(os);
1211   os << "(";
1212   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) {
1213     if (hasValue(i))
1214       os << "Value\t";
1215     else
1216       os << "None\t";
1217   }
1218   for (unsigned i = getVarKindOffset(VarKind::Local),
1219                 e = getVarKindEnd(VarKind::Local);
1220        i < e; ++i)
1221     os << "Local\t";
1222   os << "const)\n";
1223 }
1224 
1225 void FlatLinearValueConstraints::projectOut(Value val) {
1226   unsigned pos;
1227   bool ret = findVar(val, &pos);
1228   assert(ret);
1229   (void)ret;
1230   fourierMotzkinEliminate(pos);
1231 }
1232 
1233 LogicalResult FlatLinearValueConstraints::unionBoundingBox(
1234     const FlatLinearValueConstraints &otherCst) {
1235   assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch");
1236   SmallVector<std::optional<Value>> maybeValues = getMaybeValues(),
1237                                     otherMaybeValues =
1238                                         otherCst.getMaybeValues();
1239   assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(),
1240                     otherMaybeValues.begin(),
1241                     otherMaybeValues.begin() + getNumDimVars()) &&
1242          "dim values mismatch");
1243   assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
1244   assert(getNumLocalVars() == 0 && "local vars not supported yet here");
1245 
1246   // Align `other` to this.
1247   if (!areVarsAligned(*this, otherCst)) {
1248     FlatLinearValueConstraints otherCopy(otherCst);
1249     mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy);
1250     return IntegerPolyhedron::unionBoundingBox(otherCopy);
1251   }
1252 
1253   return IntegerPolyhedron::unionBoundingBox(otherCst);
1254 }
1255 
1256 //===----------------------------------------------------------------------===//
1257 // Helper functions
1258 //===----------------------------------------------------------------------===//
1259 
1260 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
1261                                          ValueRange dims, ValueRange syms,
1262                                          SmallVector<Value> *newSyms) {
1263   assert(operands.size() == map.getNumInputs() &&
1264          "expected same number of operands and map inputs");
1265   MLIRContext *ctx = map.getContext();
1266   Builder builder(ctx);
1267   SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
1268   unsigned numSymbols = syms.size();
1269   SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
1270   if (newSyms) {
1271     newSyms->clear();
1272     newSyms->append(syms.begin(), syms.end());
1273   }
1274 
1275   for (const auto &operand : llvm::enumerate(operands)) {
1276     // Compute replacement dim/sym of operand.
1277     AffineExpr replacement;
1278     auto dimIt = llvm::find(dims, operand.value());
1279     auto symIt = llvm::find(syms, operand.value());
1280     if (dimIt != dims.end()) {
1281       replacement =
1282           builder.getAffineDimExpr(std::distance(dims.begin(), dimIt));
1283     } else if (symIt != syms.end()) {
1284       replacement =
1285           builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt));
1286     } else {
1287       // This operand is neither a dimension nor a symbol. Add it as a new
1288       // symbol.
1289       replacement = builder.getAffineSymbolExpr(numSymbols++);
1290       if (newSyms)
1291         newSyms->push_back(operand.value());
1292     }
1293     // Add to corresponding replacements vector.
1294     if (operand.index() < map.getNumDims()) {
1295       dimReplacements[operand.index()] = replacement;
1296     } else {
1297       symReplacements[operand.index() - map.getNumDims()] = replacement;
1298     }
1299   }
1300 
1301   return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
1302                                    dims.size(), numSymbols);
1303 }
1304 
1305 LogicalResult
1306 mlir::getMultiAffineFunctionFromMap(AffineMap map,
1307                                     MultiAffineFunction &multiAff) {
1308   FlatLinearConstraints cst;
1309   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
1310   LogicalResult result = getFlattenedAffineExprs(map, &flattenedExprs, &cst);
1311 
1312   if (result.failed())
1313     return failure();
1314 
1315   DivisionRepr divs = cst.getLocalReprs();
1316   assert(divs.hasAllReprs() &&
1317          "AffineMap cannot produce divs without local representation");
1318 
1319   // TODO: We shouldn't have to do this conversion.
1320   Matrix<DynamicAPInt> mat(map.getNumResults(),
1321                            map.getNumInputs() + divs.getNumDivs() + 1);
1322   for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i)
1323     for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j)
1324       mat(i, j) = flattenedExprs[i][j];
1325 
1326   multiAff = MultiAffineFunction(
1327       PresburgerSpace::getRelationSpace(map.getNumDims(), map.getNumResults(),
1328                                         map.getNumSymbols(), divs.getNumDivs()),
1329       mat, divs);
1330 
1331   return success();
1332 }
1333