xref: /llvm-project/mlir/lib/Analysis/FlatLinearValueConstraints.cpp (revision a24c468782010e17563f6aa93c5bb173c7f873b2)
15b0055a4SMatthias Springer //===- FlatLinearValueConstraints.cpp - Linear Constraint -----------------===//
25b0055a4SMatthias Springer //
35b0055a4SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45b0055a4SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
55b0055a4SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65b0055a4SMatthias Springer //
75b0055a4SMatthias Springer //===----------------------------------------------------------------------===//
85b0055a4SMatthias Springer 
95b0055a4SMatthias Springer #include "mlir/Analysis/FlatLinearValueConstraints.h"
105b0055a4SMatthias Springer 
115b0055a4SMatthias Springer #include "mlir/Analysis/Presburger/LinearTransform.h"
1224da7fa0SBharathi Ramana Joshi #include "mlir/Analysis/Presburger/PresburgerSpace.h"
135b0055a4SMatthias Springer #include "mlir/Analysis/Presburger/Simplex.h"
145b0055a4SMatthias Springer #include "mlir/Analysis/Presburger/Utils.h"
155b0055a4SMatthias Springer #include "mlir/IR/AffineExprVisitor.h"
165b0055a4SMatthias Springer #include "mlir/IR/Builders.h"
175b0055a4SMatthias Springer #include "mlir/IR/IntegerSet.h"
185b0055a4SMatthias Springer #include "mlir/Support/LLVM.h"
195b0055a4SMatthias Springer #include "llvm/ADT/STLExtras.h"
205b0055a4SMatthias Springer #include "llvm/ADT/SmallPtrSet.h"
215b0055a4SMatthias Springer #include "llvm/ADT/SmallVector.h"
225b0055a4SMatthias Springer #include "llvm/Support/Debug.h"
235b0055a4SMatthias Springer #include "llvm/Support/raw_ostream.h"
245b0055a4SMatthias Springer #include <optional>
255b0055a4SMatthias Springer 
265b0055a4SMatthias Springer #define DEBUG_TYPE "flat-value-constraints"
275b0055a4SMatthias Springer 
285b0055a4SMatthias Springer using namespace mlir;
295b0055a4SMatthias Springer using namespace presburger;
305b0055a4SMatthias Springer 
315b0055a4SMatthias Springer //===----------------------------------------------------------------------===//
325b0055a4SMatthias Springer // AffineExprFlattener
335b0055a4SMatthias Springer //===----------------------------------------------------------------------===//
345b0055a4SMatthias Springer 
355b0055a4SMatthias Springer namespace {
365b0055a4SMatthias Springer 
375b0055a4SMatthias Springer // See comments for SimpleAffineExprFlattener.
3829a925abSBenjamin Maxwell // An AffineExprFlattenerWithLocalVars extends a SimpleAffineExprFlattener by
3929a925abSBenjamin Maxwell // recording constraint information associated with mod's, floordiv's, and
4029a925abSBenjamin Maxwell // ceildiv's in FlatLinearConstraints 'localVarCst'.
415b0055a4SMatthias Springer struct AffineExprFlattener : public SimpleAffineExprFlattener {
4229a925abSBenjamin Maxwell   using SimpleAffineExprFlattener::SimpleAffineExprFlattener;
4329a925abSBenjamin Maxwell 
445b0055a4SMatthias Springer   // Constraints connecting newly introduced local variables (for mod's and
455b0055a4SMatthias Springer   // div's) to existing (dimensional and symbolic) ones. These are always
465b0055a4SMatthias Springer   // inequalities.
475b0055a4SMatthias Springer   IntegerPolyhedron localVarCst;
485b0055a4SMatthias Springer 
495b0055a4SMatthias Springer   AffineExprFlattener(unsigned nDims, unsigned nSymbols)
505b0055a4SMatthias Springer       : SimpleAffineExprFlattener(nDims, nSymbols),
5129a925abSBenjamin Maxwell         localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {};
525b0055a4SMatthias Springer 
535b0055a4SMatthias Springer private:
545b0055a4SMatthias Springer   // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
555b0055a4SMatthias Springer   // The local variable added is always a floordiv of a pure add/mul affine
565b0055a4SMatthias Springer   // function of other variables, coefficients of which are specified in
575b0055a4SMatthias Springer   // `dividend' and with respect to the positive constant `divisor'. localExpr
585b0055a4SMatthias Springer   // is the simplified tree expression (AffineExpr) corresponding to the
595b0055a4SMatthias Springer   // quantifier.
605b0055a4SMatthias Springer   void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
615b0055a4SMatthias Springer                           AffineExpr localExpr) override {
625b0055a4SMatthias Springer     SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
635b0055a4SMatthias Springer     // Update localVarCst.
645b0055a4SMatthias Springer     localVarCst.addLocalFloorDiv(dividend, divisor);
655b0055a4SMatthias Springer   }
6629a925abSBenjamin Maxwell 
6729a925abSBenjamin Maxwell   LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
6829a925abSBenjamin Maxwell                                      ArrayRef<int64_t> rhs,
6929a925abSBenjamin Maxwell                                      AffineExpr localExpr) override {
7029a925abSBenjamin Maxwell     // AffineExprFlattener does not support semi-affine expressions.
7129a925abSBenjamin Maxwell     return failure();
7229a925abSBenjamin Maxwell   }
7329a925abSBenjamin Maxwell };
7429a925abSBenjamin Maxwell 
7529a925abSBenjamin Maxwell // A SemiAffineExprFlattener is an AffineExprFlattenerWithLocalVars that adds
7629a925abSBenjamin Maxwell // conservative bounds for semi-affine expressions (given assumptions hold). If
7729a925abSBenjamin Maxwell // the assumptions required to add the semi-affine bounds are found not to hold
7829a925abSBenjamin Maxwell // the final constraints set will be empty/inconsistent. If the assumptions are
7929a925abSBenjamin Maxwell // never contradicted the final bounds still only will be correct if the
8029a925abSBenjamin Maxwell // assumptions hold.
8129a925abSBenjamin Maxwell struct SemiAffineExprFlattener : public AffineExprFlattener {
8229a925abSBenjamin Maxwell   using AffineExprFlattener::AffineExprFlattener;
8329a925abSBenjamin Maxwell 
8429a925abSBenjamin Maxwell   LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
8529a925abSBenjamin Maxwell                                      ArrayRef<int64_t> rhs,
8629a925abSBenjamin Maxwell                                      AffineExpr localExpr) override {
8729a925abSBenjamin Maxwell     auto result =
8829a925abSBenjamin Maxwell         SimpleAffineExprFlattener::addLocalIdSemiAffine(lhs, rhs, localExpr);
8929a925abSBenjamin Maxwell     assert(succeeded(result) &&
9029a925abSBenjamin Maxwell            "unexpected failure in SimpleAffineExprFlattener");
9129a925abSBenjamin Maxwell     (void)result;
9229a925abSBenjamin Maxwell 
9329a925abSBenjamin Maxwell     if (localExpr.getKind() == AffineExprKind::Mod) {
9429a925abSBenjamin Maxwell       // Given two numbers a and b, division is defined as:
9529a925abSBenjamin Maxwell       //
9629a925abSBenjamin Maxwell       // a = bq + r
9729a925abSBenjamin Maxwell       // 0 <= r < |b| (where |x| is the absolute value of x)
9829a925abSBenjamin Maxwell       //
9929a925abSBenjamin Maxwell       // q = a floordiv b
10029a925abSBenjamin Maxwell       // r = a mod b
10129a925abSBenjamin Maxwell 
10229a925abSBenjamin Maxwell       // Add a new local variable (r) to represent the mod.
10329a925abSBenjamin Maxwell       unsigned rPos = localVarCst.appendVar(VarKind::Local);
10429a925abSBenjamin Maxwell 
10529a925abSBenjamin Maxwell       // r >= 0 (Can ALWAYS be added)
10629a925abSBenjamin Maxwell       localVarCst.addBound(BoundType::LB, rPos, 0);
10729a925abSBenjamin Maxwell 
10829a925abSBenjamin Maxwell       // r < b (Can be added if b > 0, which we assume here)
10929a925abSBenjamin Maxwell       ArrayRef<int64_t> b = rhs;
11029a925abSBenjamin Maxwell       SmallVector<int64_t> bSubR(b);
11129a925abSBenjamin Maxwell       bSubR.insert(bSubR.begin() + rPos, -1);
11229a925abSBenjamin Maxwell       // Note: bSubR = b - r
11329a925abSBenjamin Maxwell       // So this adds the bound b - r >= 1 (equivalent to r < b)
11429a925abSBenjamin Maxwell       localVarCst.addBound(BoundType::LB, bSubR, 1);
11529a925abSBenjamin Maxwell 
11629a925abSBenjamin Maxwell       // Note: The assumption of b > 0 is based on the affine expression docs,
11729a925abSBenjamin Maxwell       // which state "RHS of mod is always a constant or a symbolic expression
11829a925abSBenjamin Maxwell       // with a positive value." (see AffineExprKind in AffineExpr.h). If this
11929a925abSBenjamin Maxwell       // assumption does not hold constraints (added above) are a contradiction.
12029a925abSBenjamin Maxwell 
12129a925abSBenjamin Maxwell       return success();
12229a925abSBenjamin Maxwell     }
12329a925abSBenjamin Maxwell 
12429a925abSBenjamin Maxwell     // TODO: Support other semi-affine expressions.
12529a925abSBenjamin Maxwell     return failure();
12629a925abSBenjamin Maxwell   }
1275b0055a4SMatthias Springer };
1285b0055a4SMatthias Springer 
1295b0055a4SMatthias Springer } // namespace
1305b0055a4SMatthias Springer 
1315b0055a4SMatthias Springer // Flattens the expressions in map. Returns failure if 'expr' was unable to be
132dc4786b4Slong.chen // flattened. For example two specific cases:
13329a925abSBenjamin Maxwell // 1. an unhandled semi-affine expressions is found.
134dc4786b4Slong.chen // 2. has poison expression (i.e., division by zero).
1355b0055a4SMatthias Springer static LogicalResult
1365b0055a4SMatthias Springer getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
1375b0055a4SMatthias Springer                         unsigned numSymbols,
1385b0055a4SMatthias Springer                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
13929a925abSBenjamin Maxwell                         FlatLinearConstraints *localVarCst,
14029a925abSBenjamin Maxwell                         bool addConservativeSemiAffineBounds = false) {
1415b0055a4SMatthias Springer   if (exprs.empty()) {
1425b0055a4SMatthias Springer     if (localVarCst)
1435b0055a4SMatthias Springer       *localVarCst = FlatLinearConstraints(numDims, numSymbols);
1445b0055a4SMatthias Springer     return success();
1455b0055a4SMatthias Springer   }
1465b0055a4SMatthias Springer 
14729a925abSBenjamin Maxwell   auto flattenExprs = [&](AffineExprFlattener &flattener) -> LogicalResult {
1485b0055a4SMatthias Springer     // Use the same flattener to simplify each expression successively. This way
1495b0055a4SMatthias Springer     // local variables / expressions are shared.
1505b0055a4SMatthias Springer     for (auto expr : exprs) {
151dc4786b4Slong.chen       auto flattenResult = flattener.walkPostOrder(expr);
152dc4786b4Slong.chen       if (failed(flattenResult))
153dc4786b4Slong.chen         return failure();
1545b0055a4SMatthias Springer     }
1555b0055a4SMatthias Springer 
1565b0055a4SMatthias Springer     assert(flattener.operandExprStack.size() == exprs.size());
1575b0055a4SMatthias Springer     flattenedExprs->clear();
1585b0055a4SMatthias Springer     flattenedExprs->assign(flattener.operandExprStack.begin(),
1595b0055a4SMatthias Springer                            flattener.operandExprStack.end());
1605b0055a4SMatthias Springer 
1615b0055a4SMatthias Springer     if (localVarCst)
1625b0055a4SMatthias Springer       localVarCst->clearAndCopyFrom(flattener.localVarCst);
1635b0055a4SMatthias Springer 
1645b0055a4SMatthias Springer     return success();
16529a925abSBenjamin Maxwell   };
16629a925abSBenjamin Maxwell 
16729a925abSBenjamin Maxwell   if (addConservativeSemiAffineBounds) {
16829a925abSBenjamin Maxwell     SemiAffineExprFlattener flattener(numDims, numSymbols);
16929a925abSBenjamin Maxwell     return flattenExprs(flattener);
17029a925abSBenjamin Maxwell   }
17129a925abSBenjamin Maxwell 
17229a925abSBenjamin Maxwell   AffineExprFlattener flattener(numDims, numSymbols);
17329a925abSBenjamin Maxwell   return flattenExprs(flattener);
1745b0055a4SMatthias Springer }
1755b0055a4SMatthias Springer 
1765b0055a4SMatthias Springer // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
17729a925abSBenjamin Maxwell // be flattened (an unhandled semi-affine was found).
17829a925abSBenjamin Maxwell LogicalResult mlir::getFlattenedAffineExpr(
17929a925abSBenjamin Maxwell     AffineExpr expr, unsigned numDims, unsigned numSymbols,
18029a925abSBenjamin Maxwell     SmallVectorImpl<int64_t> *flattenedExpr, FlatLinearConstraints *localVarCst,
18129a925abSBenjamin Maxwell     bool addConservativeSemiAffineBounds) {
1825b0055a4SMatthias Springer   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
18329a925abSBenjamin Maxwell   LogicalResult ret =
18429a925abSBenjamin Maxwell       ::getFlattenedAffineExprs({expr}, numDims, numSymbols, &flattenedExprs,
18529a925abSBenjamin Maxwell                                 localVarCst, addConservativeSemiAffineBounds);
1865b0055a4SMatthias Springer   *flattenedExpr = flattenedExprs[0];
1875b0055a4SMatthias Springer   return ret;
1885b0055a4SMatthias Springer }
1895b0055a4SMatthias Springer 
1905b0055a4SMatthias Springer /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
19129a925abSBenjamin Maxwell /// flattened (i.e., an unhandled semi-affine was found).
1925b0055a4SMatthias Springer LogicalResult mlir::getFlattenedAffineExprs(
1935b0055a4SMatthias Springer     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
19429a925abSBenjamin Maxwell     FlatLinearConstraints *localVarCst, bool addConservativeSemiAffineBounds) {
1955b0055a4SMatthias Springer   if (map.getNumResults() == 0) {
1965b0055a4SMatthias Springer     if (localVarCst)
1975b0055a4SMatthias Springer       *localVarCst =
1985b0055a4SMatthias Springer           FlatLinearConstraints(map.getNumDims(), map.getNumSymbols());
1995b0055a4SMatthias Springer     return success();
2005b0055a4SMatthias Springer   }
20129a925abSBenjamin Maxwell   return ::getFlattenedAffineExprs(
20229a925abSBenjamin Maxwell       map.getResults(), map.getNumDims(), map.getNumSymbols(), flattenedExprs,
20329a925abSBenjamin Maxwell       localVarCst, addConservativeSemiAffineBounds);
2045b0055a4SMatthias Springer }
2055b0055a4SMatthias Springer 
2065b0055a4SMatthias Springer LogicalResult mlir::getFlattenedAffineExprs(
2075b0055a4SMatthias Springer     IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
2085b0055a4SMatthias Springer     FlatLinearConstraints *localVarCst) {
2095b0055a4SMatthias Springer   if (set.getNumConstraints() == 0) {
2105b0055a4SMatthias Springer     if (localVarCst)
2115b0055a4SMatthias Springer       *localVarCst =
2125b0055a4SMatthias Springer           FlatLinearConstraints(set.getNumDims(), set.getNumSymbols());
2135b0055a4SMatthias Springer     return success();
2145b0055a4SMatthias Springer   }
2155b0055a4SMatthias Springer   return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
2165b0055a4SMatthias Springer                                    set.getNumSymbols(), flattenedExprs,
2175b0055a4SMatthias Springer                                    localVarCst);
2185b0055a4SMatthias Springer }
2195b0055a4SMatthias Springer 
2205b0055a4SMatthias Springer //===----------------------------------------------------------------------===//
2215b0055a4SMatthias Springer // FlatLinearConstraints
2225b0055a4SMatthias Springer //===----------------------------------------------------------------------===//
2235b0055a4SMatthias Springer 
2245b0055a4SMatthias Springer // Similar to `composeMap` except that no Values need be associated with the
2255b0055a4SMatthias Springer // constraint system nor are they looked at -- the dimensions and symbols of
2265b0055a4SMatthias Springer // `other` are expected to correspond 1:1 to `this` system.
2275b0055a4SMatthias Springer LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) {
2285b0055a4SMatthias Springer   assert(other.getNumDims() == getNumDimVars() && "dim mismatch");
2295b0055a4SMatthias Springer   assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
2305b0055a4SMatthias Springer 
2315b0055a4SMatthias Springer   std::vector<SmallVector<int64_t, 8>> flatExprs;
2325b0055a4SMatthias Springer   if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
2335b0055a4SMatthias Springer     return failure();
2345b0055a4SMatthias Springer   assert(flatExprs.size() == other.getNumResults());
2355b0055a4SMatthias Springer 
2365b0055a4SMatthias Springer   // Add dimensions corresponding to the map's results.
2375b0055a4SMatthias Springer   insertDimVar(/*pos=*/0, /*num=*/other.getNumResults());
2385b0055a4SMatthias Springer 
2395b0055a4SMatthias Springer   // We add one equality for each result connecting the result dim of the map to
2405b0055a4SMatthias Springer   // the other variables.
2415b0055a4SMatthias Springer   // E.g.: if the expression is 16*i0 + i1, and this is the r^th
2425b0055a4SMatthias Springer   // iteration/result of the value map, we are adding the equality:
2435b0055a4SMatthias Springer   // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
2445b0055a4SMatthias Springer   // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
2455b0055a4SMatthias Springer   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
2465b0055a4SMatthias Springer     const auto &flatExpr = flatExprs[r];
2475b0055a4SMatthias Springer     assert(flatExpr.size() >= other.getNumInputs() + 1);
2485b0055a4SMatthias Springer 
2495b0055a4SMatthias Springer     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
2505b0055a4SMatthias Springer     // Set the coefficient for this result to one.
2515b0055a4SMatthias Springer     eqToAdd[r] = 1;
2525b0055a4SMatthias Springer 
2535b0055a4SMatthias Springer     // Dims and symbols.
2545b0055a4SMatthias Springer     for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
2555b0055a4SMatthias Springer       // Negate `eq[r]` since the newly added dimension will be set to this one.
2565b0055a4SMatthias Springer       eqToAdd[e + i] = -flatExpr[i];
2575b0055a4SMatthias Springer     }
2585b0055a4SMatthias Springer     // Local columns of `eq` are at the beginning.
2595b0055a4SMatthias Springer     unsigned j = getNumDimVars() + getNumSymbolVars();
2605b0055a4SMatthias Springer     unsigned end = flatExpr.size() - 1;
2615b0055a4SMatthias Springer     for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
2625b0055a4SMatthias Springer       eqToAdd[j] = -flatExpr[i];
2635b0055a4SMatthias Springer     }
2645b0055a4SMatthias Springer 
2655b0055a4SMatthias Springer     // Constant term.
2665b0055a4SMatthias Springer     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
2675b0055a4SMatthias Springer 
2685b0055a4SMatthias Springer     // Add the equality connecting the result of the map to this constraint set.
2695b0055a4SMatthias Springer     addEquality(eqToAdd);
2705b0055a4SMatthias Springer   }
2715b0055a4SMatthias Springer 
2725b0055a4SMatthias Springer   return success();
2735b0055a4SMatthias Springer }
2745b0055a4SMatthias Springer 
2755b0055a4SMatthias Springer // Determine whether the variable at 'pos' (say var_r) can be expressed as
2765b0055a4SMatthias Springer // modulo of another known variable (say var_n) w.r.t a constant. For example,
2775b0055a4SMatthias Springer // if the following constraints hold true:
2785b0055a4SMatthias Springer // ```
2795b0055a4SMatthias Springer // 0 <= var_r <= divisor - 1
2805b0055a4SMatthias Springer // var_n - (divisor * q_expr) = var_r
2815b0055a4SMatthias Springer // ```
2825b0055a4SMatthias Springer // where `var_n` is a known variable (called dividend), and `q_expr` is an
2835b0055a4SMatthias Springer // `AffineExpr` (called the quotient expression), `var_r` can be written as:
2845b0055a4SMatthias Springer //
2855b0055a4SMatthias Springer // `var_r = var_n mod divisor`.
2865b0055a4SMatthias Springer //
2875b0055a4SMatthias Springer // Additionally, in a special case of the above constaints where `q_expr` is an
2885b0055a4SMatthias Springer // variable itself that is not yet known (say `var_q`), it can be written as a
2895b0055a4SMatthias Springer // floordiv in the following way:
2905b0055a4SMatthias Springer //
2915b0055a4SMatthias Springer // `var_q = var_n floordiv divisor`.
2925b0055a4SMatthias Springer //
293c45c9625SVinayaka Bandishti // First 'num' dimensional variables starting at 'offset' are
294c45c9625SVinayaka Bandishti // derived/to-be-derived in terms of the remaining variables. The remaining
295c45c9625SVinayaka Bandishti // variables are assigned trivial affine expressions in `memo`. For example,
296c45c9625SVinayaka Bandishti // memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2:
297c45c9625SVinayaka Bandishti // memo ==>  d0  d1  .   .   d2 ...
298c45c9625SVinayaka Bandishti // cst  ==>  c0  c1  c2  c3  c4 ...
299c45c9625SVinayaka Bandishti //
3005b0055a4SMatthias Springer // Returns true if the above mod or floordiv are detected, updating 'memo' with
3015b0055a4SMatthias Springer // these new expressions. Returns false otherwise.
3025b0055a4SMatthias Springer static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos,
303c45c9625SVinayaka Bandishti                         unsigned offset, unsigned num, int64_t lbConst,
304c45c9625SVinayaka Bandishti                         int64_t ubConst, MLIRContext *context,
305c45c9625SVinayaka Bandishti                         SmallVectorImpl<AffineExpr> &memo) {
3065b0055a4SMatthias Springer   assert(pos < cst.getNumVars() && "invalid position");
3075b0055a4SMatthias Springer 
3085b0055a4SMatthias Springer   // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can
3095b0055a4SMatthias Springer   // be determined.
3105b0055a4SMatthias Springer   if (lbConst != 0 || ubConst < 1)
3115b0055a4SMatthias Springer     return false;
3125b0055a4SMatthias Springer   int64_t divisor = ubConst + 1;
3135b0055a4SMatthias Springer 
3145b0055a4SMatthias Springer   // Check for the aforementioned conditions in each equality.
3155b0055a4SMatthias Springer   for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
3165b0055a4SMatthias Springer        curEquality < numEqualities; curEquality++) {
3175b0055a4SMatthias Springer     int64_t coefficientAtPos = cst.atEq64(curEquality, pos);
3185b0055a4SMatthias Springer     // If current equality does not involve `var_r`, continue to the next
3195b0055a4SMatthias Springer     // equality.
3205b0055a4SMatthias Springer     if (coefficientAtPos == 0)
3215b0055a4SMatthias Springer       continue;
3225b0055a4SMatthias Springer 
3235b0055a4SMatthias Springer     // Constant term should be 0 in this equality.
3245b0055a4SMatthias Springer     if (cst.atEq64(curEquality, cst.getNumCols() - 1) != 0)
3255b0055a4SMatthias Springer       continue;
3265b0055a4SMatthias Springer 
3275b0055a4SMatthias Springer     // Traverse through the equality and construct the dividend expression
3285b0055a4SMatthias Springer     // `dividendExpr`, to contain all the variables which are known and are
3295b0055a4SMatthias Springer     // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
3305b0055a4SMatthias Springer     // `dividendExpr` gets simplified into a single variable `var_n` discussed
3315b0055a4SMatthias Springer     // above.
3325b0055a4SMatthias Springer     auto dividendExpr = getAffineConstantExpr(0, context);
3335b0055a4SMatthias Springer 
3345b0055a4SMatthias Springer     // Track the terms that go into quotient expression, later used to detect
3355b0055a4SMatthias Springer     // additional floordiv.
3365b0055a4SMatthias Springer     unsigned quotientCount = 0;
3375b0055a4SMatthias Springer     int quotientPosition = -1;
3385b0055a4SMatthias Springer     int quotientSign = 1;
3395b0055a4SMatthias Springer 
3405b0055a4SMatthias Springer     // Consider each term in the current equality.
3415b0055a4SMatthias Springer     unsigned curVar, e;
3425b0055a4SMatthias Springer     for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) {
3435b0055a4SMatthias Springer       // Ignore var_r.
3445b0055a4SMatthias Springer       if (curVar == pos)
3455b0055a4SMatthias Springer         continue;
3465b0055a4SMatthias Springer       int64_t coefficientOfCurVar = cst.atEq64(curEquality, curVar);
3475b0055a4SMatthias Springer       // Ignore vars that do not contribute to the current equality.
3485b0055a4SMatthias Springer       if (coefficientOfCurVar == 0)
3495b0055a4SMatthias Springer         continue;
3505b0055a4SMatthias Springer       // Check if the current var goes into the quotient expression.
3515b0055a4SMatthias Springer       if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) {
3525b0055a4SMatthias Springer         quotientCount++;
3535b0055a4SMatthias Springer         quotientPosition = curVar;
3545b0055a4SMatthias Springer         quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1;
3555b0055a4SMatthias Springer         continue;
3565b0055a4SMatthias Springer       }
3575b0055a4SMatthias Springer       // Variables that are part of dividendExpr should be known.
3585b0055a4SMatthias Springer       if (!memo[curVar])
3595b0055a4SMatthias Springer         break;
3605b0055a4SMatthias Springer       // Append the current variable to the dividend expression.
3615b0055a4SMatthias Springer       dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar;
3625b0055a4SMatthias Springer     }
3635b0055a4SMatthias Springer 
3645b0055a4SMatthias Springer     // Can't construct expression as it depends on a yet uncomputed var.
3655b0055a4SMatthias Springer     if (curVar < e)
3665b0055a4SMatthias Springer       continue;
3675b0055a4SMatthias Springer 
3685b0055a4SMatthias Springer     // Express `var_r` in terms of the other vars collected so far.
3695b0055a4SMatthias Springer     if (coefficientAtPos > 0)
3705b0055a4SMatthias Springer       dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
3715b0055a4SMatthias Springer     else
3725b0055a4SMatthias Springer       dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
3735b0055a4SMatthias Springer 
3745b0055a4SMatthias Springer     // Simplify the expression.
3755b0055a4SMatthias Springer     dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimVars(),
3765b0055a4SMatthias Springer                                       cst.getNumSymbolVars());
3775b0055a4SMatthias Springer     // Only if the final dividend expression is just a single var (which we call
3785b0055a4SMatthias Springer     // `var_n`), we can proceed.
3795b0055a4SMatthias Springer     // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
3805b0055a4SMatthias Springer     // to dims themselves.
3811609f1c2Slong.chen     auto dimExpr = dyn_cast<AffineDimExpr>(dividendExpr);
3825b0055a4SMatthias Springer     if (!dimExpr)
3835b0055a4SMatthias Springer       continue;
3845b0055a4SMatthias Springer 
3855b0055a4SMatthias Springer     // Express `var_r` as `var_n % divisor` and store the expression in `memo`.
3865b0055a4SMatthias Springer     if (quotientCount >= 1) {
387c45c9625SVinayaka Bandishti       // Find the column corresponding to `dimExpr`. `num` columns starting at
388c45c9625SVinayaka Bandishti       // `offset` correspond to previously unknown variables. The column
389c45c9625SVinayaka Bandishti       // corresponding to the trivially known `dimExpr` can be on either side
390c45c9625SVinayaka Bandishti       // of these.
391c45c9625SVinayaka Bandishti       unsigned dimExprPos = dimExpr.getPosition();
392c45c9625SVinayaka Bandishti       unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num;
393c45c9625SVinayaka Bandishti       auto ub = cst.getConstantBound64(BoundType::UB, dimExprCol);
3945b0055a4SMatthias Springer       // If `var_n` has an upperbound that is less than the divisor, mod can be
3955b0055a4SMatthias Springer       // eliminated altogether.
3965b0055a4SMatthias Springer       if (ub && *ub < divisor)
3975b0055a4SMatthias Springer         memo[pos] = dimExpr;
3985b0055a4SMatthias Springer       else
3995b0055a4SMatthias Springer         memo[pos] = dimExpr % divisor;
4005b0055a4SMatthias Springer       // If a unique quotient `var_q` was seen, it can be expressed as
4015b0055a4SMatthias Springer       // `var_n floordiv divisor`.
4025b0055a4SMatthias Springer       if (quotientCount == 1 && !memo[quotientPosition])
4035b0055a4SMatthias Springer         memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
4045b0055a4SMatthias Springer 
4055b0055a4SMatthias Springer       return true;
4065b0055a4SMatthias Springer     }
4075b0055a4SMatthias Springer   }
4085b0055a4SMatthias Springer   return false;
4095b0055a4SMatthias Springer }
4105b0055a4SMatthias Springer 
4115b0055a4SMatthias Springer /// Check if the pos^th variable can be expressed as a floordiv of an affine
4125b0055a4SMatthias Springer /// function of other variables (where the divisor is a positive constant)
4135b0055a4SMatthias Springer /// given the initial set of expressions in `exprs`. If it can be, the
4145b0055a4SMatthias Springer /// corresponding position in `exprs` is set as the detected affine expr. For
4155b0055a4SMatthias Springer /// eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4. An equality can
4165b0055a4SMatthias Springer /// also yield a floordiv: eg.  4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
4175b0055a4SMatthias Springer /// <= i <= 32q + 31 => q = i floordiv 32.
4185b0055a4SMatthias Springer static bool detectAsFloorDiv(const FlatLinearConstraints &cst, unsigned pos,
4195b0055a4SMatthias Springer                              MLIRContext *context,
4205b0055a4SMatthias Springer                              SmallVectorImpl<AffineExpr> &exprs) {
4215b0055a4SMatthias Springer   assert(pos < cst.getNumVars() && "invalid position");
4225b0055a4SMatthias Springer 
4235b0055a4SMatthias Springer   // Get upper-lower bound pair for this variable.
4245b0055a4SMatthias Springer   SmallVector<bool, 8> foundRepr(cst.getNumVars(), false);
4255b0055a4SMatthias Springer   for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i)
4265b0055a4SMatthias Springer     if (exprs[i])
4275b0055a4SMatthias Springer       foundRepr[i] = true;
4285b0055a4SMatthias Springer 
4295b0055a4SMatthias Springer   SmallVector<int64_t, 8> dividend(cst.getNumCols());
4305b0055a4SMatthias Springer   unsigned divisor;
4315b0055a4SMatthias Springer   auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
4325b0055a4SMatthias Springer 
4335b0055a4SMatthias Springer   // No upper-lower bound pair found for this var.
4345b0055a4SMatthias Springer   if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality)
4355b0055a4SMatthias Springer     return false;
4365b0055a4SMatthias Springer 
4375b0055a4SMatthias Springer   // Construct the dividend expression.
4385b0055a4SMatthias Springer   auto dividendExpr = getAffineConstantExpr(dividend.back(), context);
4395b0055a4SMatthias Springer   for (unsigned c = 0, f = cst.getNumVars(); c < f; c++)
4405b0055a4SMatthias Springer     if (dividend[c] != 0)
4415b0055a4SMatthias Springer       dividendExpr = dividendExpr + dividend[c] * exprs[c];
4425b0055a4SMatthias Springer 
4435b0055a4SMatthias Springer   // Successfully detected the floordiv.
4445b0055a4SMatthias Springer   exprs[pos] = dividendExpr.floorDiv(divisor);
4455b0055a4SMatthias Springer   return true;
4465b0055a4SMatthias Springer }
4475b0055a4SMatthias Springer 
4485b0055a4SMatthias Springer std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound(
4495b0055a4SMatthias Springer     unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
4505b0055a4SMatthias Springer     ArrayRef<AffineExpr> localExprs, MLIRContext *context,
4515b0055a4SMatthias Springer     bool closedUB) const {
4525b0055a4SMatthias Springer   assert(pos + offset < getNumDimVars() && "invalid dim start pos");
4535b0055a4SMatthias Springer   assert(symStartPos >= (pos + offset) && "invalid sym start pos");
4545b0055a4SMatthias Springer   assert(getNumLocalVars() == localExprs.size() &&
4555b0055a4SMatthias Springer          "incorrect local exprs count");
4565b0055a4SMatthias Springer 
4575b0055a4SMatthias Springer   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
4585b0055a4SMatthias Springer   getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
4595b0055a4SMatthias Springer                                offset, num);
4605b0055a4SMatthias Springer 
4615b0055a4SMatthias Springer   /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
4625b0055a4SMatthias Springer   auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
4635b0055a4SMatthias Springer     b.clear();
4645b0055a4SMatthias Springer     for (unsigned i = 0, e = a.size(); i < e; ++i) {
4655b0055a4SMatthias Springer       if (i < offset || i >= offset + num)
4665b0055a4SMatthias Springer         b.push_back(a[i]);
4675b0055a4SMatthias Springer     }
4685b0055a4SMatthias Springer   };
4695b0055a4SMatthias Springer 
4705b0055a4SMatthias Springer   SmallVector<int64_t, 8> lb, ub;
4715b0055a4SMatthias Springer   SmallVector<AffineExpr, 4> lbExprs;
4725b0055a4SMatthias Springer   unsigned dimCount = symStartPos - num;
4735b0055a4SMatthias Springer   unsigned symCount = getNumDimAndSymbolVars() - symStartPos;
4745b0055a4SMatthias Springer   lbExprs.reserve(lbIndices.size() + eqIndices.size());
4755b0055a4SMatthias Springer   // Lower bound expressions.
4765b0055a4SMatthias Springer   for (auto idx : lbIndices) {
4775b0055a4SMatthias Springer     auto ineq = getInequality64(idx);
4785b0055a4SMatthias Springer     // Extract the lower bound (in terms of other coeff's + const), i.e., if
4795b0055a4SMatthias Springer     // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
4805b0055a4SMatthias Springer     // - 1.
4815b0055a4SMatthias Springer     addCoeffs(ineq, lb);
4825b0055a4SMatthias Springer     std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
4835b0055a4SMatthias Springer     auto expr =
4845b0055a4SMatthias Springer         getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
4855b0055a4SMatthias Springer     // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
4865b0055a4SMatthias Springer     int64_t divisor = std::abs(ineq[pos + offset]);
4875b0055a4SMatthias Springer     expr = (expr + divisor - 1).floorDiv(divisor);
4885b0055a4SMatthias Springer     lbExprs.push_back(expr);
4895b0055a4SMatthias Springer   }
4905b0055a4SMatthias Springer 
4915b0055a4SMatthias Springer   SmallVector<AffineExpr, 4> ubExprs;
4925b0055a4SMatthias Springer   ubExprs.reserve(ubIndices.size() + eqIndices.size());
4935b0055a4SMatthias Springer   // Upper bound expressions.
4945b0055a4SMatthias Springer   for (auto idx : ubIndices) {
4955b0055a4SMatthias Springer     auto ineq = getInequality64(idx);
4965b0055a4SMatthias Springer     // Extract the upper bound (in terms of other coeff's + const).
4975b0055a4SMatthias Springer     addCoeffs(ineq, ub);
4985b0055a4SMatthias Springer     auto expr =
4995b0055a4SMatthias Springer         getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
5005b0055a4SMatthias Springer     expr = expr.floorDiv(std::abs(ineq[pos + offset]));
5015b0055a4SMatthias Springer     int64_t ubAdjustment = closedUB ? 0 : 1;
5025b0055a4SMatthias Springer     ubExprs.push_back(expr + ubAdjustment);
5035b0055a4SMatthias Springer   }
5045b0055a4SMatthias Springer 
5055b0055a4SMatthias Springer   // Equalities. It's both a lower and a upper bound.
5065b0055a4SMatthias Springer   SmallVector<int64_t, 4> b;
5075b0055a4SMatthias Springer   for (auto idx : eqIndices) {
5085b0055a4SMatthias Springer     auto eq = getEquality64(idx);
5095b0055a4SMatthias Springer     addCoeffs(eq, b);
5105b0055a4SMatthias Springer     if (eq[pos + offset] > 0)
5115b0055a4SMatthias Springer       std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
5125b0055a4SMatthias Springer 
5135b0055a4SMatthias Springer     // Extract the upper bound (in terms of other coeff's + const).
5145b0055a4SMatthias Springer     auto expr =
5155b0055a4SMatthias Springer         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
5165b0055a4SMatthias Springer     expr = expr.floorDiv(std::abs(eq[pos + offset]));
5175b0055a4SMatthias Springer     // Upper bound is exclusive.
5185b0055a4SMatthias Springer     ubExprs.push_back(expr + 1);
5195b0055a4SMatthias Springer     // Lower bound.
5205b0055a4SMatthias Springer     expr =
5215b0055a4SMatthias Springer         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
5225b0055a4SMatthias Springer     expr = expr.ceilDiv(std::abs(eq[pos + offset]));
5235b0055a4SMatthias Springer     lbExprs.push_back(expr);
5245b0055a4SMatthias Springer   }
5255b0055a4SMatthias Springer 
5265b0055a4SMatthias Springer   auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
5275b0055a4SMatthias Springer   auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
5285b0055a4SMatthias Springer 
5295b0055a4SMatthias Springer   return {lbMap, ubMap};
5305b0055a4SMatthias Springer }
5315b0055a4SMatthias Springer 
5325b0055a4SMatthias Springer /// Computes the lower and upper bounds of the first 'num' dimensional
5335b0055a4SMatthias Springer /// variables (starting at 'offset') as affine maps of the remaining
5345b0055a4SMatthias Springer /// variables (dimensional and symbolic variables). Local variables are
5355b0055a4SMatthias Springer /// themselves explicitly computed as affine functions of other variables in
5365b0055a4SMatthias Springer /// this process if needed.
5375b0055a4SMatthias Springer void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
5385b0055a4SMatthias Springer                                            MLIRContext *context,
5395b0055a4SMatthias Springer                                            SmallVectorImpl<AffineMap> *lbMaps,
5405b0055a4SMatthias Springer                                            SmallVectorImpl<AffineMap> *ubMaps,
5415b0055a4SMatthias Springer                                            bool closedUB) {
5427e5d300dSMatthias Springer   assert(offset + num <= getNumDimVars() && "invalid range");
5435b0055a4SMatthias Springer 
5445b0055a4SMatthias Springer   // Basic simplification.
5455b0055a4SMatthias Springer   normalizeConstraintsByGCD();
5465b0055a4SMatthias Springer 
5475b0055a4SMatthias Springer   LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
5485b0055a4SMatthias Springer                           << " variables\n");
5495b0055a4SMatthias Springer   LLVM_DEBUG(dump());
5505b0055a4SMatthias Springer 
5515b0055a4SMatthias Springer   // Record computed/detected variables.
5525b0055a4SMatthias Springer   SmallVector<AffineExpr, 8> memo(getNumVars());
5535b0055a4SMatthias Springer   // Initialize dimensional and symbolic variables.
5545b0055a4SMatthias Springer   for (unsigned i = 0, e = getNumDimVars(); i < e; i++) {
5555b0055a4SMatthias Springer     if (i < offset)
5565b0055a4SMatthias Springer       memo[i] = getAffineDimExpr(i, context);
5575b0055a4SMatthias Springer     else if (i >= offset + num)
5585b0055a4SMatthias Springer       memo[i] = getAffineDimExpr(i - num, context);
5595b0055a4SMatthias Springer   }
5605b0055a4SMatthias Springer   for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++)
5615b0055a4SMatthias Springer     memo[i] = getAffineSymbolExpr(i - getNumDimVars(), context);
5625b0055a4SMatthias Springer 
5635b0055a4SMatthias Springer   bool changed;
5645b0055a4SMatthias Springer   do {
5655b0055a4SMatthias Springer     changed = false;
5665b0055a4SMatthias Springer     // Identify yet unknown variables as constants or mod's / floordiv's of
5675b0055a4SMatthias Springer     // other variables if possible.
5685b0055a4SMatthias Springer     for (unsigned pos = 0; pos < getNumVars(); pos++) {
5695b0055a4SMatthias Springer       if (memo[pos])
5705b0055a4SMatthias Springer         continue;
5715b0055a4SMatthias Springer 
5725b0055a4SMatthias Springer       auto lbConst = getConstantBound64(BoundType::LB, pos);
5735b0055a4SMatthias Springer       auto ubConst = getConstantBound64(BoundType::UB, pos);
5745b0055a4SMatthias Springer       if (lbConst.has_value() && ubConst.has_value()) {
5755b0055a4SMatthias Springer         // Detect equality to a constant.
5765b0055a4SMatthias Springer         if (*lbConst == *ubConst) {
5775b0055a4SMatthias Springer           memo[pos] = getAffineConstantExpr(*lbConst, context);
5785b0055a4SMatthias Springer           changed = true;
5795b0055a4SMatthias Springer           continue;
5805b0055a4SMatthias Springer         }
5815b0055a4SMatthias Springer 
5825b0055a4SMatthias Springer         // Detect a variable as modulo of another variable w.r.t a
5835b0055a4SMatthias Springer         // constant.
584c45c9625SVinayaka Bandishti         if (detectAsMod(*this, pos, offset, num, *lbConst, *ubConst, context,
585c45c9625SVinayaka Bandishti                         memo)) {
5865b0055a4SMatthias Springer           changed = true;
5875b0055a4SMatthias Springer           continue;
5885b0055a4SMatthias Springer         }
5895b0055a4SMatthias Springer       }
5905b0055a4SMatthias Springer 
5915b0055a4SMatthias Springer       // Detect a variable as a floordiv of an affine function of other
5925b0055a4SMatthias Springer       // variables (divisor is a positive constant).
5935b0055a4SMatthias Springer       if (detectAsFloorDiv(*this, pos, context, memo)) {
5945b0055a4SMatthias Springer         changed = true;
5955b0055a4SMatthias Springer         continue;
5965b0055a4SMatthias Springer       }
5975b0055a4SMatthias Springer 
5985b0055a4SMatthias Springer       // Detect a variable as an expression of other variables.
5995b0055a4SMatthias Springer       unsigned idx;
6005b0055a4SMatthias Springer       if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) {
6015b0055a4SMatthias Springer         continue;
6025b0055a4SMatthias Springer       }
6035b0055a4SMatthias Springer 
6045b0055a4SMatthias Springer       // Build AffineExpr solving for variable 'pos' in terms of all others.
6055b0055a4SMatthias Springer       auto expr = getAffineConstantExpr(0, context);
6065b0055a4SMatthias Springer       unsigned j, e;
6075b0055a4SMatthias Springer       for (j = 0, e = getNumVars(); j < e; ++j) {
6085b0055a4SMatthias Springer         if (j == pos)
6095b0055a4SMatthias Springer           continue;
6105b0055a4SMatthias Springer         int64_t c = atEq64(idx, j);
6115b0055a4SMatthias Springer         if (c == 0)
6125b0055a4SMatthias Springer           continue;
6135b0055a4SMatthias Springer         // If any of the involved IDs hasn't been found yet, we can't proceed.
6145b0055a4SMatthias Springer         if (!memo[j])
6155b0055a4SMatthias Springer           break;
6165b0055a4SMatthias Springer         expr = expr + memo[j] * c;
6175b0055a4SMatthias Springer       }
6185b0055a4SMatthias Springer       if (j < e)
6195b0055a4SMatthias Springer         // Can't construct expression as it depends on a yet uncomputed
6205b0055a4SMatthias Springer         // variable.
6215b0055a4SMatthias Springer         continue;
6225b0055a4SMatthias Springer 
6235b0055a4SMatthias Springer       // Add constant term to AffineExpr.
6245b0055a4SMatthias Springer       expr = expr + atEq64(idx, getNumVars());
6255b0055a4SMatthias Springer       int64_t vPos = atEq64(idx, pos);
6265b0055a4SMatthias Springer       assert(vPos != 0 && "expected non-zero here");
6275b0055a4SMatthias Springer       if (vPos > 0)
6285b0055a4SMatthias Springer         expr = (-expr).floorDiv(vPos);
6295b0055a4SMatthias Springer       else
6305b0055a4SMatthias Springer         // vPos < 0.
6315b0055a4SMatthias Springer         expr = expr.floorDiv(-vPos);
6325b0055a4SMatthias Springer       // Successfully constructed expression.
6335b0055a4SMatthias Springer       memo[pos] = expr;
6345b0055a4SMatthias Springer       changed = true;
6355b0055a4SMatthias Springer     }
6365b0055a4SMatthias Springer     // This loop is guaranteed to reach a fixed point - since once an
6375b0055a4SMatthias Springer     // variable's explicit form is computed (in memo[pos]), it's not updated
6385b0055a4SMatthias Springer     // again.
6395b0055a4SMatthias Springer   } while (changed);
6405b0055a4SMatthias Springer 
6415b0055a4SMatthias Springer   int64_t ubAdjustment = closedUB ? 0 : 1;
6425b0055a4SMatthias Springer 
6435b0055a4SMatthias Springer   // Set the lower and upper bound maps for all the variables that were
6445b0055a4SMatthias Springer   // computed as affine expressions of the rest as the "detected expr" and
6455b0055a4SMatthias Springer   // "detected expr + 1" respectively; set the undetected ones to null.
6465b0055a4SMatthias Springer   std::optional<FlatLinearConstraints> tmpClone;
6475b0055a4SMatthias Springer   for (unsigned pos = 0; pos < num; pos++) {
6485b0055a4SMatthias Springer     unsigned numMapDims = getNumDimVars() - num;
6495b0055a4SMatthias Springer     unsigned numMapSymbols = getNumSymbolVars();
6505b0055a4SMatthias Springer     AffineExpr expr = memo[pos + offset];
6515b0055a4SMatthias Springer     if (expr)
6525b0055a4SMatthias Springer       expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
6535b0055a4SMatthias Springer 
6545b0055a4SMatthias Springer     AffineMap &lbMap = (*lbMaps)[pos];
6555b0055a4SMatthias Springer     AffineMap &ubMap = (*ubMaps)[pos];
6565b0055a4SMatthias Springer 
6575b0055a4SMatthias Springer     if (expr) {
6585b0055a4SMatthias Springer       lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
6595b0055a4SMatthias Springer       ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + ubAdjustment);
6605b0055a4SMatthias Springer     } else {
6615b0055a4SMatthias Springer       // TODO: Whenever there are local variables in the dependence
6625b0055a4SMatthias Springer       // constraints, we'll conservatively over-approximate, since we don't
6635b0055a4SMatthias Springer       // always explicitly compute them above (in the while loop).
6645b0055a4SMatthias Springer       if (getNumLocalVars() == 0) {
6655b0055a4SMatthias Springer         // Work on a copy so that we don't update this constraint system.
6665b0055a4SMatthias Springer         if (!tmpClone) {
6675b0055a4SMatthias Springer           tmpClone.emplace(FlatLinearConstraints(*this));
6685b0055a4SMatthias Springer           // Removing redundant inequalities is necessary so that we don't get
6695b0055a4SMatthias Springer           // redundant loop bounds.
6705b0055a4SMatthias Springer           tmpClone->removeRedundantInequalities();
6715b0055a4SMatthias Springer         }
6725b0055a4SMatthias Springer         std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
6735b0055a4SMatthias Springer             pos, offset, num, getNumDimVars(), /*localExprs=*/{}, context,
6745b0055a4SMatthias Springer             closedUB);
6755b0055a4SMatthias Springer       }
6765b0055a4SMatthias Springer 
6775b0055a4SMatthias Springer       // If the above fails, we'll just use the constant lower bound and the
6785b0055a4SMatthias Springer       // constant upper bound (if they exist) as the slice bounds.
6795b0055a4SMatthias Springer       // TODO: being conservative for the moment in cases that
6805b0055a4SMatthias Springer       // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
6815b0055a4SMatthias Springer       // fixed (b/126426796).
6825b0055a4SMatthias Springer       if (!lbMap || lbMap.getNumResults() > 1) {
6835b0055a4SMatthias Springer         LLVM_DEBUG(llvm::dbgs()
6845b0055a4SMatthias Springer                    << "WARNING: Potentially over-approximating slice lb\n");
6855b0055a4SMatthias Springer         auto lbConst = getConstantBound64(BoundType::LB, pos + offset);
6865b0055a4SMatthias Springer         if (lbConst.has_value()) {
6875b0055a4SMatthias Springer           lbMap = AffineMap::get(numMapDims, numMapSymbols,
6885b0055a4SMatthias Springer                                  getAffineConstantExpr(*lbConst, context));
6895b0055a4SMatthias Springer         }
6905b0055a4SMatthias Springer       }
6915b0055a4SMatthias Springer       if (!ubMap || ubMap.getNumResults() > 1) {
6925b0055a4SMatthias Springer         LLVM_DEBUG(llvm::dbgs()
6935b0055a4SMatthias Springer                    << "WARNING: Potentially over-approximating slice ub\n");
6945b0055a4SMatthias Springer         auto ubConst = getConstantBound64(BoundType::UB, pos + offset);
6955b0055a4SMatthias Springer         if (ubConst.has_value()) {
6965b0055a4SMatthias Springer           ubMap = AffineMap::get(
6975b0055a4SMatthias Springer               numMapDims, numMapSymbols,
6985b0055a4SMatthias Springer               getAffineConstantExpr(*ubConst + ubAdjustment, context));
6995b0055a4SMatthias Springer         }
7005b0055a4SMatthias Springer       }
7015b0055a4SMatthias Springer     }
7025b0055a4SMatthias Springer     LLVM_DEBUG(llvm::dbgs()
7035b0055a4SMatthias Springer                << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
7045b0055a4SMatthias Springer     LLVM_DEBUG(lbMap.dump(););
7055b0055a4SMatthias Springer     LLVM_DEBUG(llvm::dbgs()
7065b0055a4SMatthias Springer                << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
7075b0055a4SMatthias Springer     LLVM_DEBUG(ubMap.dump(););
7085b0055a4SMatthias Springer   }
7095b0055a4SMatthias Springer }
7105b0055a4SMatthias Springer 
7115b0055a4SMatthias Springer LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
71229a925abSBenjamin Maxwell     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
71329a925abSBenjamin Maxwell     bool addConservativeSemiAffineBounds) {
7145b0055a4SMatthias Springer   FlatLinearConstraints localCst;
71529a925abSBenjamin Maxwell   if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst,
71629a925abSBenjamin Maxwell                                      addConservativeSemiAffineBounds))) {
7175b0055a4SMatthias Springer     LLVM_DEBUG(llvm::dbgs()
7185b0055a4SMatthias Springer                << "composition unimplemented for semi-affine maps\n");
7195b0055a4SMatthias Springer     return failure();
7205b0055a4SMatthias Springer   }
7215b0055a4SMatthias Springer 
7225b0055a4SMatthias Springer   // Add localCst information.
7235b0055a4SMatthias Springer   if (localCst.getNumLocalVars() > 0) {
7245b0055a4SMatthias Springer     unsigned numLocalVars = getNumLocalVars();
7255b0055a4SMatthias Springer     // Insert local dims of localCst at the beginning.
7265b0055a4SMatthias Springer     insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars());
7275b0055a4SMatthias Springer     // Insert local dims of `this` at the end of localCst.
7285b0055a4SMatthias Springer     localCst.appendLocalVar(/*num=*/numLocalVars);
7295b0055a4SMatthias Springer     // Dimensions of localCst and this constraint set match. Append localCst to
7305b0055a4SMatthias Springer     // this constraint set.
7315b0055a4SMatthias Springer     append(localCst);
7325b0055a4SMatthias Springer   }
7335b0055a4SMatthias Springer 
7345b0055a4SMatthias Springer   return success();
7355b0055a4SMatthias Springer }
7365b0055a4SMatthias Springer 
73729a925abSBenjamin Maxwell LogicalResult FlatLinearConstraints::addBound(
73829a925abSBenjamin Maxwell     BoundType type, unsigned pos, AffineMap boundMap, bool isClosedBound,
73929a925abSBenjamin Maxwell     AddConservativeSemiAffineBounds addSemiAffineBounds) {
7405b0055a4SMatthias Springer   assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch");
7415b0055a4SMatthias Springer   assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
7425b0055a4SMatthias Springer   assert(pos < getNumDimAndSymbolVars() && "invalid position");
7435b0055a4SMatthias Springer   assert((type != BoundType::EQ || isClosedBound) &&
7445b0055a4SMatthias Springer          "EQ bound must be closed.");
7455b0055a4SMatthias Springer 
7465b0055a4SMatthias Springer   // Equality follows the logic of lower bound except that we add an equality
7475b0055a4SMatthias Springer   // instead of an inequality.
7485b0055a4SMatthias Springer   assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
7495b0055a4SMatthias Springer          "single result expected");
7505b0055a4SMatthias Springer   bool lower = type == BoundType::LB || type == BoundType::EQ;
7515b0055a4SMatthias Springer 
7525b0055a4SMatthias Springer   std::vector<SmallVector<int64_t, 8>> flatExprs;
75329a925abSBenjamin Maxwell   if (failed(flattenAlignedMapAndMergeLocals(
75429a925abSBenjamin Maxwell           boundMap, &flatExprs,
75529a925abSBenjamin Maxwell           addSemiAffineBounds == AddConservativeSemiAffineBounds::Yes)))
7565b0055a4SMatthias Springer     return failure();
7575b0055a4SMatthias Springer   assert(flatExprs.size() == boundMap.getNumResults());
7585b0055a4SMatthias Springer 
7595b0055a4SMatthias Springer   // Add one (in)equality for each result.
7605b0055a4SMatthias Springer   for (const auto &flatExpr : flatExprs) {
7615b0055a4SMatthias Springer     SmallVector<int64_t> ineq(getNumCols(), 0);
7625b0055a4SMatthias Springer     // Dims and symbols.
7635b0055a4SMatthias Springer     for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
7645b0055a4SMatthias Springer       ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
7655b0055a4SMatthias Springer     }
7665b0055a4SMatthias Springer     // Invalid bound: pos appears in `boundMap`.
7675b0055a4SMatthias Springer     // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
7685b0055a4SMatthias Springer     // its callers to prevent invalid bounds from being added.
7695b0055a4SMatthias Springer     if (ineq[pos] != 0)
7705b0055a4SMatthias Springer       continue;
7715b0055a4SMatthias Springer     ineq[pos] = lower ? 1 : -1;
7725b0055a4SMatthias Springer     // Local columns of `ineq` are at the beginning.
7735b0055a4SMatthias Springer     unsigned j = getNumDimVars() + getNumSymbolVars();
7745b0055a4SMatthias Springer     unsigned end = flatExpr.size() - 1;
7755b0055a4SMatthias Springer     for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
7765b0055a4SMatthias Springer       ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
7775b0055a4SMatthias Springer     }
7785b0055a4SMatthias Springer     // Make the bound closed in if flatExpr is open. The inequality is always
7795b0055a4SMatthias Springer     // created in the upper bound form, so the adjustment is -1.
7805b0055a4SMatthias Springer     int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1;
7815b0055a4SMatthias Springer     // Constant term.
7825b0055a4SMatthias Springer     ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1]
7835b0055a4SMatthias Springer                                     : flatExpr[flatExpr.size() - 1]) +
7845b0055a4SMatthias Springer                              boundAdjustment;
7855b0055a4SMatthias Springer     type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq);
7865b0055a4SMatthias Springer   }
7875b0055a4SMatthias Springer 
7885b0055a4SMatthias Springer   return success();
7895b0055a4SMatthias Springer }
7905b0055a4SMatthias Springer 
79129a925abSBenjamin Maxwell LogicalResult FlatLinearConstraints::addBound(
79229a925abSBenjamin Maxwell     BoundType type, unsigned pos, AffineMap boundMap,
79329a925abSBenjamin Maxwell     AddConservativeSemiAffineBounds addSemiAffineBounds) {
79429a925abSBenjamin Maxwell   return addBound(type, pos, boundMap,
79529a925abSBenjamin Maxwell                   /*isClosedBound=*/type != BoundType::UB, addSemiAffineBounds);
7965b0055a4SMatthias Springer }
7975b0055a4SMatthias Springer 
7985b0055a4SMatthias Springer /// Compute an explicit representation for local vars. For all systems coming
7995b0055a4SMatthias Springer /// from MLIR integer sets, maps, or expressions where local vars were
8005b0055a4SMatthias Springer /// introduced to model floordivs and mods, this always succeeds.
8015b0055a4SMatthias Springer LogicalResult
8025b0055a4SMatthias Springer FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo,
8035b0055a4SMatthias Springer                                         MLIRContext *context) const {
8045b0055a4SMatthias Springer   unsigned numDims = getNumDimVars();
8055b0055a4SMatthias Springer   unsigned numSyms = getNumSymbolVars();
8065b0055a4SMatthias Springer 
8075b0055a4SMatthias Springer   // Initialize dimensional and symbolic variables.
8085b0055a4SMatthias Springer   for (unsigned i = 0; i < numDims; i++)
8095b0055a4SMatthias Springer     memo[i] = getAffineDimExpr(i, context);
8105b0055a4SMatthias Springer   for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
8115b0055a4SMatthias Springer     memo[i] = getAffineSymbolExpr(i - numDims, context);
8125b0055a4SMatthias Springer 
8135b0055a4SMatthias Springer   bool changed;
8145b0055a4SMatthias Springer   do {
8155b0055a4SMatthias Springer     // Each time `changed` is true at the end of this iteration, one or more
8165b0055a4SMatthias Springer     // local vars would have been detected as floordivs and set in memo; so the
8175b0055a4SMatthias Springer     // number of null entries in memo[...] strictly reduces; so this converges.
8185b0055a4SMatthias Springer     changed = false;
8195b0055a4SMatthias Springer     for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i)
8205b0055a4SMatthias Springer       if (!memo[numDims + numSyms + i] &&
8215b0055a4SMatthias Springer           detectAsFloorDiv(*this, /*pos=*/numDims + numSyms + i, context, memo))
8225b0055a4SMatthias Springer         changed = true;
8235b0055a4SMatthias Springer   } while (changed);
8245b0055a4SMatthias Springer 
8255b0055a4SMatthias Springer   ArrayRef<AffineExpr> localExprs =
8265b0055a4SMatthias Springer       ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars());
8275b0055a4SMatthias Springer   return success(
8285b0055a4SMatthias Springer       llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
8295b0055a4SMatthias Springer }
8305b0055a4SMatthias Springer 
8315b0055a4SMatthias Springer IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const {
8325b0055a4SMatthias Springer   if (getNumConstraints() == 0)
8335b0055a4SMatthias Springer     // Return universal set (always true): 0 == 0.
8345b0055a4SMatthias Springer     return IntegerSet::get(getNumDimVars(), getNumSymbolVars(),
8355b0055a4SMatthias Springer                            getAffineConstantExpr(/*constant=*/0, context),
8365b0055a4SMatthias Springer                            /*eqFlags=*/true);
8375b0055a4SMatthias Springer 
8385b0055a4SMatthias Springer   // Construct local references.
8395b0055a4SMatthias Springer   SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
8405b0055a4SMatthias Springer 
8415b0055a4SMatthias Springer   if (failed(computeLocalVars(memo, context))) {
8425b0055a4SMatthias Springer     // Check if the local variables without an explicit representation have
8435b0055a4SMatthias Springer     // zero coefficients everywhere.
8445b0055a4SMatthias Springer     SmallVector<unsigned> noLocalRepVars;
8455b0055a4SMatthias Springer     unsigned numDimsSymbols = getNumDimAndSymbolVars();
8465b0055a4SMatthias Springer     for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) {
8475b0055a4SMatthias Springer       if (!memo[i] && !isColZero(/*pos=*/i))
8485b0055a4SMatthias Springer         noLocalRepVars.push_back(i - numDimsSymbols);
8495b0055a4SMatthias Springer     }
8505b0055a4SMatthias Springer     if (!noLocalRepVars.empty()) {
8515b0055a4SMatthias Springer       LLVM_DEBUG({
8525b0055a4SMatthias Springer         llvm::dbgs() << "local variables at position(s) ";
8535b0055a4SMatthias Springer         llvm::interleaveComma(noLocalRepVars, llvm::dbgs());
8545b0055a4SMatthias Springer         llvm::dbgs() << " do not have an explicit representation in:\n";
8555b0055a4SMatthias Springer         this->dump();
8565b0055a4SMatthias Springer       });
8575b0055a4SMatthias Springer       return IntegerSet();
8585b0055a4SMatthias Springer     }
8595b0055a4SMatthias Springer   }
8605b0055a4SMatthias Springer 
8615b0055a4SMatthias Springer   ArrayRef<AffineExpr> localExprs =
8625b0055a4SMatthias Springer       ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars());
8635b0055a4SMatthias Springer 
8645b0055a4SMatthias Springer   // Construct the IntegerSet from the equalities/inequalities.
8655b0055a4SMatthias Springer   unsigned numDims = getNumDimVars();
8665b0055a4SMatthias Springer   unsigned numSyms = getNumSymbolVars();
8675b0055a4SMatthias Springer 
8685b0055a4SMatthias Springer   SmallVector<bool, 16> eqFlags(getNumConstraints());
8695b0055a4SMatthias Springer   std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
8705b0055a4SMatthias Springer   std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
8715b0055a4SMatthias Springer 
8725b0055a4SMatthias Springer   SmallVector<AffineExpr, 8> exprs;
8735b0055a4SMatthias Springer   exprs.reserve(getNumConstraints());
8745b0055a4SMatthias Springer 
8755b0055a4SMatthias Springer   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
8765b0055a4SMatthias Springer     exprs.push_back(getAffineExprFromFlatForm(getEquality64(i), numDims,
8775b0055a4SMatthias Springer                                               numSyms, localExprs, context));
8785b0055a4SMatthias Springer   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
8795b0055a4SMatthias Springer     exprs.push_back(getAffineExprFromFlatForm(getInequality64(i), numDims,
8805b0055a4SMatthias Springer                                               numSyms, localExprs, context));
8815b0055a4SMatthias Springer   return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
8825b0055a4SMatthias Springer }
8835b0055a4SMatthias Springer 
8845b0055a4SMatthias Springer //===----------------------------------------------------------------------===//
8855b0055a4SMatthias Springer // FlatLinearValueConstraints
8865b0055a4SMatthias Springer //===----------------------------------------------------------------------===//
8875b0055a4SMatthias Springer 
8885b0055a4SMatthias Springer // Construct from an IntegerSet.
8895b0055a4SMatthias Springer FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set,
8905b0055a4SMatthias Springer                                                        ValueRange operands)
8915b0055a4SMatthias Springer     : FlatLinearConstraints(set.getNumInequalities(), set.getNumEqualities(),
8925b0055a4SMatthias Springer                             set.getNumDims() + set.getNumSymbols() + 1,
8935b0055a4SMatthias Springer                             set.getNumDims(), set.getNumSymbols(),
8945b0055a4SMatthias Springer                             /*numLocals=*/0) {
895*a24c4687SAlexander Pivovarov   assert((operands.empty() || set.getNumInputs() == operands.size()) &&
896*a24c4687SAlexander Pivovarov          "operand count mismatch");
89724da7fa0SBharathi Ramana Joshi   // Set the values for the non-local variables.
89824da7fa0SBharathi Ramana Joshi   for (unsigned i = 0, e = operands.size(); i < e; ++i)
89924da7fa0SBharathi Ramana Joshi     setValue(i, operands[i]);
9005b0055a4SMatthias Springer 
9015b0055a4SMatthias Springer   // Flatten expressions and add them to the constraint system.
9025b0055a4SMatthias Springer   std::vector<SmallVector<int64_t, 8>> flatExprs;
9035b0055a4SMatthias Springer   FlatLinearConstraints localVarCst;
9045b0055a4SMatthias Springer   if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
9055b0055a4SMatthias Springer     assert(false && "flattening unimplemented for semi-affine integer sets");
9065b0055a4SMatthias Springer     return;
9075b0055a4SMatthias Springer   }
9085b0055a4SMatthias Springer   assert(flatExprs.size() == set.getNumConstraints());
9095b0055a4SMatthias Springer   insertVar(VarKind::Local, getNumVarKind(VarKind::Local),
9105b0055a4SMatthias Springer             /*num=*/localVarCst.getNumLocalVars());
9115b0055a4SMatthias Springer 
9125b0055a4SMatthias Springer   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
9135b0055a4SMatthias Springer     const auto &flatExpr = flatExprs[i];
9145b0055a4SMatthias Springer     assert(flatExpr.size() == getNumCols());
9155b0055a4SMatthias Springer     if (set.getEqFlags()[i]) {
9165b0055a4SMatthias Springer       addEquality(flatExpr);
9175b0055a4SMatthias Springer     } else {
9185b0055a4SMatthias Springer       addInequality(flatExpr);
9195b0055a4SMatthias Springer     }
9205b0055a4SMatthias Springer   }
9215b0055a4SMatthias Springer   // Add the other constraints involving local vars from flattening.
9225b0055a4SMatthias Springer   append(localVarCst);
9235b0055a4SMatthias Springer }
9245b0055a4SMatthias Springer 
9255b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::appendDimVar(ValueRange vals) {
9265b0055a4SMatthias Springer   unsigned pos = getNumDimVars();
9275b0055a4SMatthias Springer   return insertVar(VarKind::SetDim, pos, vals);
9285b0055a4SMatthias Springer }
9295b0055a4SMatthias Springer 
9305b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::appendSymbolVar(ValueRange vals) {
9315b0055a4SMatthias Springer   unsigned pos = getNumSymbolVars();
9325b0055a4SMatthias Springer   return insertVar(VarKind::Symbol, pos, vals);
9335b0055a4SMatthias Springer }
9345b0055a4SMatthias Springer 
9355b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::insertDimVar(unsigned pos,
9365b0055a4SMatthias Springer                                                   ValueRange vals) {
9375b0055a4SMatthias Springer   return insertVar(VarKind::SetDim, pos, vals);
9385b0055a4SMatthias Springer }
9395b0055a4SMatthias Springer 
9405b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::insertSymbolVar(unsigned pos,
9415b0055a4SMatthias Springer                                                      ValueRange vals) {
9425b0055a4SMatthias Springer   return insertVar(VarKind::Symbol, pos, vals);
9435b0055a4SMatthias Springer }
9445b0055a4SMatthias Springer 
9455b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
9465b0055a4SMatthias Springer                                                unsigned num) {
9475b0055a4SMatthias Springer   unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
9485b0055a4SMatthias Springer 
9495b0055a4SMatthias Springer   return absolutePos;
9505b0055a4SMatthias Springer }
9515b0055a4SMatthias Springer 
9525b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
9535b0055a4SMatthias Springer                                                ValueRange vals) {
9545b0055a4SMatthias Springer   assert(!vals.empty() && "expected ValueRange with Values.");
9555b0055a4SMatthias Springer   assert(kind != VarKind::Local &&
9565b0055a4SMatthias Springer          "values cannot be attached to local variables.");
9575b0055a4SMatthias Springer   unsigned num = vals.size();
9585b0055a4SMatthias Springer   unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
9595b0055a4SMatthias Springer 
9601b2247d9SKazu Hirata   // If a Value is provided, insert it; otherwise use std::nullopt.
96124da7fa0SBharathi Ramana Joshi   for (unsigned i = 0, e = vals.size(); i < e; ++i)
96224da7fa0SBharathi Ramana Joshi     if (vals[i])
96324da7fa0SBharathi Ramana Joshi       setValue(absolutePos + i, vals[i]);
9645b0055a4SMatthias Springer 
9655b0055a4SMatthias Springer   return absolutePos;
9665b0055a4SMatthias Springer }
9675b0055a4SMatthias Springer 
9685b0055a4SMatthias Springer /// Checks if two constraint systems are in the same space, i.e., if they are
9695b0055a4SMatthias Springer /// associated with the same set of variables, appearing in the same order.
9705b0055a4SMatthias Springer static bool areVarsAligned(const FlatLinearValueConstraints &a,
9715b0055a4SMatthias Springer                            const FlatLinearValueConstraints &b) {
97224da7fa0SBharathi Ramana Joshi   if (a.getNumDomainVars() != b.getNumDomainVars() ||
97324da7fa0SBharathi Ramana Joshi       a.getNumRangeVars() != b.getNumRangeVars() ||
97424da7fa0SBharathi Ramana Joshi       a.getNumSymbolVars() != b.getNumSymbolVars())
97524da7fa0SBharathi Ramana Joshi     return false;
97624da7fa0SBharathi Ramana Joshi   SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(),
97724da7fa0SBharathi Ramana Joshi                                     bMaybeValues = b.getMaybeValues();
97824da7fa0SBharathi Ramana Joshi   return std::equal(aMaybeValues.begin(), aMaybeValues.end(),
97924da7fa0SBharathi Ramana Joshi                     bMaybeValues.begin(), bMaybeValues.end());
9805b0055a4SMatthias Springer }
9815b0055a4SMatthias Springer 
9825b0055a4SMatthias Springer /// Calls areVarsAligned to check if two constraint systems have the same set
9835b0055a4SMatthias Springer /// of variables in the same order.
9845b0055a4SMatthias Springer bool FlatLinearValueConstraints::areVarsAlignedWithOther(
9855b0055a4SMatthias Springer     const FlatLinearConstraints &other) {
9865b0055a4SMatthias Springer   return areVarsAligned(*this, other);
9875b0055a4SMatthias Springer }
9885b0055a4SMatthias Springer 
9895b0055a4SMatthias Springer /// Checks if the SSA values associated with `cst`'s variables in range
9905b0055a4SMatthias Springer /// [start, end) are unique.
9915b0055a4SMatthias Springer static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
9925b0055a4SMatthias Springer     const FlatLinearValueConstraints &cst, unsigned start, unsigned end) {
9935b0055a4SMatthias Springer 
9945b0055a4SMatthias Springer   assert(start <= cst.getNumDimAndSymbolVars() &&
9955b0055a4SMatthias Springer          "Start position out of bounds");
9965b0055a4SMatthias Springer   assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds");
9975b0055a4SMatthias Springer 
9985b0055a4SMatthias Springer   if (start >= end)
9995b0055a4SMatthias Springer     return true;
10005b0055a4SMatthias Springer 
10015b0055a4SMatthias Springer   SmallPtrSet<Value, 8> uniqueVars;
100224da7fa0SBharathi Ramana Joshi   SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues();
100324da7fa0SBharathi Ramana Joshi   ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start,
100424da7fa0SBharathi Ramana Joshi                                                 maybeValuesAll.data() + end};
100524da7fa0SBharathi Ramana Joshi 
100624da7fa0SBharathi Ramana Joshi   for (std::optional<Value> val : maybeValues)
10075b0055a4SMatthias Springer     if (val && !uniqueVars.insert(*val).second)
10085b0055a4SMatthias Springer       return false;
100924da7fa0SBharathi Ramana Joshi 
10105b0055a4SMatthias Springer   return true;
10115b0055a4SMatthias Springer }
10125b0055a4SMatthias Springer 
10135b0055a4SMatthias Springer /// Checks if the SSA values associated with `cst`'s variables are unique.
10145b0055a4SMatthias Springer static bool LLVM_ATTRIBUTE_UNUSED
10155b0055a4SMatthias Springer areVarsUnique(const FlatLinearValueConstraints &cst) {
10165b0055a4SMatthias Springer   return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars());
10175b0055a4SMatthias Springer }
10185b0055a4SMatthias Springer 
10195b0055a4SMatthias Springer /// Checks if the SSA values associated with `cst`'s variables of kind `kind`
10205b0055a4SMatthias Springer /// are unique.
10215b0055a4SMatthias Springer static bool LLVM_ATTRIBUTE_UNUSED
10225b0055a4SMatthias Springer areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
10235b0055a4SMatthias Springer 
10245b0055a4SMatthias Springer   if (kind == VarKind::SetDim)
10255b0055a4SMatthias Springer     return areVarsUnique(cst, 0, cst.getNumDimVars());
10265b0055a4SMatthias Springer   if (kind == VarKind::Symbol)
10275b0055a4SMatthias Springer     return areVarsUnique(cst, cst.getNumDimVars(),
10285b0055a4SMatthias Springer                          cst.getNumDimAndSymbolVars());
10295b0055a4SMatthias Springer   llvm_unreachable("Unexpected VarKind");
10305b0055a4SMatthias Springer }
10315b0055a4SMatthias Springer 
10325b0055a4SMatthias Springer /// Merge and align the variables of A and B starting at 'offset', so that
10335b0055a4SMatthias Springer /// both constraint systems get the union of the contained variables that is
10345b0055a4SMatthias Springer /// dimension-wise and symbol-wise unique; both constraint systems are updated
10355b0055a4SMatthias Springer /// so that they have the union of all variables, with A's original
10365b0055a4SMatthias Springer /// variables appearing first followed by any of B's variables that didn't
10375b0055a4SMatthias Springer /// appear in A. Local variables in B that have the same division
1038c79ffb02SUday Bondhugula /// representation as local variables in A are merged into one. We allow A
1039c79ffb02SUday Bondhugula /// and B to have non-unique values for their variables; in such cases, they are
1040c79ffb02SUday Bondhugula /// still aligned with the variables appearing first aligned with those
1041c79ffb02SUday Bondhugula /// appearing first in the other system from left to right.
10425b0055a4SMatthias Springer //  E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
10435b0055a4SMatthias Springer //        Output: both A, B have (%i, %j, %k) [%M, %N, %P]
10445b0055a4SMatthias Springer static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
10455b0055a4SMatthias Springer                               FlatLinearValueConstraints *b) {
10465b0055a4SMatthias Springer   assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars());
10475b0055a4SMatthias Springer 
10485b0055a4SMatthias Springer   assert(llvm::all_of(
10495b0055a4SMatthias Springer       llvm::drop_begin(a->getMaybeValues(), offset),
10505b0055a4SMatthias Springer       [](const std::optional<Value> &var) { return var.has_value(); }));
10515b0055a4SMatthias Springer 
10525b0055a4SMatthias Springer   assert(llvm::all_of(
10535b0055a4SMatthias Springer       llvm::drop_begin(b->getMaybeValues(), offset),
10545b0055a4SMatthias Springer       [](const std::optional<Value> &var) { return var.has_value(); }));
10555b0055a4SMatthias Springer 
10565b0055a4SMatthias Springer   SmallVector<Value, 4> aDimValues;
10575b0055a4SMatthias Springer   a->getValues(offset, a->getNumDimVars(), &aDimValues);
10585b0055a4SMatthias Springer 
10595b0055a4SMatthias Springer   {
10605b0055a4SMatthias Springer     // Merge dims from A into B.
10615b0055a4SMatthias Springer     unsigned d = offset;
1062c79ffb02SUday Bondhugula     for (Value aDimValue : aDimValues) {
10635b0055a4SMatthias Springer       unsigned loc;
1064c79ffb02SUday Bondhugula       // Find from the position `d` since we'd like to also consider the
1065c79ffb02SUday Bondhugula       // possibility of multiple variables with the same `Value`. We align with
1066c79ffb02SUday Bondhugula       // the next appearing one.
1067c79ffb02SUday Bondhugula       if (b->findVar(aDimValue, &loc, d)) {
10685b0055a4SMatthias Springer         assert(loc >= offset && "A's dim appears in B's aligned range");
10695b0055a4SMatthias Springer         assert(loc < b->getNumDimVars() &&
10705b0055a4SMatthias Springer                "A's dim appears in B's non-dim position");
10715b0055a4SMatthias Springer         b->swapVar(d, loc);
10725b0055a4SMatthias Springer       } else {
10735b0055a4SMatthias Springer         b->insertDimVar(d, aDimValue);
10745b0055a4SMatthias Springer       }
10755b0055a4SMatthias Springer       d++;
10765b0055a4SMatthias Springer     }
10775b0055a4SMatthias Springer     // Dimensions that are in B, but not in A, are added at the end.
10785b0055a4SMatthias Springer     for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) {
10795b0055a4SMatthias Springer       a->appendDimVar(b->getValue(t));
10805b0055a4SMatthias Springer     }
10815b0055a4SMatthias Springer     assert(a->getNumDimVars() == b->getNumDimVars() &&
10825b0055a4SMatthias Springer            "expected same number of dims");
10835b0055a4SMatthias Springer   }
10845b0055a4SMatthias Springer 
10855b0055a4SMatthias Springer   // Merge and align symbols of A and B
10865b0055a4SMatthias Springer   a->mergeSymbolVars(*b);
10875b0055a4SMatthias Springer   // Merge and align locals of A and B
10885b0055a4SMatthias Springer   a->mergeLocalVars(*b);
10895b0055a4SMatthias Springer 
10905b0055a4SMatthias Springer   assert(areVarsAligned(*a, *b) && "IDs expected to be aligned");
10915b0055a4SMatthias Springer }
10925b0055a4SMatthias Springer 
10935b0055a4SMatthias Springer // Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'.
10945b0055a4SMatthias Springer void FlatLinearValueConstraints::mergeAndAlignVarsWithOther(
10955b0055a4SMatthias Springer     unsigned offset, FlatLinearValueConstraints *other) {
10965b0055a4SMatthias Springer   mergeAndAlignVars(offset, this, other);
10975b0055a4SMatthias Springer }
10985b0055a4SMatthias Springer 
10995b0055a4SMatthias Springer /// Merge and align symbols of `this` and `other` such that both get union of
1100c79ffb02SUday Bondhugula /// of symbols. Existing symbols need not be unique; they will be aligned from
1101c79ffb02SUday Bondhugula /// left to right with duplicates aligned in the same order. Symbols with Value
1102c79ffb02SUday Bondhugula /// as `None` are considered to be inequal to all other symbols.
11035b0055a4SMatthias Springer void FlatLinearValueConstraints::mergeSymbolVars(
11045b0055a4SMatthias Springer     FlatLinearValueConstraints &other) {
11055b0055a4SMatthias Springer 
11065b0055a4SMatthias Springer   SmallVector<Value, 4> aSymValues;
11075b0055a4SMatthias Springer   getValues(getNumDimVars(), getNumDimAndSymbolVars(), &aSymValues);
11085b0055a4SMatthias Springer 
11095b0055a4SMatthias Springer   // Merge symbols: merge symbols into `other` first from `this`.
11105b0055a4SMatthias Springer   unsigned s = other.getNumDimVars();
11115b0055a4SMatthias Springer   for (Value aSymValue : aSymValues) {
11125b0055a4SMatthias Springer     unsigned loc;
11135b0055a4SMatthias Springer     // If the var is a symbol in `other`, then align it, otherwise assume that
1114c79ffb02SUday Bondhugula     // it is a new symbol. Search in `other` starting at position `s` since the
1115c79ffb02SUday Bondhugula     // left of it is aligned.
1116c79ffb02SUday Bondhugula     if (other.findVar(aSymValue, &loc, s) && loc >= other.getNumDimVars() &&
11175b0055a4SMatthias Springer         loc < other.getNumDimAndSymbolVars())
11185b0055a4SMatthias Springer       other.swapVar(s, loc);
11195b0055a4SMatthias Springer     else
11205b0055a4SMatthias Springer       other.insertSymbolVar(s - other.getNumDimVars(), aSymValue);
11215b0055a4SMatthias Springer     s++;
11225b0055a4SMatthias Springer   }
11235b0055a4SMatthias Springer 
11245b0055a4SMatthias Springer   // Symbols that are in other, but not in this, are added at the end.
11255b0055a4SMatthias Springer   for (unsigned t = other.getNumDimVars() + getNumSymbolVars(),
11265b0055a4SMatthias Springer                 e = other.getNumDimAndSymbolVars();
11275b0055a4SMatthias Springer        t < e; t++)
11285b0055a4SMatthias Springer     insertSymbolVar(getNumSymbolVars(), other.getValue(t));
11295b0055a4SMatthias Springer 
11305b0055a4SMatthias Springer   assert(getNumSymbolVars() == other.getNumSymbolVars() &&
11315b0055a4SMatthias Springer          "expected same number of symbols");
11325b0055a4SMatthias Springer }
11335b0055a4SMatthias Springer 
11345b0055a4SMatthias Springer void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart,
11355b0055a4SMatthias Springer                                                 unsigned varLimit) {
11365b0055a4SMatthias Springer   IntegerPolyhedron::removeVarRange(kind, varStart, varLimit);
11375b0055a4SMatthias Springer }
11385b0055a4SMatthias Springer 
11395b0055a4SMatthias Springer AffineMap
11405b0055a4SMatthias Springer FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
11415b0055a4SMatthias Springer                                               ValueRange operands) const {
11425b0055a4SMatthias Springer   assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
11435b0055a4SMatthias Springer 
11445b0055a4SMatthias Springer   SmallVector<Value> dims, syms;
11455b0055a4SMatthias Springer #ifndef NDEBUG
11465b0055a4SMatthias Springer   SmallVector<Value> newSyms;
11475b0055a4SMatthias Springer   SmallVector<Value> *newSymsPtr = &newSyms;
11485b0055a4SMatthias Springer #else
11495b0055a4SMatthias Springer   SmallVector<Value> *newSymsPtr = nullptr;
11505b0055a4SMatthias Springer #endif // NDEBUG
11515b0055a4SMatthias Springer 
11525b0055a4SMatthias Springer   dims.reserve(getNumDimVars());
11535b0055a4SMatthias Springer   syms.reserve(getNumSymbolVars());
115424da7fa0SBharathi Ramana Joshi   for (unsigned i = 0, e = getNumVarKind(VarKind::SetDim); i < e; ++i) {
115524da7fa0SBharathi Ramana Joshi     Identifier id = space.getId(VarKind::SetDim, i);
115624da7fa0SBharathi Ramana Joshi     dims.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value());
115724da7fa0SBharathi Ramana Joshi   }
115824da7fa0SBharathi Ramana Joshi   for (unsigned i = 0, e = getNumVarKind(VarKind::Symbol); i < e; ++i) {
115924da7fa0SBharathi Ramana Joshi     Identifier id = space.getId(VarKind::Symbol, i);
116024da7fa0SBharathi Ramana Joshi     syms.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value());
116124da7fa0SBharathi Ramana Joshi   }
11625b0055a4SMatthias Springer 
11635b0055a4SMatthias Springer   AffineMap alignedMap =
11645b0055a4SMatthias Springer       alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
11655b0055a4SMatthias Springer   // All symbols are already part of this FlatAffineValueConstraints.
11665b0055a4SMatthias Springer   assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
11675b0055a4SMatthias Springer   assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
11685b0055a4SMatthias Springer          "unexpected new/missing symbols");
11695b0055a4SMatthias Springer   return alignedMap;
11705b0055a4SMatthias Springer }
11715b0055a4SMatthias Springer 
1172c79ffb02SUday Bondhugula bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
1173c79ffb02SUday Bondhugula                                          unsigned offset) const {
117424da7fa0SBharathi Ramana Joshi   SmallVector<std::optional<Value>> maybeValues = getMaybeValues();
117524da7fa0SBharathi Ramana Joshi   for (unsigned i = offset, e = maybeValues.size(); i < e; ++i)
117624da7fa0SBharathi Ramana Joshi     if (maybeValues[i] && maybeValues[i].value() == val) {
11775b0055a4SMatthias Springer       *pos = i;
11785b0055a4SMatthias Springer       return true;
11795b0055a4SMatthias Springer     }
11805b0055a4SMatthias Springer   return false;
11815b0055a4SMatthias Springer }
11825b0055a4SMatthias Springer 
11835b0055a4SMatthias Springer bool FlatLinearValueConstraints::containsVar(Value val) const {
118424da7fa0SBharathi Ramana Joshi   unsigned pos;
118524da7fa0SBharathi Ramana Joshi   return findVar(val, &pos, 0);
11865b0055a4SMatthias Springer }
11875b0055a4SMatthias Springer 
11885b0055a4SMatthias Springer void FlatLinearValueConstraints::addBound(BoundType type, Value val,
11895b0055a4SMatthias Springer                                           int64_t value) {
11905b0055a4SMatthias Springer   unsigned pos;
11915b0055a4SMatthias Springer   if (!findVar(val, &pos))
11925b0055a4SMatthias Springer     // This is a pre-condition for this method.
11935b0055a4SMatthias Springer     assert(0 && "var not found");
11945b0055a4SMatthias Springer   addBound(type, pos, value);
11955b0055a4SMatthias Springer }
11965b0055a4SMatthias Springer 
11975b0055a4SMatthias Springer void FlatLinearConstraints::printSpace(raw_ostream &os) const {
11985b0055a4SMatthias Springer   IntegerPolyhedron::printSpace(os);
11995b0055a4SMatthias Springer   os << "(";
12005b0055a4SMatthias Springer   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++)
12015b0055a4SMatthias Springer     os << "None\t";
12025b0055a4SMatthias Springer   for (unsigned i = getVarKindOffset(VarKind::Local),
12035b0055a4SMatthias Springer                 e = getVarKindEnd(VarKind::Local);
12045b0055a4SMatthias Springer        i < e; ++i)
12055b0055a4SMatthias Springer     os << "Local\t";
12065b0055a4SMatthias Springer   os << "const)\n";
12075b0055a4SMatthias Springer }
12085b0055a4SMatthias Springer 
12095b0055a4SMatthias Springer void FlatLinearValueConstraints::printSpace(raw_ostream &os) const {
12105b0055a4SMatthias Springer   IntegerPolyhedron::printSpace(os);
12115b0055a4SMatthias Springer   os << "(";
12125b0055a4SMatthias Springer   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) {
12135b0055a4SMatthias Springer     if (hasValue(i))
12145b0055a4SMatthias Springer       os << "Value\t";
12155b0055a4SMatthias Springer     else
12165b0055a4SMatthias Springer       os << "None\t";
12175b0055a4SMatthias Springer   }
12185b0055a4SMatthias Springer   for (unsigned i = getVarKindOffset(VarKind::Local),
12195b0055a4SMatthias Springer                 e = getVarKindEnd(VarKind::Local);
12205b0055a4SMatthias Springer        i < e; ++i)
12215b0055a4SMatthias Springer     os << "Local\t";
12225b0055a4SMatthias Springer   os << "const)\n";
12235b0055a4SMatthias Springer }
12245b0055a4SMatthias Springer 
12255b0055a4SMatthias Springer void FlatLinearValueConstraints::projectOut(Value val) {
12265b0055a4SMatthias Springer   unsigned pos;
12275b0055a4SMatthias Springer   bool ret = findVar(val, &pos);
12285b0055a4SMatthias Springer   assert(ret);
12295b0055a4SMatthias Springer   (void)ret;
12305b0055a4SMatthias Springer   fourierMotzkinEliminate(pos);
12315b0055a4SMatthias Springer }
12325b0055a4SMatthias Springer 
12335b0055a4SMatthias Springer LogicalResult FlatLinearValueConstraints::unionBoundingBox(
12345b0055a4SMatthias Springer     const FlatLinearValueConstraints &otherCst) {
12355b0055a4SMatthias Springer   assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch");
123624da7fa0SBharathi Ramana Joshi   SmallVector<std::optional<Value>> maybeValues = getMaybeValues(),
123724da7fa0SBharathi Ramana Joshi                                     otherMaybeValues =
123824da7fa0SBharathi Ramana Joshi                                         otherCst.getMaybeValues();
123924da7fa0SBharathi Ramana Joshi   assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(),
124024da7fa0SBharathi Ramana Joshi                     otherMaybeValues.begin(),
124124da7fa0SBharathi Ramana Joshi                     otherMaybeValues.begin() + getNumDimVars()) &&
12425b0055a4SMatthias Springer          "dim values mismatch");
12435b0055a4SMatthias Springer   assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
12445b0055a4SMatthias Springer   assert(getNumLocalVars() == 0 && "local vars not supported yet here");
12455b0055a4SMatthias Springer 
12465b0055a4SMatthias Springer   // Align `other` to this.
12475b0055a4SMatthias Springer   if (!areVarsAligned(*this, otherCst)) {
12485b0055a4SMatthias Springer     FlatLinearValueConstraints otherCopy(otherCst);
12495b0055a4SMatthias Springer     mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy);
1250f819302aSRamkumar Ramachandra     return IntegerPolyhedron::unionBoundingBox(otherCopy);
12515b0055a4SMatthias Springer   }
12525b0055a4SMatthias Springer 
1253f819302aSRamkumar Ramachandra   return IntegerPolyhedron::unionBoundingBox(otherCst);
12545b0055a4SMatthias Springer }
12555b0055a4SMatthias Springer 
12565b0055a4SMatthias Springer //===----------------------------------------------------------------------===//
12575b0055a4SMatthias Springer // Helper functions
12585b0055a4SMatthias Springer //===----------------------------------------------------------------------===//
12595b0055a4SMatthias Springer 
12605b0055a4SMatthias Springer AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
12615b0055a4SMatthias Springer                                          ValueRange dims, ValueRange syms,
12625b0055a4SMatthias Springer                                          SmallVector<Value> *newSyms) {
12635b0055a4SMatthias Springer   assert(operands.size() == map.getNumInputs() &&
12645b0055a4SMatthias Springer          "expected same number of operands and map inputs");
12655b0055a4SMatthias Springer   MLIRContext *ctx = map.getContext();
12665b0055a4SMatthias Springer   Builder builder(ctx);
12675b0055a4SMatthias Springer   SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
12685b0055a4SMatthias Springer   unsigned numSymbols = syms.size();
12695b0055a4SMatthias Springer   SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
12705b0055a4SMatthias Springer   if (newSyms) {
12715b0055a4SMatthias Springer     newSyms->clear();
12725b0055a4SMatthias Springer     newSyms->append(syms.begin(), syms.end());
12735b0055a4SMatthias Springer   }
12745b0055a4SMatthias Springer 
12755b0055a4SMatthias Springer   for (const auto &operand : llvm::enumerate(operands)) {
12765b0055a4SMatthias Springer     // Compute replacement dim/sym of operand.
12775b0055a4SMatthias Springer     AffineExpr replacement;
12783f0bddb5SKazu Hirata     auto dimIt = llvm::find(dims, operand.value());
12793f0bddb5SKazu Hirata     auto symIt = llvm::find(syms, operand.value());
12805b0055a4SMatthias Springer     if (dimIt != dims.end()) {
12815b0055a4SMatthias Springer       replacement =
12825b0055a4SMatthias Springer           builder.getAffineDimExpr(std::distance(dims.begin(), dimIt));
12835b0055a4SMatthias Springer     } else if (symIt != syms.end()) {
12845b0055a4SMatthias Springer       replacement =
12855b0055a4SMatthias Springer           builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt));
12865b0055a4SMatthias Springer     } else {
12875b0055a4SMatthias Springer       // This operand is neither a dimension nor a symbol. Add it as a new
12885b0055a4SMatthias Springer       // symbol.
12895b0055a4SMatthias Springer       replacement = builder.getAffineSymbolExpr(numSymbols++);
12905b0055a4SMatthias Springer       if (newSyms)
12915b0055a4SMatthias Springer         newSyms->push_back(operand.value());
12925b0055a4SMatthias Springer     }
12935b0055a4SMatthias Springer     // Add to corresponding replacements vector.
12945b0055a4SMatthias Springer     if (operand.index() < map.getNumDims()) {
12955b0055a4SMatthias Springer       dimReplacements[operand.index()] = replacement;
12965b0055a4SMatthias Springer     } else {
12975b0055a4SMatthias Springer       symReplacements[operand.index() - map.getNumDims()] = replacement;
12985b0055a4SMatthias Springer     }
12995b0055a4SMatthias Springer   }
13005b0055a4SMatthias Springer 
13015b0055a4SMatthias Springer   return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
13025b0055a4SMatthias Springer                                    dims.size(), numSymbols);
13035b0055a4SMatthias Springer }
13045b0055a4SMatthias Springer 
13055b0055a4SMatthias Springer LogicalResult
13065b0055a4SMatthias Springer mlir::getMultiAffineFunctionFromMap(AffineMap map,
13075b0055a4SMatthias Springer                                     MultiAffineFunction &multiAff) {
13085b0055a4SMatthias Springer   FlatLinearConstraints cst;
13095b0055a4SMatthias Springer   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
13105b0055a4SMatthias Springer   LogicalResult result = getFlattenedAffineExprs(map, &flattenedExprs, &cst);
13115b0055a4SMatthias Springer 
13125b0055a4SMatthias Springer   if (result.failed())
13135b0055a4SMatthias Springer     return failure();
13145b0055a4SMatthias Springer 
13155b0055a4SMatthias Springer   DivisionRepr divs = cst.getLocalReprs();
13165b0055a4SMatthias Springer   assert(divs.hasAllReprs() &&
13175b0055a4SMatthias Springer          "AffineMap cannot produce divs without local representation");
13185b0055a4SMatthias Springer 
13195b0055a4SMatthias Springer   // TODO: We shouldn't have to do this conversion.
13201a0e67d7SRamkumar Ramachandra   Matrix<DynamicAPInt> mat(map.getNumResults(),
13211a0e67d7SRamkumar Ramachandra                            map.getNumInputs() + divs.getNumDivs() + 1);
13225b0055a4SMatthias Springer   for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i)
13235b0055a4SMatthias Springer     for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j)
13245b0055a4SMatthias Springer       mat(i, j) = flattenedExprs[i][j];
13255b0055a4SMatthias Springer 
13265b0055a4SMatthias Springer   multiAff = MultiAffineFunction(
13275b0055a4SMatthias Springer       PresburgerSpace::getRelationSpace(map.getNumDims(), map.getNumResults(),
13285b0055a4SMatthias Springer                                         map.getNumSymbols(), divs.getNumDivs()),
13295b0055a4SMatthias Springer       mat, divs);
13305b0055a4SMatthias Springer 
13315b0055a4SMatthias Springer   return success();
13325b0055a4SMatthias Springer }
1333