15b0055a4SMatthias Springer //===- FlatLinearValueConstraints.cpp - Linear Constraint -----------------===// 25b0055a4SMatthias Springer // 35b0055a4SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45b0055a4SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 55b0055a4SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65b0055a4SMatthias Springer // 75b0055a4SMatthias Springer //===----------------------------------------------------------------------===// 85b0055a4SMatthias Springer 95b0055a4SMatthias Springer #include "mlir/Analysis/FlatLinearValueConstraints.h" 105b0055a4SMatthias Springer 115b0055a4SMatthias Springer #include "mlir/Analysis/Presburger/LinearTransform.h" 1224da7fa0SBharathi Ramana Joshi #include "mlir/Analysis/Presburger/PresburgerSpace.h" 135b0055a4SMatthias Springer #include "mlir/Analysis/Presburger/Simplex.h" 145b0055a4SMatthias Springer #include "mlir/Analysis/Presburger/Utils.h" 155b0055a4SMatthias Springer #include "mlir/IR/AffineExprVisitor.h" 165b0055a4SMatthias Springer #include "mlir/IR/Builders.h" 175b0055a4SMatthias Springer #include "mlir/IR/IntegerSet.h" 185b0055a4SMatthias Springer #include "mlir/Support/LLVM.h" 195b0055a4SMatthias Springer #include "llvm/ADT/STLExtras.h" 205b0055a4SMatthias Springer #include "llvm/ADT/SmallPtrSet.h" 215b0055a4SMatthias Springer #include "llvm/ADT/SmallVector.h" 225b0055a4SMatthias Springer #include "llvm/Support/Debug.h" 235b0055a4SMatthias Springer #include "llvm/Support/raw_ostream.h" 245b0055a4SMatthias Springer #include <optional> 255b0055a4SMatthias Springer 265b0055a4SMatthias Springer #define DEBUG_TYPE "flat-value-constraints" 275b0055a4SMatthias Springer 285b0055a4SMatthias Springer using namespace mlir; 295b0055a4SMatthias Springer using namespace presburger; 305b0055a4SMatthias Springer 315b0055a4SMatthias Springer //===----------------------------------------------------------------------===// 325b0055a4SMatthias Springer // AffineExprFlattener 335b0055a4SMatthias Springer //===----------------------------------------------------------------------===// 345b0055a4SMatthias Springer 355b0055a4SMatthias Springer namespace { 365b0055a4SMatthias Springer 375b0055a4SMatthias Springer // See comments for SimpleAffineExprFlattener. 3829a925abSBenjamin Maxwell // An AffineExprFlattenerWithLocalVars extends a SimpleAffineExprFlattener by 3929a925abSBenjamin Maxwell // recording constraint information associated with mod's, floordiv's, and 4029a925abSBenjamin Maxwell // ceildiv's in FlatLinearConstraints 'localVarCst'. 415b0055a4SMatthias Springer struct AffineExprFlattener : public SimpleAffineExprFlattener { 4229a925abSBenjamin Maxwell using SimpleAffineExprFlattener::SimpleAffineExprFlattener; 4329a925abSBenjamin Maxwell 445b0055a4SMatthias Springer // Constraints connecting newly introduced local variables (for mod's and 455b0055a4SMatthias Springer // div's) to existing (dimensional and symbolic) ones. These are always 465b0055a4SMatthias Springer // inequalities. 475b0055a4SMatthias Springer IntegerPolyhedron localVarCst; 485b0055a4SMatthias Springer 495b0055a4SMatthias Springer AffineExprFlattener(unsigned nDims, unsigned nSymbols) 505b0055a4SMatthias Springer : SimpleAffineExprFlattener(nDims, nSymbols), 5129a925abSBenjamin Maxwell localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {}; 525b0055a4SMatthias Springer 535b0055a4SMatthias Springer private: 545b0055a4SMatthias Springer // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr). 555b0055a4SMatthias Springer // The local variable added is always a floordiv of a pure add/mul affine 565b0055a4SMatthias Springer // function of other variables, coefficients of which are specified in 575b0055a4SMatthias Springer // `dividend' and with respect to the positive constant `divisor'. localExpr 585b0055a4SMatthias Springer // is the simplified tree expression (AffineExpr) corresponding to the 595b0055a4SMatthias Springer // quantifier. 605b0055a4SMatthias Springer void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, 615b0055a4SMatthias Springer AffineExpr localExpr) override { 625b0055a4SMatthias Springer SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); 635b0055a4SMatthias Springer // Update localVarCst. 645b0055a4SMatthias Springer localVarCst.addLocalFloorDiv(dividend, divisor); 655b0055a4SMatthias Springer } 6629a925abSBenjamin Maxwell 6729a925abSBenjamin Maxwell LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs, 6829a925abSBenjamin Maxwell ArrayRef<int64_t> rhs, 6929a925abSBenjamin Maxwell AffineExpr localExpr) override { 7029a925abSBenjamin Maxwell // AffineExprFlattener does not support semi-affine expressions. 7129a925abSBenjamin Maxwell return failure(); 7229a925abSBenjamin Maxwell } 7329a925abSBenjamin Maxwell }; 7429a925abSBenjamin Maxwell 7529a925abSBenjamin Maxwell // A SemiAffineExprFlattener is an AffineExprFlattenerWithLocalVars that adds 7629a925abSBenjamin Maxwell // conservative bounds for semi-affine expressions (given assumptions hold). If 7729a925abSBenjamin Maxwell // the assumptions required to add the semi-affine bounds are found not to hold 7829a925abSBenjamin Maxwell // the final constraints set will be empty/inconsistent. If the assumptions are 7929a925abSBenjamin Maxwell // never contradicted the final bounds still only will be correct if the 8029a925abSBenjamin Maxwell // assumptions hold. 8129a925abSBenjamin Maxwell struct SemiAffineExprFlattener : public AffineExprFlattener { 8229a925abSBenjamin Maxwell using AffineExprFlattener::AffineExprFlattener; 8329a925abSBenjamin Maxwell 8429a925abSBenjamin Maxwell LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs, 8529a925abSBenjamin Maxwell ArrayRef<int64_t> rhs, 8629a925abSBenjamin Maxwell AffineExpr localExpr) override { 8729a925abSBenjamin Maxwell auto result = 8829a925abSBenjamin Maxwell SimpleAffineExprFlattener::addLocalIdSemiAffine(lhs, rhs, localExpr); 8929a925abSBenjamin Maxwell assert(succeeded(result) && 9029a925abSBenjamin Maxwell "unexpected failure in SimpleAffineExprFlattener"); 9129a925abSBenjamin Maxwell (void)result; 9229a925abSBenjamin Maxwell 9329a925abSBenjamin Maxwell if (localExpr.getKind() == AffineExprKind::Mod) { 9429a925abSBenjamin Maxwell // Given two numbers a and b, division is defined as: 9529a925abSBenjamin Maxwell // 9629a925abSBenjamin Maxwell // a = bq + r 9729a925abSBenjamin Maxwell // 0 <= r < |b| (where |x| is the absolute value of x) 9829a925abSBenjamin Maxwell // 9929a925abSBenjamin Maxwell // q = a floordiv b 10029a925abSBenjamin Maxwell // r = a mod b 10129a925abSBenjamin Maxwell 10229a925abSBenjamin Maxwell // Add a new local variable (r) to represent the mod. 10329a925abSBenjamin Maxwell unsigned rPos = localVarCst.appendVar(VarKind::Local); 10429a925abSBenjamin Maxwell 10529a925abSBenjamin Maxwell // r >= 0 (Can ALWAYS be added) 10629a925abSBenjamin Maxwell localVarCst.addBound(BoundType::LB, rPos, 0); 10729a925abSBenjamin Maxwell 10829a925abSBenjamin Maxwell // r < b (Can be added if b > 0, which we assume here) 10929a925abSBenjamin Maxwell ArrayRef<int64_t> b = rhs; 11029a925abSBenjamin Maxwell SmallVector<int64_t> bSubR(b); 11129a925abSBenjamin Maxwell bSubR.insert(bSubR.begin() + rPos, -1); 11229a925abSBenjamin Maxwell // Note: bSubR = b - r 11329a925abSBenjamin Maxwell // So this adds the bound b - r >= 1 (equivalent to r < b) 11429a925abSBenjamin Maxwell localVarCst.addBound(BoundType::LB, bSubR, 1); 11529a925abSBenjamin Maxwell 11629a925abSBenjamin Maxwell // Note: The assumption of b > 0 is based on the affine expression docs, 11729a925abSBenjamin Maxwell // which state "RHS of mod is always a constant or a symbolic expression 11829a925abSBenjamin Maxwell // with a positive value." (see AffineExprKind in AffineExpr.h). If this 11929a925abSBenjamin Maxwell // assumption does not hold constraints (added above) are a contradiction. 12029a925abSBenjamin Maxwell 12129a925abSBenjamin Maxwell return success(); 12229a925abSBenjamin Maxwell } 12329a925abSBenjamin Maxwell 12429a925abSBenjamin Maxwell // TODO: Support other semi-affine expressions. 12529a925abSBenjamin Maxwell return failure(); 12629a925abSBenjamin Maxwell } 1275b0055a4SMatthias Springer }; 1285b0055a4SMatthias Springer 1295b0055a4SMatthias Springer } // namespace 1305b0055a4SMatthias Springer 1315b0055a4SMatthias Springer // Flattens the expressions in map. Returns failure if 'expr' was unable to be 132dc4786b4Slong.chen // flattened. For example two specific cases: 13329a925abSBenjamin Maxwell // 1. an unhandled semi-affine expressions is found. 134dc4786b4Slong.chen // 2. has poison expression (i.e., division by zero). 1355b0055a4SMatthias Springer static LogicalResult 1365b0055a4SMatthias Springer getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, 1375b0055a4SMatthias Springer unsigned numSymbols, 1385b0055a4SMatthias Springer std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 13929a925abSBenjamin Maxwell FlatLinearConstraints *localVarCst, 14029a925abSBenjamin Maxwell bool addConservativeSemiAffineBounds = false) { 1415b0055a4SMatthias Springer if (exprs.empty()) { 1425b0055a4SMatthias Springer if (localVarCst) 1435b0055a4SMatthias Springer *localVarCst = FlatLinearConstraints(numDims, numSymbols); 1445b0055a4SMatthias Springer return success(); 1455b0055a4SMatthias Springer } 1465b0055a4SMatthias Springer 14729a925abSBenjamin Maxwell auto flattenExprs = [&](AffineExprFlattener &flattener) -> LogicalResult { 1485b0055a4SMatthias Springer // Use the same flattener to simplify each expression successively. This way 1495b0055a4SMatthias Springer // local variables / expressions are shared. 1505b0055a4SMatthias Springer for (auto expr : exprs) { 151dc4786b4Slong.chen auto flattenResult = flattener.walkPostOrder(expr); 152dc4786b4Slong.chen if (failed(flattenResult)) 153dc4786b4Slong.chen return failure(); 1545b0055a4SMatthias Springer } 1555b0055a4SMatthias Springer 1565b0055a4SMatthias Springer assert(flattener.operandExprStack.size() == exprs.size()); 1575b0055a4SMatthias Springer flattenedExprs->clear(); 1585b0055a4SMatthias Springer flattenedExprs->assign(flattener.operandExprStack.begin(), 1595b0055a4SMatthias Springer flattener.operandExprStack.end()); 1605b0055a4SMatthias Springer 1615b0055a4SMatthias Springer if (localVarCst) 1625b0055a4SMatthias Springer localVarCst->clearAndCopyFrom(flattener.localVarCst); 1635b0055a4SMatthias Springer 1645b0055a4SMatthias Springer return success(); 16529a925abSBenjamin Maxwell }; 16629a925abSBenjamin Maxwell 16729a925abSBenjamin Maxwell if (addConservativeSemiAffineBounds) { 16829a925abSBenjamin Maxwell SemiAffineExprFlattener flattener(numDims, numSymbols); 16929a925abSBenjamin Maxwell return flattenExprs(flattener); 17029a925abSBenjamin Maxwell } 17129a925abSBenjamin Maxwell 17229a925abSBenjamin Maxwell AffineExprFlattener flattener(numDims, numSymbols); 17329a925abSBenjamin Maxwell return flattenExprs(flattener); 1745b0055a4SMatthias Springer } 1755b0055a4SMatthias Springer 1765b0055a4SMatthias Springer // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to 17729a925abSBenjamin Maxwell // be flattened (an unhandled semi-affine was found). 17829a925abSBenjamin Maxwell LogicalResult mlir::getFlattenedAffineExpr( 17929a925abSBenjamin Maxwell AffineExpr expr, unsigned numDims, unsigned numSymbols, 18029a925abSBenjamin Maxwell SmallVectorImpl<int64_t> *flattenedExpr, FlatLinearConstraints *localVarCst, 18129a925abSBenjamin Maxwell bool addConservativeSemiAffineBounds) { 1825b0055a4SMatthias Springer std::vector<SmallVector<int64_t, 8>> flattenedExprs; 18329a925abSBenjamin Maxwell LogicalResult ret = 18429a925abSBenjamin Maxwell ::getFlattenedAffineExprs({expr}, numDims, numSymbols, &flattenedExprs, 18529a925abSBenjamin Maxwell localVarCst, addConservativeSemiAffineBounds); 1865b0055a4SMatthias Springer *flattenedExpr = flattenedExprs[0]; 1875b0055a4SMatthias Springer return ret; 1885b0055a4SMatthias Springer } 1895b0055a4SMatthias Springer 1905b0055a4SMatthias Springer /// Flattens the expressions in map. Returns failure if 'expr' was unable to be 19129a925abSBenjamin Maxwell /// flattened (i.e., an unhandled semi-affine was found). 1925b0055a4SMatthias Springer LogicalResult mlir::getFlattenedAffineExprs( 1935b0055a4SMatthias Springer AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 19429a925abSBenjamin Maxwell FlatLinearConstraints *localVarCst, bool addConservativeSemiAffineBounds) { 1955b0055a4SMatthias Springer if (map.getNumResults() == 0) { 1965b0055a4SMatthias Springer if (localVarCst) 1975b0055a4SMatthias Springer *localVarCst = 1985b0055a4SMatthias Springer FlatLinearConstraints(map.getNumDims(), map.getNumSymbols()); 1995b0055a4SMatthias Springer return success(); 2005b0055a4SMatthias Springer } 20129a925abSBenjamin Maxwell return ::getFlattenedAffineExprs( 20229a925abSBenjamin Maxwell map.getResults(), map.getNumDims(), map.getNumSymbols(), flattenedExprs, 20329a925abSBenjamin Maxwell localVarCst, addConservativeSemiAffineBounds); 2045b0055a4SMatthias Springer } 2055b0055a4SMatthias Springer 2065b0055a4SMatthias Springer LogicalResult mlir::getFlattenedAffineExprs( 2075b0055a4SMatthias Springer IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 2085b0055a4SMatthias Springer FlatLinearConstraints *localVarCst) { 2095b0055a4SMatthias Springer if (set.getNumConstraints() == 0) { 2105b0055a4SMatthias Springer if (localVarCst) 2115b0055a4SMatthias Springer *localVarCst = 2125b0055a4SMatthias Springer FlatLinearConstraints(set.getNumDims(), set.getNumSymbols()); 2135b0055a4SMatthias Springer return success(); 2145b0055a4SMatthias Springer } 2155b0055a4SMatthias Springer return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), 2165b0055a4SMatthias Springer set.getNumSymbols(), flattenedExprs, 2175b0055a4SMatthias Springer localVarCst); 2185b0055a4SMatthias Springer } 2195b0055a4SMatthias Springer 2205b0055a4SMatthias Springer //===----------------------------------------------------------------------===// 2215b0055a4SMatthias Springer // FlatLinearConstraints 2225b0055a4SMatthias Springer //===----------------------------------------------------------------------===// 2235b0055a4SMatthias Springer 2245b0055a4SMatthias Springer // Similar to `composeMap` except that no Values need be associated with the 2255b0055a4SMatthias Springer // constraint system nor are they looked at -- the dimensions and symbols of 2265b0055a4SMatthias Springer // `other` are expected to correspond 1:1 to `this` system. 2275b0055a4SMatthias Springer LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) { 2285b0055a4SMatthias Springer assert(other.getNumDims() == getNumDimVars() && "dim mismatch"); 2295b0055a4SMatthias Springer assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"); 2305b0055a4SMatthias Springer 2315b0055a4SMatthias Springer std::vector<SmallVector<int64_t, 8>> flatExprs; 2325b0055a4SMatthias Springer if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs))) 2335b0055a4SMatthias Springer return failure(); 2345b0055a4SMatthias Springer assert(flatExprs.size() == other.getNumResults()); 2355b0055a4SMatthias Springer 2365b0055a4SMatthias Springer // Add dimensions corresponding to the map's results. 2375b0055a4SMatthias Springer insertDimVar(/*pos=*/0, /*num=*/other.getNumResults()); 2385b0055a4SMatthias Springer 2395b0055a4SMatthias Springer // We add one equality for each result connecting the result dim of the map to 2405b0055a4SMatthias Springer // the other variables. 2415b0055a4SMatthias Springer // E.g.: if the expression is 16*i0 + i1, and this is the r^th 2425b0055a4SMatthias Springer // iteration/result of the value map, we are adding the equality: 2435b0055a4SMatthias Springer // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we 2445b0055a4SMatthias Springer // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. 2455b0055a4SMatthias Springer for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { 2465b0055a4SMatthias Springer const auto &flatExpr = flatExprs[r]; 2475b0055a4SMatthias Springer assert(flatExpr.size() >= other.getNumInputs() + 1); 2485b0055a4SMatthias Springer 2495b0055a4SMatthias Springer SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0); 2505b0055a4SMatthias Springer // Set the coefficient for this result to one. 2515b0055a4SMatthias Springer eqToAdd[r] = 1; 2525b0055a4SMatthias Springer 2535b0055a4SMatthias Springer // Dims and symbols. 2545b0055a4SMatthias Springer for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) { 2555b0055a4SMatthias Springer // Negate `eq[r]` since the newly added dimension will be set to this one. 2565b0055a4SMatthias Springer eqToAdd[e + i] = -flatExpr[i]; 2575b0055a4SMatthias Springer } 2585b0055a4SMatthias Springer // Local columns of `eq` are at the beginning. 2595b0055a4SMatthias Springer unsigned j = getNumDimVars() + getNumSymbolVars(); 2605b0055a4SMatthias Springer unsigned end = flatExpr.size() - 1; 2615b0055a4SMatthias Springer for (unsigned i = other.getNumInputs(); i < end; i++, j++) { 2625b0055a4SMatthias Springer eqToAdd[j] = -flatExpr[i]; 2635b0055a4SMatthias Springer } 2645b0055a4SMatthias Springer 2655b0055a4SMatthias Springer // Constant term. 2665b0055a4SMatthias Springer eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; 2675b0055a4SMatthias Springer 2685b0055a4SMatthias Springer // Add the equality connecting the result of the map to this constraint set. 2695b0055a4SMatthias Springer addEquality(eqToAdd); 2705b0055a4SMatthias Springer } 2715b0055a4SMatthias Springer 2725b0055a4SMatthias Springer return success(); 2735b0055a4SMatthias Springer } 2745b0055a4SMatthias Springer 2755b0055a4SMatthias Springer // Determine whether the variable at 'pos' (say var_r) can be expressed as 2765b0055a4SMatthias Springer // modulo of another known variable (say var_n) w.r.t a constant. For example, 2775b0055a4SMatthias Springer // if the following constraints hold true: 2785b0055a4SMatthias Springer // ``` 2795b0055a4SMatthias Springer // 0 <= var_r <= divisor - 1 2805b0055a4SMatthias Springer // var_n - (divisor * q_expr) = var_r 2815b0055a4SMatthias Springer // ``` 2825b0055a4SMatthias Springer // where `var_n` is a known variable (called dividend), and `q_expr` is an 2835b0055a4SMatthias Springer // `AffineExpr` (called the quotient expression), `var_r` can be written as: 2845b0055a4SMatthias Springer // 2855b0055a4SMatthias Springer // `var_r = var_n mod divisor`. 2865b0055a4SMatthias Springer // 2875b0055a4SMatthias Springer // Additionally, in a special case of the above constaints where `q_expr` is an 2885b0055a4SMatthias Springer // variable itself that is not yet known (say `var_q`), it can be written as a 2895b0055a4SMatthias Springer // floordiv in the following way: 2905b0055a4SMatthias Springer // 2915b0055a4SMatthias Springer // `var_q = var_n floordiv divisor`. 2925b0055a4SMatthias Springer // 293c45c9625SVinayaka Bandishti // First 'num' dimensional variables starting at 'offset' are 294c45c9625SVinayaka Bandishti // derived/to-be-derived in terms of the remaining variables. The remaining 295c45c9625SVinayaka Bandishti // variables are assigned trivial affine expressions in `memo`. For example, 296c45c9625SVinayaka Bandishti // memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2: 297c45c9625SVinayaka Bandishti // memo ==> d0 d1 . . d2 ... 298c45c9625SVinayaka Bandishti // cst ==> c0 c1 c2 c3 c4 ... 299c45c9625SVinayaka Bandishti // 3005b0055a4SMatthias Springer // Returns true if the above mod or floordiv are detected, updating 'memo' with 3015b0055a4SMatthias Springer // these new expressions. Returns false otherwise. 3025b0055a4SMatthias Springer static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos, 303c45c9625SVinayaka Bandishti unsigned offset, unsigned num, int64_t lbConst, 304c45c9625SVinayaka Bandishti int64_t ubConst, MLIRContext *context, 305c45c9625SVinayaka Bandishti SmallVectorImpl<AffineExpr> &memo) { 3065b0055a4SMatthias Springer assert(pos < cst.getNumVars() && "invalid position"); 3075b0055a4SMatthias Springer 3085b0055a4SMatthias Springer // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can 3095b0055a4SMatthias Springer // be determined. 3105b0055a4SMatthias Springer if (lbConst != 0 || ubConst < 1) 3115b0055a4SMatthias Springer return false; 3125b0055a4SMatthias Springer int64_t divisor = ubConst + 1; 3135b0055a4SMatthias Springer 3145b0055a4SMatthias Springer // Check for the aforementioned conditions in each equality. 3155b0055a4SMatthias Springer for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities(); 3165b0055a4SMatthias Springer curEquality < numEqualities; curEquality++) { 3175b0055a4SMatthias Springer int64_t coefficientAtPos = cst.atEq64(curEquality, pos); 3185b0055a4SMatthias Springer // If current equality does not involve `var_r`, continue to the next 3195b0055a4SMatthias Springer // equality. 3205b0055a4SMatthias Springer if (coefficientAtPos == 0) 3215b0055a4SMatthias Springer continue; 3225b0055a4SMatthias Springer 3235b0055a4SMatthias Springer // Constant term should be 0 in this equality. 3245b0055a4SMatthias Springer if (cst.atEq64(curEquality, cst.getNumCols() - 1) != 0) 3255b0055a4SMatthias Springer continue; 3265b0055a4SMatthias Springer 3275b0055a4SMatthias Springer // Traverse through the equality and construct the dividend expression 3285b0055a4SMatthias Springer // `dividendExpr`, to contain all the variables which are known and are 3295b0055a4SMatthias Springer // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the 3305b0055a4SMatthias Springer // `dividendExpr` gets simplified into a single variable `var_n` discussed 3315b0055a4SMatthias Springer // above. 3325b0055a4SMatthias Springer auto dividendExpr = getAffineConstantExpr(0, context); 3335b0055a4SMatthias Springer 3345b0055a4SMatthias Springer // Track the terms that go into quotient expression, later used to detect 3355b0055a4SMatthias Springer // additional floordiv. 3365b0055a4SMatthias Springer unsigned quotientCount = 0; 3375b0055a4SMatthias Springer int quotientPosition = -1; 3385b0055a4SMatthias Springer int quotientSign = 1; 3395b0055a4SMatthias Springer 3405b0055a4SMatthias Springer // Consider each term in the current equality. 3415b0055a4SMatthias Springer unsigned curVar, e; 3425b0055a4SMatthias Springer for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) { 3435b0055a4SMatthias Springer // Ignore var_r. 3445b0055a4SMatthias Springer if (curVar == pos) 3455b0055a4SMatthias Springer continue; 3465b0055a4SMatthias Springer int64_t coefficientOfCurVar = cst.atEq64(curEquality, curVar); 3475b0055a4SMatthias Springer // Ignore vars that do not contribute to the current equality. 3485b0055a4SMatthias Springer if (coefficientOfCurVar == 0) 3495b0055a4SMatthias Springer continue; 3505b0055a4SMatthias Springer // Check if the current var goes into the quotient expression. 3515b0055a4SMatthias Springer if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) { 3525b0055a4SMatthias Springer quotientCount++; 3535b0055a4SMatthias Springer quotientPosition = curVar; 3545b0055a4SMatthias Springer quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1; 3555b0055a4SMatthias Springer continue; 3565b0055a4SMatthias Springer } 3575b0055a4SMatthias Springer // Variables that are part of dividendExpr should be known. 3585b0055a4SMatthias Springer if (!memo[curVar]) 3595b0055a4SMatthias Springer break; 3605b0055a4SMatthias Springer // Append the current variable to the dividend expression. 3615b0055a4SMatthias Springer dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar; 3625b0055a4SMatthias Springer } 3635b0055a4SMatthias Springer 3645b0055a4SMatthias Springer // Can't construct expression as it depends on a yet uncomputed var. 3655b0055a4SMatthias Springer if (curVar < e) 3665b0055a4SMatthias Springer continue; 3675b0055a4SMatthias Springer 3685b0055a4SMatthias Springer // Express `var_r` in terms of the other vars collected so far. 3695b0055a4SMatthias Springer if (coefficientAtPos > 0) 3705b0055a4SMatthias Springer dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos); 3715b0055a4SMatthias Springer else 3725b0055a4SMatthias Springer dividendExpr = dividendExpr.floorDiv(-coefficientAtPos); 3735b0055a4SMatthias Springer 3745b0055a4SMatthias Springer // Simplify the expression. 3755b0055a4SMatthias Springer dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimVars(), 3765b0055a4SMatthias Springer cst.getNumSymbolVars()); 3775b0055a4SMatthias Springer // Only if the final dividend expression is just a single var (which we call 3785b0055a4SMatthias Springer // `var_n`), we can proceed. 3795b0055a4SMatthias Springer // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it 3805b0055a4SMatthias Springer // to dims themselves. 3811609f1c2Slong.chen auto dimExpr = dyn_cast<AffineDimExpr>(dividendExpr); 3825b0055a4SMatthias Springer if (!dimExpr) 3835b0055a4SMatthias Springer continue; 3845b0055a4SMatthias Springer 3855b0055a4SMatthias Springer // Express `var_r` as `var_n % divisor` and store the expression in `memo`. 3865b0055a4SMatthias Springer if (quotientCount >= 1) { 387c45c9625SVinayaka Bandishti // Find the column corresponding to `dimExpr`. `num` columns starting at 388c45c9625SVinayaka Bandishti // `offset` correspond to previously unknown variables. The column 389c45c9625SVinayaka Bandishti // corresponding to the trivially known `dimExpr` can be on either side 390c45c9625SVinayaka Bandishti // of these. 391c45c9625SVinayaka Bandishti unsigned dimExprPos = dimExpr.getPosition(); 392c45c9625SVinayaka Bandishti unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num; 393c45c9625SVinayaka Bandishti auto ub = cst.getConstantBound64(BoundType::UB, dimExprCol); 3945b0055a4SMatthias Springer // If `var_n` has an upperbound that is less than the divisor, mod can be 3955b0055a4SMatthias Springer // eliminated altogether. 3965b0055a4SMatthias Springer if (ub && *ub < divisor) 3975b0055a4SMatthias Springer memo[pos] = dimExpr; 3985b0055a4SMatthias Springer else 3995b0055a4SMatthias Springer memo[pos] = dimExpr % divisor; 4005b0055a4SMatthias Springer // If a unique quotient `var_q` was seen, it can be expressed as 4015b0055a4SMatthias Springer // `var_n floordiv divisor`. 4025b0055a4SMatthias Springer if (quotientCount == 1 && !memo[quotientPosition]) 4035b0055a4SMatthias Springer memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign; 4045b0055a4SMatthias Springer 4055b0055a4SMatthias Springer return true; 4065b0055a4SMatthias Springer } 4075b0055a4SMatthias Springer } 4085b0055a4SMatthias Springer return false; 4095b0055a4SMatthias Springer } 4105b0055a4SMatthias Springer 4115b0055a4SMatthias Springer /// Check if the pos^th variable can be expressed as a floordiv of an affine 4125b0055a4SMatthias Springer /// function of other variables (where the divisor is a positive constant) 4135b0055a4SMatthias Springer /// given the initial set of expressions in `exprs`. If it can be, the 4145b0055a4SMatthias Springer /// corresponding position in `exprs` is set as the detected affine expr. For 4155b0055a4SMatthias Springer /// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can 4165b0055a4SMatthias Springer /// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28 4175b0055a4SMatthias Springer /// <= i <= 32q + 31 => q = i floordiv 32. 4185b0055a4SMatthias Springer static bool detectAsFloorDiv(const FlatLinearConstraints &cst, unsigned pos, 4195b0055a4SMatthias Springer MLIRContext *context, 4205b0055a4SMatthias Springer SmallVectorImpl<AffineExpr> &exprs) { 4215b0055a4SMatthias Springer assert(pos < cst.getNumVars() && "invalid position"); 4225b0055a4SMatthias Springer 4235b0055a4SMatthias Springer // Get upper-lower bound pair for this variable. 4245b0055a4SMatthias Springer SmallVector<bool, 8> foundRepr(cst.getNumVars(), false); 4255b0055a4SMatthias Springer for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i) 4265b0055a4SMatthias Springer if (exprs[i]) 4275b0055a4SMatthias Springer foundRepr[i] = true; 4285b0055a4SMatthias Springer 4295b0055a4SMatthias Springer SmallVector<int64_t, 8> dividend(cst.getNumCols()); 4305b0055a4SMatthias Springer unsigned divisor; 4315b0055a4SMatthias Springer auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); 4325b0055a4SMatthias Springer 4335b0055a4SMatthias Springer // No upper-lower bound pair found for this var. 4345b0055a4SMatthias Springer if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality) 4355b0055a4SMatthias Springer return false; 4365b0055a4SMatthias Springer 4375b0055a4SMatthias Springer // Construct the dividend expression. 4385b0055a4SMatthias Springer auto dividendExpr = getAffineConstantExpr(dividend.back(), context); 4395b0055a4SMatthias Springer for (unsigned c = 0, f = cst.getNumVars(); c < f; c++) 4405b0055a4SMatthias Springer if (dividend[c] != 0) 4415b0055a4SMatthias Springer dividendExpr = dividendExpr + dividend[c] * exprs[c]; 4425b0055a4SMatthias Springer 4435b0055a4SMatthias Springer // Successfully detected the floordiv. 4445b0055a4SMatthias Springer exprs[pos] = dividendExpr.floorDiv(divisor); 4455b0055a4SMatthias Springer return true; 4465b0055a4SMatthias Springer } 4475b0055a4SMatthias Springer 4485b0055a4SMatthias Springer std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound( 4495b0055a4SMatthias Springer unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, 4505b0055a4SMatthias Springer ArrayRef<AffineExpr> localExprs, MLIRContext *context, 4515b0055a4SMatthias Springer bool closedUB) const { 4525b0055a4SMatthias Springer assert(pos + offset < getNumDimVars() && "invalid dim start pos"); 4535b0055a4SMatthias Springer assert(symStartPos >= (pos + offset) && "invalid sym start pos"); 4545b0055a4SMatthias Springer assert(getNumLocalVars() == localExprs.size() && 4555b0055a4SMatthias Springer "incorrect local exprs count"); 4565b0055a4SMatthias Springer 4575b0055a4SMatthias Springer SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices; 4585b0055a4SMatthias Springer getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices, 4595b0055a4SMatthias Springer offset, num); 4605b0055a4SMatthias Springer 4615b0055a4SMatthias Springer /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). 4625b0055a4SMatthias Springer auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) { 4635b0055a4SMatthias Springer b.clear(); 4645b0055a4SMatthias Springer for (unsigned i = 0, e = a.size(); i < e; ++i) { 4655b0055a4SMatthias Springer if (i < offset || i >= offset + num) 4665b0055a4SMatthias Springer b.push_back(a[i]); 4675b0055a4SMatthias Springer } 4685b0055a4SMatthias Springer }; 4695b0055a4SMatthias Springer 4705b0055a4SMatthias Springer SmallVector<int64_t, 8> lb, ub; 4715b0055a4SMatthias Springer SmallVector<AffineExpr, 4> lbExprs; 4725b0055a4SMatthias Springer unsigned dimCount = symStartPos - num; 4735b0055a4SMatthias Springer unsigned symCount = getNumDimAndSymbolVars() - symStartPos; 4745b0055a4SMatthias Springer lbExprs.reserve(lbIndices.size() + eqIndices.size()); 4755b0055a4SMatthias Springer // Lower bound expressions. 4765b0055a4SMatthias Springer for (auto idx : lbIndices) { 4775b0055a4SMatthias Springer auto ineq = getInequality64(idx); 4785b0055a4SMatthias Springer // Extract the lower bound (in terms of other coeff's + const), i.e., if 4795b0055a4SMatthias Springer // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j 4805b0055a4SMatthias Springer // - 1. 4815b0055a4SMatthias Springer addCoeffs(ineq, lb); 4825b0055a4SMatthias Springer std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>()); 4835b0055a4SMatthias Springer auto expr = 4845b0055a4SMatthias Springer getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context); 4855b0055a4SMatthias Springer // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor 4865b0055a4SMatthias Springer int64_t divisor = std::abs(ineq[pos + offset]); 4875b0055a4SMatthias Springer expr = (expr + divisor - 1).floorDiv(divisor); 4885b0055a4SMatthias Springer lbExprs.push_back(expr); 4895b0055a4SMatthias Springer } 4905b0055a4SMatthias Springer 4915b0055a4SMatthias Springer SmallVector<AffineExpr, 4> ubExprs; 4925b0055a4SMatthias Springer ubExprs.reserve(ubIndices.size() + eqIndices.size()); 4935b0055a4SMatthias Springer // Upper bound expressions. 4945b0055a4SMatthias Springer for (auto idx : ubIndices) { 4955b0055a4SMatthias Springer auto ineq = getInequality64(idx); 4965b0055a4SMatthias Springer // Extract the upper bound (in terms of other coeff's + const). 4975b0055a4SMatthias Springer addCoeffs(ineq, ub); 4985b0055a4SMatthias Springer auto expr = 4995b0055a4SMatthias Springer getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context); 5005b0055a4SMatthias Springer expr = expr.floorDiv(std::abs(ineq[pos + offset])); 5015b0055a4SMatthias Springer int64_t ubAdjustment = closedUB ? 0 : 1; 5025b0055a4SMatthias Springer ubExprs.push_back(expr + ubAdjustment); 5035b0055a4SMatthias Springer } 5045b0055a4SMatthias Springer 5055b0055a4SMatthias Springer // Equalities. It's both a lower and a upper bound. 5065b0055a4SMatthias Springer SmallVector<int64_t, 4> b; 5075b0055a4SMatthias Springer for (auto idx : eqIndices) { 5085b0055a4SMatthias Springer auto eq = getEquality64(idx); 5095b0055a4SMatthias Springer addCoeffs(eq, b); 5105b0055a4SMatthias Springer if (eq[pos + offset] > 0) 5115b0055a4SMatthias Springer std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>()); 5125b0055a4SMatthias Springer 5135b0055a4SMatthias Springer // Extract the upper bound (in terms of other coeff's + const). 5145b0055a4SMatthias Springer auto expr = 5155b0055a4SMatthias Springer getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); 5165b0055a4SMatthias Springer expr = expr.floorDiv(std::abs(eq[pos + offset])); 5175b0055a4SMatthias Springer // Upper bound is exclusive. 5185b0055a4SMatthias Springer ubExprs.push_back(expr + 1); 5195b0055a4SMatthias Springer // Lower bound. 5205b0055a4SMatthias Springer expr = 5215b0055a4SMatthias Springer getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); 5225b0055a4SMatthias Springer expr = expr.ceilDiv(std::abs(eq[pos + offset])); 5235b0055a4SMatthias Springer lbExprs.push_back(expr); 5245b0055a4SMatthias Springer } 5255b0055a4SMatthias Springer 5265b0055a4SMatthias Springer auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context); 5275b0055a4SMatthias Springer auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context); 5285b0055a4SMatthias Springer 5295b0055a4SMatthias Springer return {lbMap, ubMap}; 5305b0055a4SMatthias Springer } 5315b0055a4SMatthias Springer 5325b0055a4SMatthias Springer /// Computes the lower and upper bounds of the first 'num' dimensional 5335b0055a4SMatthias Springer /// variables (starting at 'offset') as affine maps of the remaining 5345b0055a4SMatthias Springer /// variables (dimensional and symbolic variables). Local variables are 5355b0055a4SMatthias Springer /// themselves explicitly computed as affine functions of other variables in 5365b0055a4SMatthias Springer /// this process if needed. 5375b0055a4SMatthias Springer void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num, 5385b0055a4SMatthias Springer MLIRContext *context, 5395b0055a4SMatthias Springer SmallVectorImpl<AffineMap> *lbMaps, 5405b0055a4SMatthias Springer SmallVectorImpl<AffineMap> *ubMaps, 5415b0055a4SMatthias Springer bool closedUB) { 5427e5d300dSMatthias Springer assert(offset + num <= getNumDimVars() && "invalid range"); 5435b0055a4SMatthias Springer 5445b0055a4SMatthias Springer // Basic simplification. 5455b0055a4SMatthias Springer normalizeConstraintsByGCD(); 5465b0055a4SMatthias Springer 5475b0055a4SMatthias Springer LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num 5485b0055a4SMatthias Springer << " variables\n"); 5495b0055a4SMatthias Springer LLVM_DEBUG(dump()); 5505b0055a4SMatthias Springer 5515b0055a4SMatthias Springer // Record computed/detected variables. 5525b0055a4SMatthias Springer SmallVector<AffineExpr, 8> memo(getNumVars()); 5535b0055a4SMatthias Springer // Initialize dimensional and symbolic variables. 5545b0055a4SMatthias Springer for (unsigned i = 0, e = getNumDimVars(); i < e; i++) { 5555b0055a4SMatthias Springer if (i < offset) 5565b0055a4SMatthias Springer memo[i] = getAffineDimExpr(i, context); 5575b0055a4SMatthias Springer else if (i >= offset + num) 5585b0055a4SMatthias Springer memo[i] = getAffineDimExpr(i - num, context); 5595b0055a4SMatthias Springer } 5605b0055a4SMatthias Springer for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++) 5615b0055a4SMatthias Springer memo[i] = getAffineSymbolExpr(i - getNumDimVars(), context); 5625b0055a4SMatthias Springer 5635b0055a4SMatthias Springer bool changed; 5645b0055a4SMatthias Springer do { 5655b0055a4SMatthias Springer changed = false; 5665b0055a4SMatthias Springer // Identify yet unknown variables as constants or mod's / floordiv's of 5675b0055a4SMatthias Springer // other variables if possible. 5685b0055a4SMatthias Springer for (unsigned pos = 0; pos < getNumVars(); pos++) { 5695b0055a4SMatthias Springer if (memo[pos]) 5705b0055a4SMatthias Springer continue; 5715b0055a4SMatthias Springer 5725b0055a4SMatthias Springer auto lbConst = getConstantBound64(BoundType::LB, pos); 5735b0055a4SMatthias Springer auto ubConst = getConstantBound64(BoundType::UB, pos); 5745b0055a4SMatthias Springer if (lbConst.has_value() && ubConst.has_value()) { 5755b0055a4SMatthias Springer // Detect equality to a constant. 5765b0055a4SMatthias Springer if (*lbConst == *ubConst) { 5775b0055a4SMatthias Springer memo[pos] = getAffineConstantExpr(*lbConst, context); 5785b0055a4SMatthias Springer changed = true; 5795b0055a4SMatthias Springer continue; 5805b0055a4SMatthias Springer } 5815b0055a4SMatthias Springer 5825b0055a4SMatthias Springer // Detect a variable as modulo of another variable w.r.t a 5835b0055a4SMatthias Springer // constant. 584c45c9625SVinayaka Bandishti if (detectAsMod(*this, pos, offset, num, *lbConst, *ubConst, context, 585c45c9625SVinayaka Bandishti memo)) { 5865b0055a4SMatthias Springer changed = true; 5875b0055a4SMatthias Springer continue; 5885b0055a4SMatthias Springer } 5895b0055a4SMatthias Springer } 5905b0055a4SMatthias Springer 5915b0055a4SMatthias Springer // Detect a variable as a floordiv of an affine function of other 5925b0055a4SMatthias Springer // variables (divisor is a positive constant). 5935b0055a4SMatthias Springer if (detectAsFloorDiv(*this, pos, context, memo)) { 5945b0055a4SMatthias Springer changed = true; 5955b0055a4SMatthias Springer continue; 5965b0055a4SMatthias Springer } 5975b0055a4SMatthias Springer 5985b0055a4SMatthias Springer // Detect a variable as an expression of other variables. 5995b0055a4SMatthias Springer unsigned idx; 6005b0055a4SMatthias Springer if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) { 6015b0055a4SMatthias Springer continue; 6025b0055a4SMatthias Springer } 6035b0055a4SMatthias Springer 6045b0055a4SMatthias Springer // Build AffineExpr solving for variable 'pos' in terms of all others. 6055b0055a4SMatthias Springer auto expr = getAffineConstantExpr(0, context); 6065b0055a4SMatthias Springer unsigned j, e; 6075b0055a4SMatthias Springer for (j = 0, e = getNumVars(); j < e; ++j) { 6085b0055a4SMatthias Springer if (j == pos) 6095b0055a4SMatthias Springer continue; 6105b0055a4SMatthias Springer int64_t c = atEq64(idx, j); 6115b0055a4SMatthias Springer if (c == 0) 6125b0055a4SMatthias Springer continue; 6135b0055a4SMatthias Springer // If any of the involved IDs hasn't been found yet, we can't proceed. 6145b0055a4SMatthias Springer if (!memo[j]) 6155b0055a4SMatthias Springer break; 6165b0055a4SMatthias Springer expr = expr + memo[j] * c; 6175b0055a4SMatthias Springer } 6185b0055a4SMatthias Springer if (j < e) 6195b0055a4SMatthias Springer // Can't construct expression as it depends on a yet uncomputed 6205b0055a4SMatthias Springer // variable. 6215b0055a4SMatthias Springer continue; 6225b0055a4SMatthias Springer 6235b0055a4SMatthias Springer // Add constant term to AffineExpr. 6245b0055a4SMatthias Springer expr = expr + atEq64(idx, getNumVars()); 6255b0055a4SMatthias Springer int64_t vPos = atEq64(idx, pos); 6265b0055a4SMatthias Springer assert(vPos != 0 && "expected non-zero here"); 6275b0055a4SMatthias Springer if (vPos > 0) 6285b0055a4SMatthias Springer expr = (-expr).floorDiv(vPos); 6295b0055a4SMatthias Springer else 6305b0055a4SMatthias Springer // vPos < 0. 6315b0055a4SMatthias Springer expr = expr.floorDiv(-vPos); 6325b0055a4SMatthias Springer // Successfully constructed expression. 6335b0055a4SMatthias Springer memo[pos] = expr; 6345b0055a4SMatthias Springer changed = true; 6355b0055a4SMatthias Springer } 6365b0055a4SMatthias Springer // This loop is guaranteed to reach a fixed point - since once an 6375b0055a4SMatthias Springer // variable's explicit form is computed (in memo[pos]), it's not updated 6385b0055a4SMatthias Springer // again. 6395b0055a4SMatthias Springer } while (changed); 6405b0055a4SMatthias Springer 6415b0055a4SMatthias Springer int64_t ubAdjustment = closedUB ? 0 : 1; 6425b0055a4SMatthias Springer 6435b0055a4SMatthias Springer // Set the lower and upper bound maps for all the variables that were 6445b0055a4SMatthias Springer // computed as affine expressions of the rest as the "detected expr" and 6455b0055a4SMatthias Springer // "detected expr + 1" respectively; set the undetected ones to null. 6465b0055a4SMatthias Springer std::optional<FlatLinearConstraints> tmpClone; 6475b0055a4SMatthias Springer for (unsigned pos = 0; pos < num; pos++) { 6485b0055a4SMatthias Springer unsigned numMapDims = getNumDimVars() - num; 6495b0055a4SMatthias Springer unsigned numMapSymbols = getNumSymbolVars(); 6505b0055a4SMatthias Springer AffineExpr expr = memo[pos + offset]; 6515b0055a4SMatthias Springer if (expr) 6525b0055a4SMatthias Springer expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); 6535b0055a4SMatthias Springer 6545b0055a4SMatthias Springer AffineMap &lbMap = (*lbMaps)[pos]; 6555b0055a4SMatthias Springer AffineMap &ubMap = (*ubMaps)[pos]; 6565b0055a4SMatthias Springer 6575b0055a4SMatthias Springer if (expr) { 6585b0055a4SMatthias Springer lbMap = AffineMap::get(numMapDims, numMapSymbols, expr); 6595b0055a4SMatthias Springer ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + ubAdjustment); 6605b0055a4SMatthias Springer } else { 6615b0055a4SMatthias Springer // TODO: Whenever there are local variables in the dependence 6625b0055a4SMatthias Springer // constraints, we'll conservatively over-approximate, since we don't 6635b0055a4SMatthias Springer // always explicitly compute them above (in the while loop). 6645b0055a4SMatthias Springer if (getNumLocalVars() == 0) { 6655b0055a4SMatthias Springer // Work on a copy so that we don't update this constraint system. 6665b0055a4SMatthias Springer if (!tmpClone) { 6675b0055a4SMatthias Springer tmpClone.emplace(FlatLinearConstraints(*this)); 6685b0055a4SMatthias Springer // Removing redundant inequalities is necessary so that we don't get 6695b0055a4SMatthias Springer // redundant loop bounds. 6705b0055a4SMatthias Springer tmpClone->removeRedundantInequalities(); 6715b0055a4SMatthias Springer } 6725b0055a4SMatthias Springer std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound( 6735b0055a4SMatthias Springer pos, offset, num, getNumDimVars(), /*localExprs=*/{}, context, 6745b0055a4SMatthias Springer closedUB); 6755b0055a4SMatthias Springer } 6765b0055a4SMatthias Springer 6775b0055a4SMatthias Springer // If the above fails, we'll just use the constant lower bound and the 6785b0055a4SMatthias Springer // constant upper bound (if they exist) as the slice bounds. 6795b0055a4SMatthias Springer // TODO: being conservative for the moment in cases that 6805b0055a4SMatthias Springer // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is 6815b0055a4SMatthias Springer // fixed (b/126426796). 6825b0055a4SMatthias Springer if (!lbMap || lbMap.getNumResults() > 1) { 6835b0055a4SMatthias Springer LLVM_DEBUG(llvm::dbgs() 6845b0055a4SMatthias Springer << "WARNING: Potentially over-approximating slice lb\n"); 6855b0055a4SMatthias Springer auto lbConst = getConstantBound64(BoundType::LB, pos + offset); 6865b0055a4SMatthias Springer if (lbConst.has_value()) { 6875b0055a4SMatthias Springer lbMap = AffineMap::get(numMapDims, numMapSymbols, 6885b0055a4SMatthias Springer getAffineConstantExpr(*lbConst, context)); 6895b0055a4SMatthias Springer } 6905b0055a4SMatthias Springer } 6915b0055a4SMatthias Springer if (!ubMap || ubMap.getNumResults() > 1) { 6925b0055a4SMatthias Springer LLVM_DEBUG(llvm::dbgs() 6935b0055a4SMatthias Springer << "WARNING: Potentially over-approximating slice ub\n"); 6945b0055a4SMatthias Springer auto ubConst = getConstantBound64(BoundType::UB, pos + offset); 6955b0055a4SMatthias Springer if (ubConst.has_value()) { 6965b0055a4SMatthias Springer ubMap = AffineMap::get( 6975b0055a4SMatthias Springer numMapDims, numMapSymbols, 6985b0055a4SMatthias Springer getAffineConstantExpr(*ubConst + ubAdjustment, context)); 6995b0055a4SMatthias Springer } 7005b0055a4SMatthias Springer } 7015b0055a4SMatthias Springer } 7025b0055a4SMatthias Springer LLVM_DEBUG(llvm::dbgs() 7035b0055a4SMatthias Springer << "lb map for pos = " << Twine(pos + offset) << ", expr: "); 7045b0055a4SMatthias Springer LLVM_DEBUG(lbMap.dump();); 7055b0055a4SMatthias Springer LLVM_DEBUG(llvm::dbgs() 7065b0055a4SMatthias Springer << "ub map for pos = " << Twine(pos + offset) << ", expr: "); 7075b0055a4SMatthias Springer LLVM_DEBUG(ubMap.dump();); 7085b0055a4SMatthias Springer } 7095b0055a4SMatthias Springer } 7105b0055a4SMatthias Springer 7115b0055a4SMatthias Springer LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals( 71229a925abSBenjamin Maxwell AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 71329a925abSBenjamin Maxwell bool addConservativeSemiAffineBounds) { 7145b0055a4SMatthias Springer FlatLinearConstraints localCst; 71529a925abSBenjamin Maxwell if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst, 71629a925abSBenjamin Maxwell addConservativeSemiAffineBounds))) { 7175b0055a4SMatthias Springer LLVM_DEBUG(llvm::dbgs() 7185b0055a4SMatthias Springer << "composition unimplemented for semi-affine maps\n"); 7195b0055a4SMatthias Springer return failure(); 7205b0055a4SMatthias Springer } 7215b0055a4SMatthias Springer 7225b0055a4SMatthias Springer // Add localCst information. 7235b0055a4SMatthias Springer if (localCst.getNumLocalVars() > 0) { 7245b0055a4SMatthias Springer unsigned numLocalVars = getNumLocalVars(); 7255b0055a4SMatthias Springer // Insert local dims of localCst at the beginning. 7265b0055a4SMatthias Springer insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars()); 7275b0055a4SMatthias Springer // Insert local dims of `this` at the end of localCst. 7285b0055a4SMatthias Springer localCst.appendLocalVar(/*num=*/numLocalVars); 7295b0055a4SMatthias Springer // Dimensions of localCst and this constraint set match. Append localCst to 7305b0055a4SMatthias Springer // this constraint set. 7315b0055a4SMatthias Springer append(localCst); 7325b0055a4SMatthias Springer } 7335b0055a4SMatthias Springer 7345b0055a4SMatthias Springer return success(); 7355b0055a4SMatthias Springer } 7365b0055a4SMatthias Springer 73729a925abSBenjamin Maxwell LogicalResult FlatLinearConstraints::addBound( 73829a925abSBenjamin Maxwell BoundType type, unsigned pos, AffineMap boundMap, bool isClosedBound, 73929a925abSBenjamin Maxwell AddConservativeSemiAffineBounds addSemiAffineBounds) { 7405b0055a4SMatthias Springer assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch"); 7415b0055a4SMatthias Springer assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"); 7425b0055a4SMatthias Springer assert(pos < getNumDimAndSymbolVars() && "invalid position"); 7435b0055a4SMatthias Springer assert((type != BoundType::EQ || isClosedBound) && 7445b0055a4SMatthias Springer "EQ bound must be closed."); 7455b0055a4SMatthias Springer 7465b0055a4SMatthias Springer // Equality follows the logic of lower bound except that we add an equality 7475b0055a4SMatthias Springer // instead of an inequality. 7485b0055a4SMatthias Springer assert((type != BoundType::EQ || boundMap.getNumResults() == 1) && 7495b0055a4SMatthias Springer "single result expected"); 7505b0055a4SMatthias Springer bool lower = type == BoundType::LB || type == BoundType::EQ; 7515b0055a4SMatthias Springer 7525b0055a4SMatthias Springer std::vector<SmallVector<int64_t, 8>> flatExprs; 75329a925abSBenjamin Maxwell if (failed(flattenAlignedMapAndMergeLocals( 75429a925abSBenjamin Maxwell boundMap, &flatExprs, 75529a925abSBenjamin Maxwell addSemiAffineBounds == AddConservativeSemiAffineBounds::Yes))) 7565b0055a4SMatthias Springer return failure(); 7575b0055a4SMatthias Springer assert(flatExprs.size() == boundMap.getNumResults()); 7585b0055a4SMatthias Springer 7595b0055a4SMatthias Springer // Add one (in)equality for each result. 7605b0055a4SMatthias Springer for (const auto &flatExpr : flatExprs) { 7615b0055a4SMatthias Springer SmallVector<int64_t> ineq(getNumCols(), 0); 7625b0055a4SMatthias Springer // Dims and symbols. 7635b0055a4SMatthias Springer for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { 7645b0055a4SMatthias Springer ineq[j] = lower ? -flatExpr[j] : flatExpr[j]; 7655b0055a4SMatthias Springer } 7665b0055a4SMatthias Springer // Invalid bound: pos appears in `boundMap`. 7675b0055a4SMatthias Springer // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or 7685b0055a4SMatthias Springer // its callers to prevent invalid bounds from being added. 7695b0055a4SMatthias Springer if (ineq[pos] != 0) 7705b0055a4SMatthias Springer continue; 7715b0055a4SMatthias Springer ineq[pos] = lower ? 1 : -1; 7725b0055a4SMatthias Springer // Local columns of `ineq` are at the beginning. 7735b0055a4SMatthias Springer unsigned j = getNumDimVars() + getNumSymbolVars(); 7745b0055a4SMatthias Springer unsigned end = flatExpr.size() - 1; 7755b0055a4SMatthias Springer for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) { 7765b0055a4SMatthias Springer ineq[j] = lower ? -flatExpr[i] : flatExpr[i]; 7775b0055a4SMatthias Springer } 7785b0055a4SMatthias Springer // Make the bound closed in if flatExpr is open. The inequality is always 7795b0055a4SMatthias Springer // created in the upper bound form, so the adjustment is -1. 7805b0055a4SMatthias Springer int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1; 7815b0055a4SMatthias Springer // Constant term. 7825b0055a4SMatthias Springer ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1] 7835b0055a4SMatthias Springer : flatExpr[flatExpr.size() - 1]) + 7845b0055a4SMatthias Springer boundAdjustment; 7855b0055a4SMatthias Springer type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq); 7865b0055a4SMatthias Springer } 7875b0055a4SMatthias Springer 7885b0055a4SMatthias Springer return success(); 7895b0055a4SMatthias Springer } 7905b0055a4SMatthias Springer 79129a925abSBenjamin Maxwell LogicalResult FlatLinearConstraints::addBound( 79229a925abSBenjamin Maxwell BoundType type, unsigned pos, AffineMap boundMap, 79329a925abSBenjamin Maxwell AddConservativeSemiAffineBounds addSemiAffineBounds) { 79429a925abSBenjamin Maxwell return addBound(type, pos, boundMap, 79529a925abSBenjamin Maxwell /*isClosedBound=*/type != BoundType::UB, addSemiAffineBounds); 7965b0055a4SMatthias Springer } 7975b0055a4SMatthias Springer 7985b0055a4SMatthias Springer /// Compute an explicit representation for local vars. For all systems coming 7995b0055a4SMatthias Springer /// from MLIR integer sets, maps, or expressions where local vars were 8005b0055a4SMatthias Springer /// introduced to model floordivs and mods, this always succeeds. 8015b0055a4SMatthias Springer LogicalResult 8025b0055a4SMatthias Springer FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo, 8035b0055a4SMatthias Springer MLIRContext *context) const { 8045b0055a4SMatthias Springer unsigned numDims = getNumDimVars(); 8055b0055a4SMatthias Springer unsigned numSyms = getNumSymbolVars(); 8065b0055a4SMatthias Springer 8075b0055a4SMatthias Springer // Initialize dimensional and symbolic variables. 8085b0055a4SMatthias Springer for (unsigned i = 0; i < numDims; i++) 8095b0055a4SMatthias Springer memo[i] = getAffineDimExpr(i, context); 8105b0055a4SMatthias Springer for (unsigned i = numDims, e = numDims + numSyms; i < e; i++) 8115b0055a4SMatthias Springer memo[i] = getAffineSymbolExpr(i - numDims, context); 8125b0055a4SMatthias Springer 8135b0055a4SMatthias Springer bool changed; 8145b0055a4SMatthias Springer do { 8155b0055a4SMatthias Springer // Each time `changed` is true at the end of this iteration, one or more 8165b0055a4SMatthias Springer // local vars would have been detected as floordivs and set in memo; so the 8175b0055a4SMatthias Springer // number of null entries in memo[...] strictly reduces; so this converges. 8185b0055a4SMatthias Springer changed = false; 8195b0055a4SMatthias Springer for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) 8205b0055a4SMatthias Springer if (!memo[numDims + numSyms + i] && 8215b0055a4SMatthias Springer detectAsFloorDiv(*this, /*pos=*/numDims + numSyms + i, context, memo)) 8225b0055a4SMatthias Springer changed = true; 8235b0055a4SMatthias Springer } while (changed); 8245b0055a4SMatthias Springer 8255b0055a4SMatthias Springer ArrayRef<AffineExpr> localExprs = 8265b0055a4SMatthias Springer ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars()); 8275b0055a4SMatthias Springer return success( 8285b0055a4SMatthias Springer llvm::all_of(localExprs, [](AffineExpr expr) { return expr; })); 8295b0055a4SMatthias Springer } 8305b0055a4SMatthias Springer 8315b0055a4SMatthias Springer IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const { 8325b0055a4SMatthias Springer if (getNumConstraints() == 0) 8335b0055a4SMatthias Springer // Return universal set (always true): 0 == 0. 8345b0055a4SMatthias Springer return IntegerSet::get(getNumDimVars(), getNumSymbolVars(), 8355b0055a4SMatthias Springer getAffineConstantExpr(/*constant=*/0, context), 8365b0055a4SMatthias Springer /*eqFlags=*/true); 8375b0055a4SMatthias Springer 8385b0055a4SMatthias Springer // Construct local references. 8395b0055a4SMatthias Springer SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr()); 8405b0055a4SMatthias Springer 8415b0055a4SMatthias Springer if (failed(computeLocalVars(memo, context))) { 8425b0055a4SMatthias Springer // Check if the local variables without an explicit representation have 8435b0055a4SMatthias Springer // zero coefficients everywhere. 8445b0055a4SMatthias Springer SmallVector<unsigned> noLocalRepVars; 8455b0055a4SMatthias Springer unsigned numDimsSymbols = getNumDimAndSymbolVars(); 8465b0055a4SMatthias Springer for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) { 8475b0055a4SMatthias Springer if (!memo[i] && !isColZero(/*pos=*/i)) 8485b0055a4SMatthias Springer noLocalRepVars.push_back(i - numDimsSymbols); 8495b0055a4SMatthias Springer } 8505b0055a4SMatthias Springer if (!noLocalRepVars.empty()) { 8515b0055a4SMatthias Springer LLVM_DEBUG({ 8525b0055a4SMatthias Springer llvm::dbgs() << "local variables at position(s) "; 8535b0055a4SMatthias Springer llvm::interleaveComma(noLocalRepVars, llvm::dbgs()); 8545b0055a4SMatthias Springer llvm::dbgs() << " do not have an explicit representation in:\n"; 8555b0055a4SMatthias Springer this->dump(); 8565b0055a4SMatthias Springer }); 8575b0055a4SMatthias Springer return IntegerSet(); 8585b0055a4SMatthias Springer } 8595b0055a4SMatthias Springer } 8605b0055a4SMatthias Springer 8615b0055a4SMatthias Springer ArrayRef<AffineExpr> localExprs = 8625b0055a4SMatthias Springer ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars()); 8635b0055a4SMatthias Springer 8645b0055a4SMatthias Springer // Construct the IntegerSet from the equalities/inequalities. 8655b0055a4SMatthias Springer unsigned numDims = getNumDimVars(); 8665b0055a4SMatthias Springer unsigned numSyms = getNumSymbolVars(); 8675b0055a4SMatthias Springer 8685b0055a4SMatthias Springer SmallVector<bool, 16> eqFlags(getNumConstraints()); 8695b0055a4SMatthias Springer std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true); 8705b0055a4SMatthias Springer std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false); 8715b0055a4SMatthias Springer 8725b0055a4SMatthias Springer SmallVector<AffineExpr, 8> exprs; 8735b0055a4SMatthias Springer exprs.reserve(getNumConstraints()); 8745b0055a4SMatthias Springer 8755b0055a4SMatthias Springer for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 8765b0055a4SMatthias Springer exprs.push_back(getAffineExprFromFlatForm(getEquality64(i), numDims, 8775b0055a4SMatthias Springer numSyms, localExprs, context)); 8785b0055a4SMatthias Springer for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) 8795b0055a4SMatthias Springer exprs.push_back(getAffineExprFromFlatForm(getInequality64(i), numDims, 8805b0055a4SMatthias Springer numSyms, localExprs, context)); 8815b0055a4SMatthias Springer return IntegerSet::get(numDims, numSyms, exprs, eqFlags); 8825b0055a4SMatthias Springer } 8835b0055a4SMatthias Springer 8845b0055a4SMatthias Springer //===----------------------------------------------------------------------===// 8855b0055a4SMatthias Springer // FlatLinearValueConstraints 8865b0055a4SMatthias Springer //===----------------------------------------------------------------------===// 8875b0055a4SMatthias Springer 8885b0055a4SMatthias Springer // Construct from an IntegerSet. 8895b0055a4SMatthias Springer FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set, 8905b0055a4SMatthias Springer ValueRange operands) 8915b0055a4SMatthias Springer : FlatLinearConstraints(set.getNumInequalities(), set.getNumEqualities(), 8925b0055a4SMatthias Springer set.getNumDims() + set.getNumSymbols() + 1, 8935b0055a4SMatthias Springer set.getNumDims(), set.getNumSymbols(), 8945b0055a4SMatthias Springer /*numLocals=*/0) { 895*a24c4687SAlexander Pivovarov assert((operands.empty() || set.getNumInputs() == operands.size()) && 896*a24c4687SAlexander Pivovarov "operand count mismatch"); 89724da7fa0SBharathi Ramana Joshi // Set the values for the non-local variables. 89824da7fa0SBharathi Ramana Joshi for (unsigned i = 0, e = operands.size(); i < e; ++i) 89924da7fa0SBharathi Ramana Joshi setValue(i, operands[i]); 9005b0055a4SMatthias Springer 9015b0055a4SMatthias Springer // Flatten expressions and add them to the constraint system. 9025b0055a4SMatthias Springer std::vector<SmallVector<int64_t, 8>> flatExprs; 9035b0055a4SMatthias Springer FlatLinearConstraints localVarCst; 9045b0055a4SMatthias Springer if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) { 9055b0055a4SMatthias Springer assert(false && "flattening unimplemented for semi-affine integer sets"); 9065b0055a4SMatthias Springer return; 9075b0055a4SMatthias Springer } 9085b0055a4SMatthias Springer assert(flatExprs.size() == set.getNumConstraints()); 9095b0055a4SMatthias Springer insertVar(VarKind::Local, getNumVarKind(VarKind::Local), 9105b0055a4SMatthias Springer /*num=*/localVarCst.getNumLocalVars()); 9115b0055a4SMatthias Springer 9125b0055a4SMatthias Springer for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { 9135b0055a4SMatthias Springer const auto &flatExpr = flatExprs[i]; 9145b0055a4SMatthias Springer assert(flatExpr.size() == getNumCols()); 9155b0055a4SMatthias Springer if (set.getEqFlags()[i]) { 9165b0055a4SMatthias Springer addEquality(flatExpr); 9175b0055a4SMatthias Springer } else { 9185b0055a4SMatthias Springer addInequality(flatExpr); 9195b0055a4SMatthias Springer } 9205b0055a4SMatthias Springer } 9215b0055a4SMatthias Springer // Add the other constraints involving local vars from flattening. 9225b0055a4SMatthias Springer append(localVarCst); 9235b0055a4SMatthias Springer } 9245b0055a4SMatthias Springer 9255b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::appendDimVar(ValueRange vals) { 9265b0055a4SMatthias Springer unsigned pos = getNumDimVars(); 9275b0055a4SMatthias Springer return insertVar(VarKind::SetDim, pos, vals); 9285b0055a4SMatthias Springer } 9295b0055a4SMatthias Springer 9305b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::appendSymbolVar(ValueRange vals) { 9315b0055a4SMatthias Springer unsigned pos = getNumSymbolVars(); 9325b0055a4SMatthias Springer return insertVar(VarKind::Symbol, pos, vals); 9335b0055a4SMatthias Springer } 9345b0055a4SMatthias Springer 9355b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::insertDimVar(unsigned pos, 9365b0055a4SMatthias Springer ValueRange vals) { 9375b0055a4SMatthias Springer return insertVar(VarKind::SetDim, pos, vals); 9385b0055a4SMatthias Springer } 9395b0055a4SMatthias Springer 9405b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::insertSymbolVar(unsigned pos, 9415b0055a4SMatthias Springer ValueRange vals) { 9425b0055a4SMatthias Springer return insertVar(VarKind::Symbol, pos, vals); 9435b0055a4SMatthias Springer } 9445b0055a4SMatthias Springer 9455b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos, 9465b0055a4SMatthias Springer unsigned num) { 9475b0055a4SMatthias Springer unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num); 9485b0055a4SMatthias Springer 9495b0055a4SMatthias Springer return absolutePos; 9505b0055a4SMatthias Springer } 9515b0055a4SMatthias Springer 9525b0055a4SMatthias Springer unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos, 9535b0055a4SMatthias Springer ValueRange vals) { 9545b0055a4SMatthias Springer assert(!vals.empty() && "expected ValueRange with Values."); 9555b0055a4SMatthias Springer assert(kind != VarKind::Local && 9565b0055a4SMatthias Springer "values cannot be attached to local variables."); 9575b0055a4SMatthias Springer unsigned num = vals.size(); 9585b0055a4SMatthias Springer unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num); 9595b0055a4SMatthias Springer 9601b2247d9SKazu Hirata // If a Value is provided, insert it; otherwise use std::nullopt. 96124da7fa0SBharathi Ramana Joshi for (unsigned i = 0, e = vals.size(); i < e; ++i) 96224da7fa0SBharathi Ramana Joshi if (vals[i]) 96324da7fa0SBharathi Ramana Joshi setValue(absolutePos + i, vals[i]); 9645b0055a4SMatthias Springer 9655b0055a4SMatthias Springer return absolutePos; 9665b0055a4SMatthias Springer } 9675b0055a4SMatthias Springer 9685b0055a4SMatthias Springer /// Checks if two constraint systems are in the same space, i.e., if they are 9695b0055a4SMatthias Springer /// associated with the same set of variables, appearing in the same order. 9705b0055a4SMatthias Springer static bool areVarsAligned(const FlatLinearValueConstraints &a, 9715b0055a4SMatthias Springer const FlatLinearValueConstraints &b) { 97224da7fa0SBharathi Ramana Joshi if (a.getNumDomainVars() != b.getNumDomainVars() || 97324da7fa0SBharathi Ramana Joshi a.getNumRangeVars() != b.getNumRangeVars() || 97424da7fa0SBharathi Ramana Joshi a.getNumSymbolVars() != b.getNumSymbolVars()) 97524da7fa0SBharathi Ramana Joshi return false; 97624da7fa0SBharathi Ramana Joshi SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(), 97724da7fa0SBharathi Ramana Joshi bMaybeValues = b.getMaybeValues(); 97824da7fa0SBharathi Ramana Joshi return std::equal(aMaybeValues.begin(), aMaybeValues.end(), 97924da7fa0SBharathi Ramana Joshi bMaybeValues.begin(), bMaybeValues.end()); 9805b0055a4SMatthias Springer } 9815b0055a4SMatthias Springer 9825b0055a4SMatthias Springer /// Calls areVarsAligned to check if two constraint systems have the same set 9835b0055a4SMatthias Springer /// of variables in the same order. 9845b0055a4SMatthias Springer bool FlatLinearValueConstraints::areVarsAlignedWithOther( 9855b0055a4SMatthias Springer const FlatLinearConstraints &other) { 9865b0055a4SMatthias Springer return areVarsAligned(*this, other); 9875b0055a4SMatthias Springer } 9885b0055a4SMatthias Springer 9895b0055a4SMatthias Springer /// Checks if the SSA values associated with `cst`'s variables in range 9905b0055a4SMatthias Springer /// [start, end) are unique. 9915b0055a4SMatthias Springer static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( 9925b0055a4SMatthias Springer const FlatLinearValueConstraints &cst, unsigned start, unsigned end) { 9935b0055a4SMatthias Springer 9945b0055a4SMatthias Springer assert(start <= cst.getNumDimAndSymbolVars() && 9955b0055a4SMatthias Springer "Start position out of bounds"); 9965b0055a4SMatthias Springer assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds"); 9975b0055a4SMatthias Springer 9985b0055a4SMatthias Springer if (start >= end) 9995b0055a4SMatthias Springer return true; 10005b0055a4SMatthias Springer 10015b0055a4SMatthias Springer SmallPtrSet<Value, 8> uniqueVars; 100224da7fa0SBharathi Ramana Joshi SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues(); 100324da7fa0SBharathi Ramana Joshi ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start, 100424da7fa0SBharathi Ramana Joshi maybeValuesAll.data() + end}; 100524da7fa0SBharathi Ramana Joshi 100624da7fa0SBharathi Ramana Joshi for (std::optional<Value> val : maybeValues) 10075b0055a4SMatthias Springer if (val && !uniqueVars.insert(*val).second) 10085b0055a4SMatthias Springer return false; 100924da7fa0SBharathi Ramana Joshi 10105b0055a4SMatthias Springer return true; 10115b0055a4SMatthias Springer } 10125b0055a4SMatthias Springer 10135b0055a4SMatthias Springer /// Checks if the SSA values associated with `cst`'s variables are unique. 10145b0055a4SMatthias Springer static bool LLVM_ATTRIBUTE_UNUSED 10155b0055a4SMatthias Springer areVarsUnique(const FlatLinearValueConstraints &cst) { 10165b0055a4SMatthias Springer return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars()); 10175b0055a4SMatthias Springer } 10185b0055a4SMatthias Springer 10195b0055a4SMatthias Springer /// Checks if the SSA values associated with `cst`'s variables of kind `kind` 10205b0055a4SMatthias Springer /// are unique. 10215b0055a4SMatthias Springer static bool LLVM_ATTRIBUTE_UNUSED 10225b0055a4SMatthias Springer areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) { 10235b0055a4SMatthias Springer 10245b0055a4SMatthias Springer if (kind == VarKind::SetDim) 10255b0055a4SMatthias Springer return areVarsUnique(cst, 0, cst.getNumDimVars()); 10265b0055a4SMatthias Springer if (kind == VarKind::Symbol) 10275b0055a4SMatthias Springer return areVarsUnique(cst, cst.getNumDimVars(), 10285b0055a4SMatthias Springer cst.getNumDimAndSymbolVars()); 10295b0055a4SMatthias Springer llvm_unreachable("Unexpected VarKind"); 10305b0055a4SMatthias Springer } 10315b0055a4SMatthias Springer 10325b0055a4SMatthias Springer /// Merge and align the variables of A and B starting at 'offset', so that 10335b0055a4SMatthias Springer /// both constraint systems get the union of the contained variables that is 10345b0055a4SMatthias Springer /// dimension-wise and symbol-wise unique; both constraint systems are updated 10355b0055a4SMatthias Springer /// so that they have the union of all variables, with A's original 10365b0055a4SMatthias Springer /// variables appearing first followed by any of B's variables that didn't 10375b0055a4SMatthias Springer /// appear in A. Local variables in B that have the same division 1038c79ffb02SUday Bondhugula /// representation as local variables in A are merged into one. We allow A 1039c79ffb02SUday Bondhugula /// and B to have non-unique values for their variables; in such cases, they are 1040c79ffb02SUday Bondhugula /// still aligned with the variables appearing first aligned with those 1041c79ffb02SUday Bondhugula /// appearing first in the other system from left to right. 10425b0055a4SMatthias Springer // E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M]) 10435b0055a4SMatthias Springer // Output: both A, B have (%i, %j, %k) [%M, %N, %P] 10445b0055a4SMatthias Springer static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a, 10455b0055a4SMatthias Springer FlatLinearValueConstraints *b) { 10465b0055a4SMatthias Springer assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars()); 10475b0055a4SMatthias Springer 10485b0055a4SMatthias Springer assert(llvm::all_of( 10495b0055a4SMatthias Springer llvm::drop_begin(a->getMaybeValues(), offset), 10505b0055a4SMatthias Springer [](const std::optional<Value> &var) { return var.has_value(); })); 10515b0055a4SMatthias Springer 10525b0055a4SMatthias Springer assert(llvm::all_of( 10535b0055a4SMatthias Springer llvm::drop_begin(b->getMaybeValues(), offset), 10545b0055a4SMatthias Springer [](const std::optional<Value> &var) { return var.has_value(); })); 10555b0055a4SMatthias Springer 10565b0055a4SMatthias Springer SmallVector<Value, 4> aDimValues; 10575b0055a4SMatthias Springer a->getValues(offset, a->getNumDimVars(), &aDimValues); 10585b0055a4SMatthias Springer 10595b0055a4SMatthias Springer { 10605b0055a4SMatthias Springer // Merge dims from A into B. 10615b0055a4SMatthias Springer unsigned d = offset; 1062c79ffb02SUday Bondhugula for (Value aDimValue : aDimValues) { 10635b0055a4SMatthias Springer unsigned loc; 1064c79ffb02SUday Bondhugula // Find from the position `d` since we'd like to also consider the 1065c79ffb02SUday Bondhugula // possibility of multiple variables with the same `Value`. We align with 1066c79ffb02SUday Bondhugula // the next appearing one. 1067c79ffb02SUday Bondhugula if (b->findVar(aDimValue, &loc, d)) { 10685b0055a4SMatthias Springer assert(loc >= offset && "A's dim appears in B's aligned range"); 10695b0055a4SMatthias Springer assert(loc < b->getNumDimVars() && 10705b0055a4SMatthias Springer "A's dim appears in B's non-dim position"); 10715b0055a4SMatthias Springer b->swapVar(d, loc); 10725b0055a4SMatthias Springer } else { 10735b0055a4SMatthias Springer b->insertDimVar(d, aDimValue); 10745b0055a4SMatthias Springer } 10755b0055a4SMatthias Springer d++; 10765b0055a4SMatthias Springer } 10775b0055a4SMatthias Springer // Dimensions that are in B, but not in A, are added at the end. 10785b0055a4SMatthias Springer for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) { 10795b0055a4SMatthias Springer a->appendDimVar(b->getValue(t)); 10805b0055a4SMatthias Springer } 10815b0055a4SMatthias Springer assert(a->getNumDimVars() == b->getNumDimVars() && 10825b0055a4SMatthias Springer "expected same number of dims"); 10835b0055a4SMatthias Springer } 10845b0055a4SMatthias Springer 10855b0055a4SMatthias Springer // Merge and align symbols of A and B 10865b0055a4SMatthias Springer a->mergeSymbolVars(*b); 10875b0055a4SMatthias Springer // Merge and align locals of A and B 10885b0055a4SMatthias Springer a->mergeLocalVars(*b); 10895b0055a4SMatthias Springer 10905b0055a4SMatthias Springer assert(areVarsAligned(*a, *b) && "IDs expected to be aligned"); 10915b0055a4SMatthias Springer } 10925b0055a4SMatthias Springer 10935b0055a4SMatthias Springer // Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'. 10945b0055a4SMatthias Springer void FlatLinearValueConstraints::mergeAndAlignVarsWithOther( 10955b0055a4SMatthias Springer unsigned offset, FlatLinearValueConstraints *other) { 10965b0055a4SMatthias Springer mergeAndAlignVars(offset, this, other); 10975b0055a4SMatthias Springer } 10985b0055a4SMatthias Springer 10995b0055a4SMatthias Springer /// Merge and align symbols of `this` and `other` such that both get union of 1100c79ffb02SUday Bondhugula /// of symbols. Existing symbols need not be unique; they will be aligned from 1101c79ffb02SUday Bondhugula /// left to right with duplicates aligned in the same order. Symbols with Value 1102c79ffb02SUday Bondhugula /// as `None` are considered to be inequal to all other symbols. 11035b0055a4SMatthias Springer void FlatLinearValueConstraints::mergeSymbolVars( 11045b0055a4SMatthias Springer FlatLinearValueConstraints &other) { 11055b0055a4SMatthias Springer 11065b0055a4SMatthias Springer SmallVector<Value, 4> aSymValues; 11075b0055a4SMatthias Springer getValues(getNumDimVars(), getNumDimAndSymbolVars(), &aSymValues); 11085b0055a4SMatthias Springer 11095b0055a4SMatthias Springer // Merge symbols: merge symbols into `other` first from `this`. 11105b0055a4SMatthias Springer unsigned s = other.getNumDimVars(); 11115b0055a4SMatthias Springer for (Value aSymValue : aSymValues) { 11125b0055a4SMatthias Springer unsigned loc; 11135b0055a4SMatthias Springer // If the var is a symbol in `other`, then align it, otherwise assume that 1114c79ffb02SUday Bondhugula // it is a new symbol. Search in `other` starting at position `s` since the 1115c79ffb02SUday Bondhugula // left of it is aligned. 1116c79ffb02SUday Bondhugula if (other.findVar(aSymValue, &loc, s) && loc >= other.getNumDimVars() && 11175b0055a4SMatthias Springer loc < other.getNumDimAndSymbolVars()) 11185b0055a4SMatthias Springer other.swapVar(s, loc); 11195b0055a4SMatthias Springer else 11205b0055a4SMatthias Springer other.insertSymbolVar(s - other.getNumDimVars(), aSymValue); 11215b0055a4SMatthias Springer s++; 11225b0055a4SMatthias Springer } 11235b0055a4SMatthias Springer 11245b0055a4SMatthias Springer // Symbols that are in other, but not in this, are added at the end. 11255b0055a4SMatthias Springer for (unsigned t = other.getNumDimVars() + getNumSymbolVars(), 11265b0055a4SMatthias Springer e = other.getNumDimAndSymbolVars(); 11275b0055a4SMatthias Springer t < e; t++) 11285b0055a4SMatthias Springer insertSymbolVar(getNumSymbolVars(), other.getValue(t)); 11295b0055a4SMatthias Springer 11305b0055a4SMatthias Springer assert(getNumSymbolVars() == other.getNumSymbolVars() && 11315b0055a4SMatthias Springer "expected same number of symbols"); 11325b0055a4SMatthias Springer } 11335b0055a4SMatthias Springer 11345b0055a4SMatthias Springer void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart, 11355b0055a4SMatthias Springer unsigned varLimit) { 11365b0055a4SMatthias Springer IntegerPolyhedron::removeVarRange(kind, varStart, varLimit); 11375b0055a4SMatthias Springer } 11385b0055a4SMatthias Springer 11395b0055a4SMatthias Springer AffineMap 11405b0055a4SMatthias Springer FlatLinearValueConstraints::computeAlignedMap(AffineMap map, 11415b0055a4SMatthias Springer ValueRange operands) const { 11425b0055a4SMatthias Springer assert(map.getNumInputs() == operands.size() && "number of inputs mismatch"); 11435b0055a4SMatthias Springer 11445b0055a4SMatthias Springer SmallVector<Value> dims, syms; 11455b0055a4SMatthias Springer #ifndef NDEBUG 11465b0055a4SMatthias Springer SmallVector<Value> newSyms; 11475b0055a4SMatthias Springer SmallVector<Value> *newSymsPtr = &newSyms; 11485b0055a4SMatthias Springer #else 11495b0055a4SMatthias Springer SmallVector<Value> *newSymsPtr = nullptr; 11505b0055a4SMatthias Springer #endif // NDEBUG 11515b0055a4SMatthias Springer 11525b0055a4SMatthias Springer dims.reserve(getNumDimVars()); 11535b0055a4SMatthias Springer syms.reserve(getNumSymbolVars()); 115424da7fa0SBharathi Ramana Joshi for (unsigned i = 0, e = getNumVarKind(VarKind::SetDim); i < e; ++i) { 115524da7fa0SBharathi Ramana Joshi Identifier id = space.getId(VarKind::SetDim, i); 115624da7fa0SBharathi Ramana Joshi dims.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value()); 115724da7fa0SBharathi Ramana Joshi } 115824da7fa0SBharathi Ramana Joshi for (unsigned i = 0, e = getNumVarKind(VarKind::Symbol); i < e; ++i) { 115924da7fa0SBharathi Ramana Joshi Identifier id = space.getId(VarKind::Symbol, i); 116024da7fa0SBharathi Ramana Joshi syms.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value()); 116124da7fa0SBharathi Ramana Joshi } 11625b0055a4SMatthias Springer 11635b0055a4SMatthias Springer AffineMap alignedMap = 11645b0055a4SMatthias Springer alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr); 11655b0055a4SMatthias Springer // All symbols are already part of this FlatAffineValueConstraints. 11665b0055a4SMatthias Springer assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols"); 11675b0055a4SMatthias Springer assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) && 11685b0055a4SMatthias Springer "unexpected new/missing symbols"); 11695b0055a4SMatthias Springer return alignedMap; 11705b0055a4SMatthias Springer } 11715b0055a4SMatthias Springer 1172c79ffb02SUday Bondhugula bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos, 1173c79ffb02SUday Bondhugula unsigned offset) const { 117424da7fa0SBharathi Ramana Joshi SmallVector<std::optional<Value>> maybeValues = getMaybeValues(); 117524da7fa0SBharathi Ramana Joshi for (unsigned i = offset, e = maybeValues.size(); i < e; ++i) 117624da7fa0SBharathi Ramana Joshi if (maybeValues[i] && maybeValues[i].value() == val) { 11775b0055a4SMatthias Springer *pos = i; 11785b0055a4SMatthias Springer return true; 11795b0055a4SMatthias Springer } 11805b0055a4SMatthias Springer return false; 11815b0055a4SMatthias Springer } 11825b0055a4SMatthias Springer 11835b0055a4SMatthias Springer bool FlatLinearValueConstraints::containsVar(Value val) const { 118424da7fa0SBharathi Ramana Joshi unsigned pos; 118524da7fa0SBharathi Ramana Joshi return findVar(val, &pos, 0); 11865b0055a4SMatthias Springer } 11875b0055a4SMatthias Springer 11885b0055a4SMatthias Springer void FlatLinearValueConstraints::addBound(BoundType type, Value val, 11895b0055a4SMatthias Springer int64_t value) { 11905b0055a4SMatthias Springer unsigned pos; 11915b0055a4SMatthias Springer if (!findVar(val, &pos)) 11925b0055a4SMatthias Springer // This is a pre-condition for this method. 11935b0055a4SMatthias Springer assert(0 && "var not found"); 11945b0055a4SMatthias Springer addBound(type, pos, value); 11955b0055a4SMatthias Springer } 11965b0055a4SMatthias Springer 11975b0055a4SMatthias Springer void FlatLinearConstraints::printSpace(raw_ostream &os) const { 11985b0055a4SMatthias Springer IntegerPolyhedron::printSpace(os); 11995b0055a4SMatthias Springer os << "("; 12005b0055a4SMatthias Springer for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) 12015b0055a4SMatthias Springer os << "None\t"; 12025b0055a4SMatthias Springer for (unsigned i = getVarKindOffset(VarKind::Local), 12035b0055a4SMatthias Springer e = getVarKindEnd(VarKind::Local); 12045b0055a4SMatthias Springer i < e; ++i) 12055b0055a4SMatthias Springer os << "Local\t"; 12065b0055a4SMatthias Springer os << "const)\n"; 12075b0055a4SMatthias Springer } 12085b0055a4SMatthias Springer 12095b0055a4SMatthias Springer void FlatLinearValueConstraints::printSpace(raw_ostream &os) const { 12105b0055a4SMatthias Springer IntegerPolyhedron::printSpace(os); 12115b0055a4SMatthias Springer os << "("; 12125b0055a4SMatthias Springer for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) { 12135b0055a4SMatthias Springer if (hasValue(i)) 12145b0055a4SMatthias Springer os << "Value\t"; 12155b0055a4SMatthias Springer else 12165b0055a4SMatthias Springer os << "None\t"; 12175b0055a4SMatthias Springer } 12185b0055a4SMatthias Springer for (unsigned i = getVarKindOffset(VarKind::Local), 12195b0055a4SMatthias Springer e = getVarKindEnd(VarKind::Local); 12205b0055a4SMatthias Springer i < e; ++i) 12215b0055a4SMatthias Springer os << "Local\t"; 12225b0055a4SMatthias Springer os << "const)\n"; 12235b0055a4SMatthias Springer } 12245b0055a4SMatthias Springer 12255b0055a4SMatthias Springer void FlatLinearValueConstraints::projectOut(Value val) { 12265b0055a4SMatthias Springer unsigned pos; 12275b0055a4SMatthias Springer bool ret = findVar(val, &pos); 12285b0055a4SMatthias Springer assert(ret); 12295b0055a4SMatthias Springer (void)ret; 12305b0055a4SMatthias Springer fourierMotzkinEliminate(pos); 12315b0055a4SMatthias Springer } 12325b0055a4SMatthias Springer 12335b0055a4SMatthias Springer LogicalResult FlatLinearValueConstraints::unionBoundingBox( 12345b0055a4SMatthias Springer const FlatLinearValueConstraints &otherCst) { 12355b0055a4SMatthias Springer assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch"); 123624da7fa0SBharathi Ramana Joshi SmallVector<std::optional<Value>> maybeValues = getMaybeValues(), 123724da7fa0SBharathi Ramana Joshi otherMaybeValues = 123824da7fa0SBharathi Ramana Joshi otherCst.getMaybeValues(); 123924da7fa0SBharathi Ramana Joshi assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(), 124024da7fa0SBharathi Ramana Joshi otherMaybeValues.begin(), 124124da7fa0SBharathi Ramana Joshi otherMaybeValues.begin() + getNumDimVars()) && 12425b0055a4SMatthias Springer "dim values mismatch"); 12435b0055a4SMatthias Springer assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here"); 12445b0055a4SMatthias Springer assert(getNumLocalVars() == 0 && "local vars not supported yet here"); 12455b0055a4SMatthias Springer 12465b0055a4SMatthias Springer // Align `other` to this. 12475b0055a4SMatthias Springer if (!areVarsAligned(*this, otherCst)) { 12485b0055a4SMatthias Springer FlatLinearValueConstraints otherCopy(otherCst); 12495b0055a4SMatthias Springer mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy); 1250f819302aSRamkumar Ramachandra return IntegerPolyhedron::unionBoundingBox(otherCopy); 12515b0055a4SMatthias Springer } 12525b0055a4SMatthias Springer 1253f819302aSRamkumar Ramachandra return IntegerPolyhedron::unionBoundingBox(otherCst); 12545b0055a4SMatthias Springer } 12555b0055a4SMatthias Springer 12565b0055a4SMatthias Springer //===----------------------------------------------------------------------===// 12575b0055a4SMatthias Springer // Helper functions 12585b0055a4SMatthias Springer //===----------------------------------------------------------------------===// 12595b0055a4SMatthias Springer 12605b0055a4SMatthias Springer AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands, 12615b0055a4SMatthias Springer ValueRange dims, ValueRange syms, 12625b0055a4SMatthias Springer SmallVector<Value> *newSyms) { 12635b0055a4SMatthias Springer assert(operands.size() == map.getNumInputs() && 12645b0055a4SMatthias Springer "expected same number of operands and map inputs"); 12655b0055a4SMatthias Springer MLIRContext *ctx = map.getContext(); 12665b0055a4SMatthias Springer Builder builder(ctx); 12675b0055a4SMatthias Springer SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {}); 12685b0055a4SMatthias Springer unsigned numSymbols = syms.size(); 12695b0055a4SMatthias Springer SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {}); 12705b0055a4SMatthias Springer if (newSyms) { 12715b0055a4SMatthias Springer newSyms->clear(); 12725b0055a4SMatthias Springer newSyms->append(syms.begin(), syms.end()); 12735b0055a4SMatthias Springer } 12745b0055a4SMatthias Springer 12755b0055a4SMatthias Springer for (const auto &operand : llvm::enumerate(operands)) { 12765b0055a4SMatthias Springer // Compute replacement dim/sym of operand. 12775b0055a4SMatthias Springer AffineExpr replacement; 12783f0bddb5SKazu Hirata auto dimIt = llvm::find(dims, operand.value()); 12793f0bddb5SKazu Hirata auto symIt = llvm::find(syms, operand.value()); 12805b0055a4SMatthias Springer if (dimIt != dims.end()) { 12815b0055a4SMatthias Springer replacement = 12825b0055a4SMatthias Springer builder.getAffineDimExpr(std::distance(dims.begin(), dimIt)); 12835b0055a4SMatthias Springer } else if (symIt != syms.end()) { 12845b0055a4SMatthias Springer replacement = 12855b0055a4SMatthias Springer builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt)); 12865b0055a4SMatthias Springer } else { 12875b0055a4SMatthias Springer // This operand is neither a dimension nor a symbol. Add it as a new 12885b0055a4SMatthias Springer // symbol. 12895b0055a4SMatthias Springer replacement = builder.getAffineSymbolExpr(numSymbols++); 12905b0055a4SMatthias Springer if (newSyms) 12915b0055a4SMatthias Springer newSyms->push_back(operand.value()); 12925b0055a4SMatthias Springer } 12935b0055a4SMatthias Springer // Add to corresponding replacements vector. 12945b0055a4SMatthias Springer if (operand.index() < map.getNumDims()) { 12955b0055a4SMatthias Springer dimReplacements[operand.index()] = replacement; 12965b0055a4SMatthias Springer } else { 12975b0055a4SMatthias Springer symReplacements[operand.index() - map.getNumDims()] = replacement; 12985b0055a4SMatthias Springer } 12995b0055a4SMatthias Springer } 13005b0055a4SMatthias Springer 13015b0055a4SMatthias Springer return map.replaceDimsAndSymbols(dimReplacements, symReplacements, 13025b0055a4SMatthias Springer dims.size(), numSymbols); 13035b0055a4SMatthias Springer } 13045b0055a4SMatthias Springer 13055b0055a4SMatthias Springer LogicalResult 13065b0055a4SMatthias Springer mlir::getMultiAffineFunctionFromMap(AffineMap map, 13075b0055a4SMatthias Springer MultiAffineFunction &multiAff) { 13085b0055a4SMatthias Springer FlatLinearConstraints cst; 13095b0055a4SMatthias Springer std::vector<SmallVector<int64_t, 8>> flattenedExprs; 13105b0055a4SMatthias Springer LogicalResult result = getFlattenedAffineExprs(map, &flattenedExprs, &cst); 13115b0055a4SMatthias Springer 13125b0055a4SMatthias Springer if (result.failed()) 13135b0055a4SMatthias Springer return failure(); 13145b0055a4SMatthias Springer 13155b0055a4SMatthias Springer DivisionRepr divs = cst.getLocalReprs(); 13165b0055a4SMatthias Springer assert(divs.hasAllReprs() && 13175b0055a4SMatthias Springer "AffineMap cannot produce divs without local representation"); 13185b0055a4SMatthias Springer 13195b0055a4SMatthias Springer // TODO: We shouldn't have to do this conversion. 13201a0e67d7SRamkumar Ramachandra Matrix<DynamicAPInt> mat(map.getNumResults(), 13211a0e67d7SRamkumar Ramachandra map.getNumInputs() + divs.getNumDivs() + 1); 13225b0055a4SMatthias Springer for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i) 13235b0055a4SMatthias Springer for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j) 13245b0055a4SMatthias Springer mat(i, j) = flattenedExprs[i][j]; 13255b0055a4SMatthias Springer 13265b0055a4SMatthias Springer multiAff = MultiAffineFunction( 13275b0055a4SMatthias Springer PresburgerSpace::getRelationSpace(map.getNumDims(), map.getNumResults(), 13285b0055a4SMatthias Springer map.getNumSymbols(), divs.getNumDivs()), 13295b0055a4SMatthias Springer mat, divs); 13305b0055a4SMatthias Springer 13315b0055a4SMatthias Springer return success(); 13325b0055a4SMatthias Springer } 1333