1 //===- FlatLinearValueConstraints.cpp - Linear Constraint -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/FlatLinearValueConstraints.h" 10 11 #include "mlir/Analysis/Presburger/LinearTransform.h" 12 #include "mlir/Analysis/Presburger/PresburgerSpace.h" 13 #include "mlir/Analysis/Presburger/Simplex.h" 14 #include "mlir/Analysis/Presburger/Utils.h" 15 #include "mlir/IR/AffineExprVisitor.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/IntegerSet.h" 18 #include "mlir/Support/LLVM.h" 19 #include "llvm/ADT/STLExtras.h" 20 #include "llvm/ADT/SmallPtrSet.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/raw_ostream.h" 24 #include <optional> 25 26 #define DEBUG_TYPE "flat-value-constraints" 27 28 using namespace mlir; 29 using namespace presburger; 30 31 //===----------------------------------------------------------------------===// 32 // AffineExprFlattener 33 //===----------------------------------------------------------------------===// 34 35 namespace { 36 37 // See comments for SimpleAffineExprFlattener. 38 // An AffineExprFlattenerWithLocalVars extends a SimpleAffineExprFlattener by 39 // recording constraint information associated with mod's, floordiv's, and 40 // ceildiv's in FlatLinearConstraints 'localVarCst'. 41 struct AffineExprFlattener : public SimpleAffineExprFlattener { 42 using SimpleAffineExprFlattener::SimpleAffineExprFlattener; 43 44 // Constraints connecting newly introduced local variables (for mod's and 45 // div's) to existing (dimensional and symbolic) ones. These are always 46 // inequalities. 47 IntegerPolyhedron localVarCst; 48 49 AffineExprFlattener(unsigned nDims, unsigned nSymbols) 50 : SimpleAffineExprFlattener(nDims, nSymbols), 51 localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {}; 52 53 private: 54 // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr). 55 // The local variable added is always a floordiv of a pure add/mul affine 56 // function of other variables, coefficients of which are specified in 57 // `dividend' and with respect to the positive constant `divisor'. localExpr 58 // is the simplified tree expression (AffineExpr) corresponding to the 59 // quantifier. 60 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, 61 AffineExpr localExpr) override { 62 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); 63 // Update localVarCst. 64 localVarCst.addLocalFloorDiv(dividend, divisor); 65 } 66 67 LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs, 68 ArrayRef<int64_t> rhs, 69 AffineExpr localExpr) override { 70 // AffineExprFlattener does not support semi-affine expressions. 71 return failure(); 72 } 73 }; 74 75 // A SemiAffineExprFlattener is an AffineExprFlattenerWithLocalVars that adds 76 // conservative bounds for semi-affine expressions (given assumptions hold). If 77 // the assumptions required to add the semi-affine bounds are found not to hold 78 // the final constraints set will be empty/inconsistent. If the assumptions are 79 // never contradicted the final bounds still only will be correct if the 80 // assumptions hold. 81 struct SemiAffineExprFlattener : public AffineExprFlattener { 82 using AffineExprFlattener::AffineExprFlattener; 83 84 LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs, 85 ArrayRef<int64_t> rhs, 86 AffineExpr localExpr) override { 87 auto result = 88 SimpleAffineExprFlattener::addLocalIdSemiAffine(lhs, rhs, localExpr); 89 assert(succeeded(result) && 90 "unexpected failure in SimpleAffineExprFlattener"); 91 (void)result; 92 93 if (localExpr.getKind() == AffineExprKind::Mod) { 94 // Given two numbers a and b, division is defined as: 95 // 96 // a = bq + r 97 // 0 <= r < |b| (where |x| is the absolute value of x) 98 // 99 // q = a floordiv b 100 // r = a mod b 101 102 // Add a new local variable (r) to represent the mod. 103 unsigned rPos = localVarCst.appendVar(VarKind::Local); 104 105 // r >= 0 (Can ALWAYS be added) 106 localVarCst.addBound(BoundType::LB, rPos, 0); 107 108 // r < b (Can be added if b > 0, which we assume here) 109 ArrayRef<int64_t> b = rhs; 110 SmallVector<int64_t> bSubR(b); 111 bSubR.insert(bSubR.begin() + rPos, -1); 112 // Note: bSubR = b - r 113 // So this adds the bound b - r >= 1 (equivalent to r < b) 114 localVarCst.addBound(BoundType::LB, bSubR, 1); 115 116 // Note: The assumption of b > 0 is based on the affine expression docs, 117 // which state "RHS of mod is always a constant or a symbolic expression 118 // with a positive value." (see AffineExprKind in AffineExpr.h). If this 119 // assumption does not hold constraints (added above) are a contradiction. 120 121 return success(); 122 } 123 124 // TODO: Support other semi-affine expressions. 125 return failure(); 126 } 127 }; 128 129 } // namespace 130 131 // Flattens the expressions in map. Returns failure if 'expr' was unable to be 132 // flattened. For example two specific cases: 133 // 1. an unhandled semi-affine expressions is found. 134 // 2. has poison expression (i.e., division by zero). 135 static LogicalResult 136 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, 137 unsigned numSymbols, 138 std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 139 FlatLinearConstraints *localVarCst, 140 bool addConservativeSemiAffineBounds = false) { 141 if (exprs.empty()) { 142 if (localVarCst) 143 *localVarCst = FlatLinearConstraints(numDims, numSymbols); 144 return success(); 145 } 146 147 auto flattenExprs = [&](AffineExprFlattener &flattener) -> LogicalResult { 148 // Use the same flattener to simplify each expression successively. This way 149 // local variables / expressions are shared. 150 for (auto expr : exprs) { 151 auto flattenResult = flattener.walkPostOrder(expr); 152 if (failed(flattenResult)) 153 return failure(); 154 } 155 156 assert(flattener.operandExprStack.size() == exprs.size()); 157 flattenedExprs->clear(); 158 flattenedExprs->assign(flattener.operandExprStack.begin(), 159 flattener.operandExprStack.end()); 160 161 if (localVarCst) 162 localVarCst->clearAndCopyFrom(flattener.localVarCst); 163 164 return success(); 165 }; 166 167 if (addConservativeSemiAffineBounds) { 168 SemiAffineExprFlattener flattener(numDims, numSymbols); 169 return flattenExprs(flattener); 170 } 171 172 AffineExprFlattener flattener(numDims, numSymbols); 173 return flattenExprs(flattener); 174 } 175 176 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to 177 // be flattened (an unhandled semi-affine was found). 178 LogicalResult mlir::getFlattenedAffineExpr( 179 AffineExpr expr, unsigned numDims, unsigned numSymbols, 180 SmallVectorImpl<int64_t> *flattenedExpr, FlatLinearConstraints *localVarCst, 181 bool addConservativeSemiAffineBounds) { 182 std::vector<SmallVector<int64_t, 8>> flattenedExprs; 183 LogicalResult ret = 184 ::getFlattenedAffineExprs({expr}, numDims, numSymbols, &flattenedExprs, 185 localVarCst, addConservativeSemiAffineBounds); 186 *flattenedExpr = flattenedExprs[0]; 187 return ret; 188 } 189 190 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be 191 /// flattened (i.e., an unhandled semi-affine was found). 192 LogicalResult mlir::getFlattenedAffineExprs( 193 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 194 FlatLinearConstraints *localVarCst, bool addConservativeSemiAffineBounds) { 195 if (map.getNumResults() == 0) { 196 if (localVarCst) 197 *localVarCst = 198 FlatLinearConstraints(map.getNumDims(), map.getNumSymbols()); 199 return success(); 200 } 201 return ::getFlattenedAffineExprs( 202 map.getResults(), map.getNumDims(), map.getNumSymbols(), flattenedExprs, 203 localVarCst, addConservativeSemiAffineBounds); 204 } 205 206 LogicalResult mlir::getFlattenedAffineExprs( 207 IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 208 FlatLinearConstraints *localVarCst) { 209 if (set.getNumConstraints() == 0) { 210 if (localVarCst) 211 *localVarCst = 212 FlatLinearConstraints(set.getNumDims(), set.getNumSymbols()); 213 return success(); 214 } 215 return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), 216 set.getNumSymbols(), flattenedExprs, 217 localVarCst); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // FlatLinearConstraints 222 //===----------------------------------------------------------------------===// 223 224 // Similar to `composeMap` except that no Values need be associated with the 225 // constraint system nor are they looked at -- the dimensions and symbols of 226 // `other` are expected to correspond 1:1 to `this` system. 227 LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) { 228 assert(other.getNumDims() == getNumDimVars() && "dim mismatch"); 229 assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"); 230 231 std::vector<SmallVector<int64_t, 8>> flatExprs; 232 if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs))) 233 return failure(); 234 assert(flatExprs.size() == other.getNumResults()); 235 236 // Add dimensions corresponding to the map's results. 237 insertDimVar(/*pos=*/0, /*num=*/other.getNumResults()); 238 239 // We add one equality for each result connecting the result dim of the map to 240 // the other variables. 241 // E.g.: if the expression is 16*i0 + i1, and this is the r^th 242 // iteration/result of the value map, we are adding the equality: 243 // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we 244 // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. 245 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { 246 const auto &flatExpr = flatExprs[r]; 247 assert(flatExpr.size() >= other.getNumInputs() + 1); 248 249 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0); 250 // Set the coefficient for this result to one. 251 eqToAdd[r] = 1; 252 253 // Dims and symbols. 254 for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) { 255 // Negate `eq[r]` since the newly added dimension will be set to this one. 256 eqToAdd[e + i] = -flatExpr[i]; 257 } 258 // Local columns of `eq` are at the beginning. 259 unsigned j = getNumDimVars() + getNumSymbolVars(); 260 unsigned end = flatExpr.size() - 1; 261 for (unsigned i = other.getNumInputs(); i < end; i++, j++) { 262 eqToAdd[j] = -flatExpr[i]; 263 } 264 265 // Constant term. 266 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; 267 268 // Add the equality connecting the result of the map to this constraint set. 269 addEquality(eqToAdd); 270 } 271 272 return success(); 273 } 274 275 // Determine whether the variable at 'pos' (say var_r) can be expressed as 276 // modulo of another known variable (say var_n) w.r.t a constant. For example, 277 // if the following constraints hold true: 278 // ``` 279 // 0 <= var_r <= divisor - 1 280 // var_n - (divisor * q_expr) = var_r 281 // ``` 282 // where `var_n` is a known variable (called dividend), and `q_expr` is an 283 // `AffineExpr` (called the quotient expression), `var_r` can be written as: 284 // 285 // `var_r = var_n mod divisor`. 286 // 287 // Additionally, in a special case of the above constaints where `q_expr` is an 288 // variable itself that is not yet known (say `var_q`), it can be written as a 289 // floordiv in the following way: 290 // 291 // `var_q = var_n floordiv divisor`. 292 // 293 // First 'num' dimensional variables starting at 'offset' are 294 // derived/to-be-derived in terms of the remaining variables. The remaining 295 // variables are assigned trivial affine expressions in `memo`. For example, 296 // memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2: 297 // memo ==> d0 d1 . . d2 ... 298 // cst ==> c0 c1 c2 c3 c4 ... 299 // 300 // Returns true if the above mod or floordiv are detected, updating 'memo' with 301 // these new expressions. Returns false otherwise. 302 static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos, 303 unsigned offset, unsigned num, int64_t lbConst, 304 int64_t ubConst, MLIRContext *context, 305 SmallVectorImpl<AffineExpr> &memo) { 306 assert(pos < cst.getNumVars() && "invalid position"); 307 308 // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can 309 // be determined. 310 if (lbConst != 0 || ubConst < 1) 311 return false; 312 int64_t divisor = ubConst + 1; 313 314 // Check for the aforementioned conditions in each equality. 315 for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities(); 316 curEquality < numEqualities; curEquality++) { 317 int64_t coefficientAtPos = cst.atEq64(curEquality, pos); 318 // If current equality does not involve `var_r`, continue to the next 319 // equality. 320 if (coefficientAtPos == 0) 321 continue; 322 323 // Constant term should be 0 in this equality. 324 if (cst.atEq64(curEquality, cst.getNumCols() - 1) != 0) 325 continue; 326 327 // Traverse through the equality and construct the dividend expression 328 // `dividendExpr`, to contain all the variables which are known and are 329 // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the 330 // `dividendExpr` gets simplified into a single variable `var_n` discussed 331 // above. 332 auto dividendExpr = getAffineConstantExpr(0, context); 333 334 // Track the terms that go into quotient expression, later used to detect 335 // additional floordiv. 336 unsigned quotientCount = 0; 337 int quotientPosition = -1; 338 int quotientSign = 1; 339 340 // Consider each term in the current equality. 341 unsigned curVar, e; 342 for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) { 343 // Ignore var_r. 344 if (curVar == pos) 345 continue; 346 int64_t coefficientOfCurVar = cst.atEq64(curEquality, curVar); 347 // Ignore vars that do not contribute to the current equality. 348 if (coefficientOfCurVar == 0) 349 continue; 350 // Check if the current var goes into the quotient expression. 351 if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) { 352 quotientCount++; 353 quotientPosition = curVar; 354 quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1; 355 continue; 356 } 357 // Variables that are part of dividendExpr should be known. 358 if (!memo[curVar]) 359 break; 360 // Append the current variable to the dividend expression. 361 dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar; 362 } 363 364 // Can't construct expression as it depends on a yet uncomputed var. 365 if (curVar < e) 366 continue; 367 368 // Express `var_r` in terms of the other vars collected so far. 369 if (coefficientAtPos > 0) 370 dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos); 371 else 372 dividendExpr = dividendExpr.floorDiv(-coefficientAtPos); 373 374 // Simplify the expression. 375 dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimVars(), 376 cst.getNumSymbolVars()); 377 // Only if the final dividend expression is just a single var (which we call 378 // `var_n`), we can proceed. 379 // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it 380 // to dims themselves. 381 auto dimExpr = dyn_cast<AffineDimExpr>(dividendExpr); 382 if (!dimExpr) 383 continue; 384 385 // Express `var_r` as `var_n % divisor` and store the expression in `memo`. 386 if (quotientCount >= 1) { 387 // Find the column corresponding to `dimExpr`. `num` columns starting at 388 // `offset` correspond to previously unknown variables. The column 389 // corresponding to the trivially known `dimExpr` can be on either side 390 // of these. 391 unsigned dimExprPos = dimExpr.getPosition(); 392 unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num; 393 auto ub = cst.getConstantBound64(BoundType::UB, dimExprCol); 394 // If `var_n` has an upperbound that is less than the divisor, mod can be 395 // eliminated altogether. 396 if (ub && *ub < divisor) 397 memo[pos] = dimExpr; 398 else 399 memo[pos] = dimExpr % divisor; 400 // If a unique quotient `var_q` was seen, it can be expressed as 401 // `var_n floordiv divisor`. 402 if (quotientCount == 1 && !memo[quotientPosition]) 403 memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign; 404 405 return true; 406 } 407 } 408 return false; 409 } 410 411 /// Check if the pos^th variable can be expressed as a floordiv of an affine 412 /// function of other variables (where the divisor is a positive constant) 413 /// given the initial set of expressions in `exprs`. If it can be, the 414 /// corresponding position in `exprs` is set as the detected affine expr. For 415 /// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can 416 /// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28 417 /// <= i <= 32q + 31 => q = i floordiv 32. 418 static bool detectAsFloorDiv(const FlatLinearConstraints &cst, unsigned pos, 419 MLIRContext *context, 420 SmallVectorImpl<AffineExpr> &exprs) { 421 assert(pos < cst.getNumVars() && "invalid position"); 422 423 // Get upper-lower bound pair for this variable. 424 SmallVector<bool, 8> foundRepr(cst.getNumVars(), false); 425 for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i) 426 if (exprs[i]) 427 foundRepr[i] = true; 428 429 SmallVector<int64_t, 8> dividend(cst.getNumCols()); 430 unsigned divisor; 431 auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); 432 433 // No upper-lower bound pair found for this var. 434 if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality) 435 return false; 436 437 // Construct the dividend expression. 438 auto dividendExpr = getAffineConstantExpr(dividend.back(), context); 439 for (unsigned c = 0, f = cst.getNumVars(); c < f; c++) 440 if (dividend[c] != 0) 441 dividendExpr = dividendExpr + dividend[c] * exprs[c]; 442 443 // Successfully detected the floordiv. 444 exprs[pos] = dividendExpr.floorDiv(divisor); 445 return true; 446 } 447 448 std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound( 449 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, 450 ArrayRef<AffineExpr> localExprs, MLIRContext *context, 451 bool closedUB) const { 452 assert(pos + offset < getNumDimVars() && "invalid dim start pos"); 453 assert(symStartPos >= (pos + offset) && "invalid sym start pos"); 454 assert(getNumLocalVars() == localExprs.size() && 455 "incorrect local exprs count"); 456 457 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices; 458 getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices, 459 offset, num); 460 461 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). 462 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) { 463 b.clear(); 464 for (unsigned i = 0, e = a.size(); i < e; ++i) { 465 if (i < offset || i >= offset + num) 466 b.push_back(a[i]); 467 } 468 }; 469 470 SmallVector<int64_t, 8> lb, ub; 471 SmallVector<AffineExpr, 4> lbExprs; 472 unsigned dimCount = symStartPos - num; 473 unsigned symCount = getNumDimAndSymbolVars() - symStartPos; 474 lbExprs.reserve(lbIndices.size() + eqIndices.size()); 475 // Lower bound expressions. 476 for (auto idx : lbIndices) { 477 auto ineq = getInequality64(idx); 478 // Extract the lower bound (in terms of other coeff's + const), i.e., if 479 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j 480 // - 1. 481 addCoeffs(ineq, lb); 482 std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>()); 483 auto expr = 484 getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context); 485 // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor 486 int64_t divisor = std::abs(ineq[pos + offset]); 487 expr = (expr + divisor - 1).floorDiv(divisor); 488 lbExprs.push_back(expr); 489 } 490 491 SmallVector<AffineExpr, 4> ubExprs; 492 ubExprs.reserve(ubIndices.size() + eqIndices.size()); 493 // Upper bound expressions. 494 for (auto idx : ubIndices) { 495 auto ineq = getInequality64(idx); 496 // Extract the upper bound (in terms of other coeff's + const). 497 addCoeffs(ineq, ub); 498 auto expr = 499 getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context); 500 expr = expr.floorDiv(std::abs(ineq[pos + offset])); 501 int64_t ubAdjustment = closedUB ? 0 : 1; 502 ubExprs.push_back(expr + ubAdjustment); 503 } 504 505 // Equalities. It's both a lower and a upper bound. 506 SmallVector<int64_t, 4> b; 507 for (auto idx : eqIndices) { 508 auto eq = getEquality64(idx); 509 addCoeffs(eq, b); 510 if (eq[pos + offset] > 0) 511 std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>()); 512 513 // Extract the upper bound (in terms of other coeff's + const). 514 auto expr = 515 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); 516 expr = expr.floorDiv(std::abs(eq[pos + offset])); 517 // Upper bound is exclusive. 518 ubExprs.push_back(expr + 1); 519 // Lower bound. 520 expr = 521 getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); 522 expr = expr.ceilDiv(std::abs(eq[pos + offset])); 523 lbExprs.push_back(expr); 524 } 525 526 auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context); 527 auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context); 528 529 return {lbMap, ubMap}; 530 } 531 532 /// Computes the lower and upper bounds of the first 'num' dimensional 533 /// variables (starting at 'offset') as affine maps of the remaining 534 /// variables (dimensional and symbolic variables). Local variables are 535 /// themselves explicitly computed as affine functions of other variables in 536 /// this process if needed. 537 void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num, 538 MLIRContext *context, 539 SmallVectorImpl<AffineMap> *lbMaps, 540 SmallVectorImpl<AffineMap> *ubMaps, 541 bool closedUB) { 542 assert(offset + num <= getNumDimVars() && "invalid range"); 543 544 // Basic simplification. 545 normalizeConstraintsByGCD(); 546 547 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num 548 << " variables\n"); 549 LLVM_DEBUG(dump()); 550 551 // Record computed/detected variables. 552 SmallVector<AffineExpr, 8> memo(getNumVars()); 553 // Initialize dimensional and symbolic variables. 554 for (unsigned i = 0, e = getNumDimVars(); i < e; i++) { 555 if (i < offset) 556 memo[i] = getAffineDimExpr(i, context); 557 else if (i >= offset + num) 558 memo[i] = getAffineDimExpr(i - num, context); 559 } 560 for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++) 561 memo[i] = getAffineSymbolExpr(i - getNumDimVars(), context); 562 563 bool changed; 564 do { 565 changed = false; 566 // Identify yet unknown variables as constants or mod's / floordiv's of 567 // other variables if possible. 568 for (unsigned pos = 0; pos < getNumVars(); pos++) { 569 if (memo[pos]) 570 continue; 571 572 auto lbConst = getConstantBound64(BoundType::LB, pos); 573 auto ubConst = getConstantBound64(BoundType::UB, pos); 574 if (lbConst.has_value() && ubConst.has_value()) { 575 // Detect equality to a constant. 576 if (*lbConst == *ubConst) { 577 memo[pos] = getAffineConstantExpr(*lbConst, context); 578 changed = true; 579 continue; 580 } 581 582 // Detect a variable as modulo of another variable w.r.t a 583 // constant. 584 if (detectAsMod(*this, pos, offset, num, *lbConst, *ubConst, context, 585 memo)) { 586 changed = true; 587 continue; 588 } 589 } 590 591 // Detect a variable as a floordiv of an affine function of other 592 // variables (divisor is a positive constant). 593 if (detectAsFloorDiv(*this, pos, context, memo)) { 594 changed = true; 595 continue; 596 } 597 598 // Detect a variable as an expression of other variables. 599 unsigned idx; 600 if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) { 601 continue; 602 } 603 604 // Build AffineExpr solving for variable 'pos' in terms of all others. 605 auto expr = getAffineConstantExpr(0, context); 606 unsigned j, e; 607 for (j = 0, e = getNumVars(); j < e; ++j) { 608 if (j == pos) 609 continue; 610 int64_t c = atEq64(idx, j); 611 if (c == 0) 612 continue; 613 // If any of the involved IDs hasn't been found yet, we can't proceed. 614 if (!memo[j]) 615 break; 616 expr = expr + memo[j] * c; 617 } 618 if (j < e) 619 // Can't construct expression as it depends on a yet uncomputed 620 // variable. 621 continue; 622 623 // Add constant term to AffineExpr. 624 expr = expr + atEq64(idx, getNumVars()); 625 int64_t vPos = atEq64(idx, pos); 626 assert(vPos != 0 && "expected non-zero here"); 627 if (vPos > 0) 628 expr = (-expr).floorDiv(vPos); 629 else 630 // vPos < 0. 631 expr = expr.floorDiv(-vPos); 632 // Successfully constructed expression. 633 memo[pos] = expr; 634 changed = true; 635 } 636 // This loop is guaranteed to reach a fixed point - since once an 637 // variable's explicit form is computed (in memo[pos]), it's not updated 638 // again. 639 } while (changed); 640 641 int64_t ubAdjustment = closedUB ? 0 : 1; 642 643 // Set the lower and upper bound maps for all the variables that were 644 // computed as affine expressions of the rest as the "detected expr" and 645 // "detected expr + 1" respectively; set the undetected ones to null. 646 std::optional<FlatLinearConstraints> tmpClone; 647 for (unsigned pos = 0; pos < num; pos++) { 648 unsigned numMapDims = getNumDimVars() - num; 649 unsigned numMapSymbols = getNumSymbolVars(); 650 AffineExpr expr = memo[pos + offset]; 651 if (expr) 652 expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); 653 654 AffineMap &lbMap = (*lbMaps)[pos]; 655 AffineMap &ubMap = (*ubMaps)[pos]; 656 657 if (expr) { 658 lbMap = AffineMap::get(numMapDims, numMapSymbols, expr); 659 ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + ubAdjustment); 660 } else { 661 // TODO: Whenever there are local variables in the dependence 662 // constraints, we'll conservatively over-approximate, since we don't 663 // always explicitly compute them above (in the while loop). 664 if (getNumLocalVars() == 0) { 665 // Work on a copy so that we don't update this constraint system. 666 if (!tmpClone) { 667 tmpClone.emplace(FlatLinearConstraints(*this)); 668 // Removing redundant inequalities is necessary so that we don't get 669 // redundant loop bounds. 670 tmpClone->removeRedundantInequalities(); 671 } 672 std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound( 673 pos, offset, num, getNumDimVars(), /*localExprs=*/{}, context, 674 closedUB); 675 } 676 677 // If the above fails, we'll just use the constant lower bound and the 678 // constant upper bound (if they exist) as the slice bounds. 679 // TODO: being conservative for the moment in cases that 680 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is 681 // fixed (b/126426796). 682 if (!lbMap || lbMap.getNumResults() > 1) { 683 LLVM_DEBUG(llvm::dbgs() 684 << "WARNING: Potentially over-approximating slice lb\n"); 685 auto lbConst = getConstantBound64(BoundType::LB, pos + offset); 686 if (lbConst.has_value()) { 687 lbMap = AffineMap::get(numMapDims, numMapSymbols, 688 getAffineConstantExpr(*lbConst, context)); 689 } 690 } 691 if (!ubMap || ubMap.getNumResults() > 1) { 692 LLVM_DEBUG(llvm::dbgs() 693 << "WARNING: Potentially over-approximating slice ub\n"); 694 auto ubConst = getConstantBound64(BoundType::UB, pos + offset); 695 if (ubConst.has_value()) { 696 ubMap = AffineMap::get( 697 numMapDims, numMapSymbols, 698 getAffineConstantExpr(*ubConst + ubAdjustment, context)); 699 } 700 } 701 } 702 LLVM_DEBUG(llvm::dbgs() 703 << "lb map for pos = " << Twine(pos + offset) << ", expr: "); 704 LLVM_DEBUG(lbMap.dump();); 705 LLVM_DEBUG(llvm::dbgs() 706 << "ub map for pos = " << Twine(pos + offset) << ", expr: "); 707 LLVM_DEBUG(ubMap.dump();); 708 } 709 } 710 711 LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals( 712 AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, 713 bool addConservativeSemiAffineBounds) { 714 FlatLinearConstraints localCst; 715 if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst, 716 addConservativeSemiAffineBounds))) { 717 LLVM_DEBUG(llvm::dbgs() 718 << "composition unimplemented for semi-affine maps\n"); 719 return failure(); 720 } 721 722 // Add localCst information. 723 if (localCst.getNumLocalVars() > 0) { 724 unsigned numLocalVars = getNumLocalVars(); 725 // Insert local dims of localCst at the beginning. 726 insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars()); 727 // Insert local dims of `this` at the end of localCst. 728 localCst.appendLocalVar(/*num=*/numLocalVars); 729 // Dimensions of localCst and this constraint set match. Append localCst to 730 // this constraint set. 731 append(localCst); 732 } 733 734 return success(); 735 } 736 737 LogicalResult FlatLinearConstraints::addBound( 738 BoundType type, unsigned pos, AffineMap boundMap, bool isClosedBound, 739 AddConservativeSemiAffineBounds addSemiAffineBounds) { 740 assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch"); 741 assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"); 742 assert(pos < getNumDimAndSymbolVars() && "invalid position"); 743 assert((type != BoundType::EQ || isClosedBound) && 744 "EQ bound must be closed."); 745 746 // Equality follows the logic of lower bound except that we add an equality 747 // instead of an inequality. 748 assert((type != BoundType::EQ || boundMap.getNumResults() == 1) && 749 "single result expected"); 750 bool lower = type == BoundType::LB || type == BoundType::EQ; 751 752 std::vector<SmallVector<int64_t, 8>> flatExprs; 753 if (failed(flattenAlignedMapAndMergeLocals( 754 boundMap, &flatExprs, 755 addSemiAffineBounds == AddConservativeSemiAffineBounds::Yes))) 756 return failure(); 757 assert(flatExprs.size() == boundMap.getNumResults()); 758 759 // Add one (in)equality for each result. 760 for (const auto &flatExpr : flatExprs) { 761 SmallVector<int64_t> ineq(getNumCols(), 0); 762 // Dims and symbols. 763 for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { 764 ineq[j] = lower ? -flatExpr[j] : flatExpr[j]; 765 } 766 // Invalid bound: pos appears in `boundMap`. 767 // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or 768 // its callers to prevent invalid bounds from being added. 769 if (ineq[pos] != 0) 770 continue; 771 ineq[pos] = lower ? 1 : -1; 772 // Local columns of `ineq` are at the beginning. 773 unsigned j = getNumDimVars() + getNumSymbolVars(); 774 unsigned end = flatExpr.size() - 1; 775 for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) { 776 ineq[j] = lower ? -flatExpr[i] : flatExpr[i]; 777 } 778 // Make the bound closed in if flatExpr is open. The inequality is always 779 // created in the upper bound form, so the adjustment is -1. 780 int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1; 781 // Constant term. 782 ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1] 783 : flatExpr[flatExpr.size() - 1]) + 784 boundAdjustment; 785 type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq); 786 } 787 788 return success(); 789 } 790 791 LogicalResult FlatLinearConstraints::addBound( 792 BoundType type, unsigned pos, AffineMap boundMap, 793 AddConservativeSemiAffineBounds addSemiAffineBounds) { 794 return addBound(type, pos, boundMap, 795 /*isClosedBound=*/type != BoundType::UB, addSemiAffineBounds); 796 } 797 798 /// Compute an explicit representation for local vars. For all systems coming 799 /// from MLIR integer sets, maps, or expressions where local vars were 800 /// introduced to model floordivs and mods, this always succeeds. 801 LogicalResult 802 FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo, 803 MLIRContext *context) const { 804 unsigned numDims = getNumDimVars(); 805 unsigned numSyms = getNumSymbolVars(); 806 807 // Initialize dimensional and symbolic variables. 808 for (unsigned i = 0; i < numDims; i++) 809 memo[i] = getAffineDimExpr(i, context); 810 for (unsigned i = numDims, e = numDims + numSyms; i < e; i++) 811 memo[i] = getAffineSymbolExpr(i - numDims, context); 812 813 bool changed; 814 do { 815 // Each time `changed` is true at the end of this iteration, one or more 816 // local vars would have been detected as floordivs and set in memo; so the 817 // number of null entries in memo[...] strictly reduces; so this converges. 818 changed = false; 819 for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) 820 if (!memo[numDims + numSyms + i] && 821 detectAsFloorDiv(*this, /*pos=*/numDims + numSyms + i, context, memo)) 822 changed = true; 823 } while (changed); 824 825 ArrayRef<AffineExpr> localExprs = 826 ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars()); 827 return success( 828 llvm::all_of(localExprs, [](AffineExpr expr) { return expr; })); 829 } 830 831 IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const { 832 if (getNumConstraints() == 0) 833 // Return universal set (always true): 0 == 0. 834 return IntegerSet::get(getNumDimVars(), getNumSymbolVars(), 835 getAffineConstantExpr(/*constant=*/0, context), 836 /*eqFlags=*/true); 837 838 // Construct local references. 839 SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr()); 840 841 if (failed(computeLocalVars(memo, context))) { 842 // Check if the local variables without an explicit representation have 843 // zero coefficients everywhere. 844 SmallVector<unsigned> noLocalRepVars; 845 unsigned numDimsSymbols = getNumDimAndSymbolVars(); 846 for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) { 847 if (!memo[i] && !isColZero(/*pos=*/i)) 848 noLocalRepVars.push_back(i - numDimsSymbols); 849 } 850 if (!noLocalRepVars.empty()) { 851 LLVM_DEBUG({ 852 llvm::dbgs() << "local variables at position(s) "; 853 llvm::interleaveComma(noLocalRepVars, llvm::dbgs()); 854 llvm::dbgs() << " do not have an explicit representation in:\n"; 855 this->dump(); 856 }); 857 return IntegerSet(); 858 } 859 } 860 861 ArrayRef<AffineExpr> localExprs = 862 ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars()); 863 864 // Construct the IntegerSet from the equalities/inequalities. 865 unsigned numDims = getNumDimVars(); 866 unsigned numSyms = getNumSymbolVars(); 867 868 SmallVector<bool, 16> eqFlags(getNumConstraints()); 869 std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true); 870 std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false); 871 872 SmallVector<AffineExpr, 8> exprs; 873 exprs.reserve(getNumConstraints()); 874 875 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) 876 exprs.push_back(getAffineExprFromFlatForm(getEquality64(i), numDims, 877 numSyms, localExprs, context)); 878 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) 879 exprs.push_back(getAffineExprFromFlatForm(getInequality64(i), numDims, 880 numSyms, localExprs, context)); 881 return IntegerSet::get(numDims, numSyms, exprs, eqFlags); 882 } 883 884 //===----------------------------------------------------------------------===// 885 // FlatLinearValueConstraints 886 //===----------------------------------------------------------------------===// 887 888 // Construct from an IntegerSet. 889 FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set, 890 ValueRange operands) 891 : FlatLinearConstraints(set.getNumInequalities(), set.getNumEqualities(), 892 set.getNumDims() + set.getNumSymbols() + 1, 893 set.getNumDims(), set.getNumSymbols(), 894 /*numLocals=*/0) { 895 assert((operands.empty() || set.getNumInputs() == operands.size()) && 896 "operand count mismatch"); 897 // Set the values for the non-local variables. 898 for (unsigned i = 0, e = operands.size(); i < e; ++i) 899 setValue(i, operands[i]); 900 901 // Flatten expressions and add them to the constraint system. 902 std::vector<SmallVector<int64_t, 8>> flatExprs; 903 FlatLinearConstraints localVarCst; 904 if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) { 905 assert(false && "flattening unimplemented for semi-affine integer sets"); 906 return; 907 } 908 assert(flatExprs.size() == set.getNumConstraints()); 909 insertVar(VarKind::Local, getNumVarKind(VarKind::Local), 910 /*num=*/localVarCst.getNumLocalVars()); 911 912 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { 913 const auto &flatExpr = flatExprs[i]; 914 assert(flatExpr.size() == getNumCols()); 915 if (set.getEqFlags()[i]) { 916 addEquality(flatExpr); 917 } else { 918 addInequality(flatExpr); 919 } 920 } 921 // Add the other constraints involving local vars from flattening. 922 append(localVarCst); 923 } 924 925 unsigned FlatLinearValueConstraints::appendDimVar(ValueRange vals) { 926 unsigned pos = getNumDimVars(); 927 return insertVar(VarKind::SetDim, pos, vals); 928 } 929 930 unsigned FlatLinearValueConstraints::appendSymbolVar(ValueRange vals) { 931 unsigned pos = getNumSymbolVars(); 932 return insertVar(VarKind::Symbol, pos, vals); 933 } 934 935 unsigned FlatLinearValueConstraints::insertDimVar(unsigned pos, 936 ValueRange vals) { 937 return insertVar(VarKind::SetDim, pos, vals); 938 } 939 940 unsigned FlatLinearValueConstraints::insertSymbolVar(unsigned pos, 941 ValueRange vals) { 942 return insertVar(VarKind::Symbol, pos, vals); 943 } 944 945 unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos, 946 unsigned num) { 947 unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num); 948 949 return absolutePos; 950 } 951 952 unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos, 953 ValueRange vals) { 954 assert(!vals.empty() && "expected ValueRange with Values."); 955 assert(kind != VarKind::Local && 956 "values cannot be attached to local variables."); 957 unsigned num = vals.size(); 958 unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num); 959 960 // If a Value is provided, insert it; otherwise use std::nullopt. 961 for (unsigned i = 0, e = vals.size(); i < e; ++i) 962 if (vals[i]) 963 setValue(absolutePos + i, vals[i]); 964 965 return absolutePos; 966 } 967 968 /// Checks if two constraint systems are in the same space, i.e., if they are 969 /// associated with the same set of variables, appearing in the same order. 970 static bool areVarsAligned(const FlatLinearValueConstraints &a, 971 const FlatLinearValueConstraints &b) { 972 if (a.getNumDomainVars() != b.getNumDomainVars() || 973 a.getNumRangeVars() != b.getNumRangeVars() || 974 a.getNumSymbolVars() != b.getNumSymbolVars()) 975 return false; 976 SmallVector<std::optional<Value>> aMaybeValues = a.getMaybeValues(), 977 bMaybeValues = b.getMaybeValues(); 978 return std::equal(aMaybeValues.begin(), aMaybeValues.end(), 979 bMaybeValues.begin(), bMaybeValues.end()); 980 } 981 982 /// Calls areVarsAligned to check if two constraint systems have the same set 983 /// of variables in the same order. 984 bool FlatLinearValueConstraints::areVarsAlignedWithOther( 985 const FlatLinearConstraints &other) { 986 return areVarsAligned(*this, other); 987 } 988 989 /// Checks if the SSA values associated with `cst`'s variables in range 990 /// [start, end) are unique. 991 static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( 992 const FlatLinearValueConstraints &cst, unsigned start, unsigned end) { 993 994 assert(start <= cst.getNumDimAndSymbolVars() && 995 "Start position out of bounds"); 996 assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds"); 997 998 if (start >= end) 999 return true; 1000 1001 SmallPtrSet<Value, 8> uniqueVars; 1002 SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues(); 1003 ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start, 1004 maybeValuesAll.data() + end}; 1005 1006 for (std::optional<Value> val : maybeValues) 1007 if (val && !uniqueVars.insert(*val).second) 1008 return false; 1009 1010 return true; 1011 } 1012 1013 /// Checks if the SSA values associated with `cst`'s variables are unique. 1014 static bool LLVM_ATTRIBUTE_UNUSED 1015 areVarsUnique(const FlatLinearValueConstraints &cst) { 1016 return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars()); 1017 } 1018 1019 /// Checks if the SSA values associated with `cst`'s variables of kind `kind` 1020 /// are unique. 1021 static bool LLVM_ATTRIBUTE_UNUSED 1022 areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) { 1023 1024 if (kind == VarKind::SetDim) 1025 return areVarsUnique(cst, 0, cst.getNumDimVars()); 1026 if (kind == VarKind::Symbol) 1027 return areVarsUnique(cst, cst.getNumDimVars(), 1028 cst.getNumDimAndSymbolVars()); 1029 llvm_unreachable("Unexpected VarKind"); 1030 } 1031 1032 /// Merge and align the variables of A and B starting at 'offset', so that 1033 /// both constraint systems get the union of the contained variables that is 1034 /// dimension-wise and symbol-wise unique; both constraint systems are updated 1035 /// so that they have the union of all variables, with A's original 1036 /// variables appearing first followed by any of B's variables that didn't 1037 /// appear in A. Local variables in B that have the same division 1038 /// representation as local variables in A are merged into one. We allow A 1039 /// and B to have non-unique values for their variables; in such cases, they are 1040 /// still aligned with the variables appearing first aligned with those 1041 /// appearing first in the other system from left to right. 1042 // E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M]) 1043 // Output: both A, B have (%i, %j, %k) [%M, %N, %P] 1044 static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a, 1045 FlatLinearValueConstraints *b) { 1046 assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars()); 1047 1048 assert(llvm::all_of( 1049 llvm::drop_begin(a->getMaybeValues(), offset), 1050 [](const std::optional<Value> &var) { return var.has_value(); })); 1051 1052 assert(llvm::all_of( 1053 llvm::drop_begin(b->getMaybeValues(), offset), 1054 [](const std::optional<Value> &var) { return var.has_value(); })); 1055 1056 SmallVector<Value, 4> aDimValues; 1057 a->getValues(offset, a->getNumDimVars(), &aDimValues); 1058 1059 { 1060 // Merge dims from A into B. 1061 unsigned d = offset; 1062 for (Value aDimValue : aDimValues) { 1063 unsigned loc; 1064 // Find from the position `d` since we'd like to also consider the 1065 // possibility of multiple variables with the same `Value`. We align with 1066 // the next appearing one. 1067 if (b->findVar(aDimValue, &loc, d)) { 1068 assert(loc >= offset && "A's dim appears in B's aligned range"); 1069 assert(loc < b->getNumDimVars() && 1070 "A's dim appears in B's non-dim position"); 1071 b->swapVar(d, loc); 1072 } else { 1073 b->insertDimVar(d, aDimValue); 1074 } 1075 d++; 1076 } 1077 // Dimensions that are in B, but not in A, are added at the end. 1078 for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) { 1079 a->appendDimVar(b->getValue(t)); 1080 } 1081 assert(a->getNumDimVars() == b->getNumDimVars() && 1082 "expected same number of dims"); 1083 } 1084 1085 // Merge and align symbols of A and B 1086 a->mergeSymbolVars(*b); 1087 // Merge and align locals of A and B 1088 a->mergeLocalVars(*b); 1089 1090 assert(areVarsAligned(*a, *b) && "IDs expected to be aligned"); 1091 } 1092 1093 // Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'. 1094 void FlatLinearValueConstraints::mergeAndAlignVarsWithOther( 1095 unsigned offset, FlatLinearValueConstraints *other) { 1096 mergeAndAlignVars(offset, this, other); 1097 } 1098 1099 /// Merge and align symbols of `this` and `other` such that both get union of 1100 /// of symbols. Existing symbols need not be unique; they will be aligned from 1101 /// left to right with duplicates aligned in the same order. Symbols with Value 1102 /// as `None` are considered to be inequal to all other symbols. 1103 void FlatLinearValueConstraints::mergeSymbolVars( 1104 FlatLinearValueConstraints &other) { 1105 1106 SmallVector<Value, 4> aSymValues; 1107 getValues(getNumDimVars(), getNumDimAndSymbolVars(), &aSymValues); 1108 1109 // Merge symbols: merge symbols into `other` first from `this`. 1110 unsigned s = other.getNumDimVars(); 1111 for (Value aSymValue : aSymValues) { 1112 unsigned loc; 1113 // If the var is a symbol in `other`, then align it, otherwise assume that 1114 // it is a new symbol. Search in `other` starting at position `s` since the 1115 // left of it is aligned. 1116 if (other.findVar(aSymValue, &loc, s) && loc >= other.getNumDimVars() && 1117 loc < other.getNumDimAndSymbolVars()) 1118 other.swapVar(s, loc); 1119 else 1120 other.insertSymbolVar(s - other.getNumDimVars(), aSymValue); 1121 s++; 1122 } 1123 1124 // Symbols that are in other, but not in this, are added at the end. 1125 for (unsigned t = other.getNumDimVars() + getNumSymbolVars(), 1126 e = other.getNumDimAndSymbolVars(); 1127 t < e; t++) 1128 insertSymbolVar(getNumSymbolVars(), other.getValue(t)); 1129 1130 assert(getNumSymbolVars() == other.getNumSymbolVars() && 1131 "expected same number of symbols"); 1132 } 1133 1134 void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart, 1135 unsigned varLimit) { 1136 IntegerPolyhedron::removeVarRange(kind, varStart, varLimit); 1137 } 1138 1139 AffineMap 1140 FlatLinearValueConstraints::computeAlignedMap(AffineMap map, 1141 ValueRange operands) const { 1142 assert(map.getNumInputs() == operands.size() && "number of inputs mismatch"); 1143 1144 SmallVector<Value> dims, syms; 1145 #ifndef NDEBUG 1146 SmallVector<Value> newSyms; 1147 SmallVector<Value> *newSymsPtr = &newSyms; 1148 #else 1149 SmallVector<Value> *newSymsPtr = nullptr; 1150 #endif // NDEBUG 1151 1152 dims.reserve(getNumDimVars()); 1153 syms.reserve(getNumSymbolVars()); 1154 for (unsigned i = 0, e = getNumVarKind(VarKind::SetDim); i < e; ++i) { 1155 Identifier id = space.getId(VarKind::SetDim, i); 1156 dims.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value()); 1157 } 1158 for (unsigned i = 0, e = getNumVarKind(VarKind::Symbol); i < e; ++i) { 1159 Identifier id = space.getId(VarKind::Symbol, i); 1160 syms.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value()); 1161 } 1162 1163 AffineMap alignedMap = 1164 alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr); 1165 // All symbols are already part of this FlatAffineValueConstraints. 1166 assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols"); 1167 assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) && 1168 "unexpected new/missing symbols"); 1169 return alignedMap; 1170 } 1171 1172 bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos, 1173 unsigned offset) const { 1174 SmallVector<std::optional<Value>> maybeValues = getMaybeValues(); 1175 for (unsigned i = offset, e = maybeValues.size(); i < e; ++i) 1176 if (maybeValues[i] && maybeValues[i].value() == val) { 1177 *pos = i; 1178 return true; 1179 } 1180 return false; 1181 } 1182 1183 bool FlatLinearValueConstraints::containsVar(Value val) const { 1184 unsigned pos; 1185 return findVar(val, &pos, 0); 1186 } 1187 1188 void FlatLinearValueConstraints::addBound(BoundType type, Value val, 1189 int64_t value) { 1190 unsigned pos; 1191 if (!findVar(val, &pos)) 1192 // This is a pre-condition for this method. 1193 assert(0 && "var not found"); 1194 addBound(type, pos, value); 1195 } 1196 1197 void FlatLinearConstraints::printSpace(raw_ostream &os) const { 1198 IntegerPolyhedron::printSpace(os); 1199 os << "("; 1200 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) 1201 os << "None\t"; 1202 for (unsigned i = getVarKindOffset(VarKind::Local), 1203 e = getVarKindEnd(VarKind::Local); 1204 i < e; ++i) 1205 os << "Local\t"; 1206 os << "const)\n"; 1207 } 1208 1209 void FlatLinearValueConstraints::printSpace(raw_ostream &os) const { 1210 IntegerPolyhedron::printSpace(os); 1211 os << "("; 1212 for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) { 1213 if (hasValue(i)) 1214 os << "Value\t"; 1215 else 1216 os << "None\t"; 1217 } 1218 for (unsigned i = getVarKindOffset(VarKind::Local), 1219 e = getVarKindEnd(VarKind::Local); 1220 i < e; ++i) 1221 os << "Local\t"; 1222 os << "const)\n"; 1223 } 1224 1225 void FlatLinearValueConstraints::projectOut(Value val) { 1226 unsigned pos; 1227 bool ret = findVar(val, &pos); 1228 assert(ret); 1229 (void)ret; 1230 fourierMotzkinEliminate(pos); 1231 } 1232 1233 LogicalResult FlatLinearValueConstraints::unionBoundingBox( 1234 const FlatLinearValueConstraints &otherCst) { 1235 assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch"); 1236 SmallVector<std::optional<Value>> maybeValues = getMaybeValues(), 1237 otherMaybeValues = 1238 otherCst.getMaybeValues(); 1239 assert(std::equal(maybeValues.begin(), maybeValues.begin() + getNumDimVars(), 1240 otherMaybeValues.begin(), 1241 otherMaybeValues.begin() + getNumDimVars()) && 1242 "dim values mismatch"); 1243 assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here"); 1244 assert(getNumLocalVars() == 0 && "local vars not supported yet here"); 1245 1246 // Align `other` to this. 1247 if (!areVarsAligned(*this, otherCst)) { 1248 FlatLinearValueConstraints otherCopy(otherCst); 1249 mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy); 1250 return IntegerPolyhedron::unionBoundingBox(otherCopy); 1251 } 1252 1253 return IntegerPolyhedron::unionBoundingBox(otherCst); 1254 } 1255 1256 //===----------------------------------------------------------------------===// 1257 // Helper functions 1258 //===----------------------------------------------------------------------===// 1259 1260 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands, 1261 ValueRange dims, ValueRange syms, 1262 SmallVector<Value> *newSyms) { 1263 assert(operands.size() == map.getNumInputs() && 1264 "expected same number of operands and map inputs"); 1265 MLIRContext *ctx = map.getContext(); 1266 Builder builder(ctx); 1267 SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {}); 1268 unsigned numSymbols = syms.size(); 1269 SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {}); 1270 if (newSyms) { 1271 newSyms->clear(); 1272 newSyms->append(syms.begin(), syms.end()); 1273 } 1274 1275 for (const auto &operand : llvm::enumerate(operands)) { 1276 // Compute replacement dim/sym of operand. 1277 AffineExpr replacement; 1278 auto dimIt = llvm::find(dims, operand.value()); 1279 auto symIt = llvm::find(syms, operand.value()); 1280 if (dimIt != dims.end()) { 1281 replacement = 1282 builder.getAffineDimExpr(std::distance(dims.begin(), dimIt)); 1283 } else if (symIt != syms.end()) { 1284 replacement = 1285 builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt)); 1286 } else { 1287 // This operand is neither a dimension nor a symbol. Add it as a new 1288 // symbol. 1289 replacement = builder.getAffineSymbolExpr(numSymbols++); 1290 if (newSyms) 1291 newSyms->push_back(operand.value()); 1292 } 1293 // Add to corresponding replacements vector. 1294 if (operand.index() < map.getNumDims()) { 1295 dimReplacements[operand.index()] = replacement; 1296 } else { 1297 symReplacements[operand.index() - map.getNumDims()] = replacement; 1298 } 1299 } 1300 1301 return map.replaceDimsAndSymbols(dimReplacements, symReplacements, 1302 dims.size(), numSymbols); 1303 } 1304 1305 LogicalResult 1306 mlir::getMultiAffineFunctionFromMap(AffineMap map, 1307 MultiAffineFunction &multiAff) { 1308 FlatLinearConstraints cst; 1309 std::vector<SmallVector<int64_t, 8>> flattenedExprs; 1310 LogicalResult result = getFlattenedAffineExprs(map, &flattenedExprs, &cst); 1311 1312 if (result.failed()) 1313 return failure(); 1314 1315 DivisionRepr divs = cst.getLocalReprs(); 1316 assert(divs.hasAllReprs() && 1317 "AffineMap cannot produce divs without local representation"); 1318 1319 // TODO: We shouldn't have to do this conversion. 1320 Matrix<DynamicAPInt> mat(map.getNumResults(), 1321 map.getNumInputs() + divs.getNumDivs() + 1); 1322 for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i) 1323 for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j) 1324 mat(i, j) = flattenedExprs[i][j]; 1325 1326 multiAff = MultiAffineFunction( 1327 PresburgerSpace::getRelationSpace(map.getNumDims(), map.getNumResults(), 1328 map.getNumSymbols(), divs.getNumDivs()), 1329 mat, divs); 1330 1331 return success(); 1332 } 1333