1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <cmath> 10 #include <cstdint> 11 #include <limits> 12 #include <utility> 13 14 #include "AffineExprDetail.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineExprVisitor.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/IntegerSet.h" 19 #include "mlir/Support/TypeID.h" 20 #include "llvm/ADT/STLExtras.h" 21 #include "llvm/Support/MathExtras.h" 22 #include <numeric> 23 #include <optional> 24 25 using namespace mlir; 26 using namespace mlir::detail; 27 28 using llvm::divideCeilSigned; 29 using llvm::divideFloorSigned; 30 using llvm::divideSignedWouldOverflow; 31 using llvm::mod; 32 33 MLIRContext *AffineExpr::getContext() const { return expr->context; } 34 35 AffineExprKind AffineExpr::getKind() const { return expr->kind; } 36 37 /// Walk all of the AffineExprs in `e` in postorder. This is a private factory 38 /// method to help handle lambda walk functions. Users should use the regular 39 /// (non-static) `walk` method. 40 template <typename WalkRetTy> 41 WalkRetTy mlir::AffineExpr::walk(AffineExpr e, 42 function_ref<WalkRetTy(AffineExpr)> callback) { 43 struct AffineExprWalker 44 : public AffineExprVisitor<AffineExprWalker, WalkRetTy> { 45 function_ref<WalkRetTy(AffineExpr)> callback; 46 47 AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback) 48 : callback(callback) {} 49 50 WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { 51 return callback(expr); 52 } 53 WalkRetTy visitConstantExpr(AffineConstantExpr expr) { 54 return callback(expr); 55 } 56 WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); } 57 WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); } 58 }; 59 60 return AffineExprWalker(callback).walkPostOrder(e); 61 } 62 // Explicitly instantiate for the two supported return types. 63 template void mlir::AffineExpr::walk(AffineExpr e, 64 function_ref<void(AffineExpr)> callback); 65 template WalkResult 66 mlir::AffineExpr::walk(AffineExpr e, 67 function_ref<WalkResult(AffineExpr)> callback); 68 69 // Dispatch affine expression construction based on kind. 70 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, 71 AffineExpr rhs) { 72 if (kind == AffineExprKind::Add) 73 return lhs + rhs; 74 if (kind == AffineExprKind::Mul) 75 return lhs * rhs; 76 if (kind == AffineExprKind::FloorDiv) 77 return lhs.floorDiv(rhs); 78 if (kind == AffineExprKind::CeilDiv) 79 return lhs.ceilDiv(rhs); 80 if (kind == AffineExprKind::Mod) 81 return lhs % rhs; 82 83 llvm_unreachable("unknown binary operation on affine expressions"); 84 } 85 86 /// This method substitutes any uses of dimensions and symbols (e.g. 87 /// dim#0 with dimReplacements[0]) and returns the modified expression tree. 88 AffineExpr 89 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 90 ArrayRef<AffineExpr> symReplacements) const { 91 switch (getKind()) { 92 case AffineExprKind::Constant: 93 return *this; 94 case AffineExprKind::DimId: { 95 unsigned dimId = llvm::cast<AffineDimExpr>(*this).getPosition(); 96 if (dimId >= dimReplacements.size()) 97 return *this; 98 return dimReplacements[dimId]; 99 } 100 case AffineExprKind::SymbolId: { 101 unsigned symId = llvm::cast<AffineSymbolExpr>(*this).getPosition(); 102 if (symId >= symReplacements.size()) 103 return *this; 104 return symReplacements[symId]; 105 } 106 case AffineExprKind::Add: 107 case AffineExprKind::Mul: 108 case AffineExprKind::FloorDiv: 109 case AffineExprKind::CeilDiv: 110 case AffineExprKind::Mod: 111 auto binOp = llvm::cast<AffineBinaryOpExpr>(*this); 112 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 113 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 114 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 115 if (newLHS == lhs && newRHS == rhs) 116 return *this; 117 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 118 } 119 llvm_unreachable("Unknown AffineExpr"); 120 } 121 122 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const { 123 return replaceDimsAndSymbols(dimReplacements, {}); 124 } 125 126 AffineExpr 127 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const { 128 return replaceDimsAndSymbols({}, symReplacements); 129 } 130 131 /// Replace dims[offset ... numDims) 132 /// by dims[offset + shift ... shift + numDims). 133 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift, 134 unsigned offset) const { 135 SmallVector<AffineExpr, 4> dims; 136 for (unsigned idx = 0; idx < offset; ++idx) 137 dims.push_back(getAffineDimExpr(idx, getContext())); 138 for (unsigned idx = offset; idx < numDims; ++idx) 139 dims.push_back(getAffineDimExpr(idx + shift, getContext())); 140 return replaceDimsAndSymbols(dims, {}); 141 } 142 143 /// Replace symbols[offset ... numSymbols) 144 /// by symbols[offset + shift ... shift + numSymbols). 145 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift, 146 unsigned offset) const { 147 SmallVector<AffineExpr, 4> symbols; 148 for (unsigned idx = 0; idx < offset; ++idx) 149 symbols.push_back(getAffineSymbolExpr(idx, getContext())); 150 for (unsigned idx = offset; idx < numSymbols; ++idx) 151 symbols.push_back(getAffineSymbolExpr(idx + shift, getContext())); 152 return replaceDimsAndSymbols({}, symbols); 153 } 154 155 /// Sparse replace method. Return the modified expression tree. 156 AffineExpr 157 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const { 158 auto it = map.find(*this); 159 if (it != map.end()) 160 return it->second; 161 switch (getKind()) { 162 default: 163 return *this; 164 case AffineExprKind::Add: 165 case AffineExprKind::Mul: 166 case AffineExprKind::FloorDiv: 167 case AffineExprKind::CeilDiv: 168 case AffineExprKind::Mod: 169 auto binOp = llvm::cast<AffineBinaryOpExpr>(*this); 170 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 171 auto newLHS = lhs.replace(map); 172 auto newRHS = rhs.replace(map); 173 if (newLHS == lhs && newRHS == rhs) 174 return *this; 175 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 176 } 177 llvm_unreachable("Unknown AffineExpr"); 178 } 179 180 /// Sparse replace method. Return the modified expression tree. 181 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const { 182 DenseMap<AffineExpr, AffineExpr> map; 183 map.insert(std::make_pair(expr, replacement)); 184 return replace(map); 185 } 186 /// Returns true if this expression is made out of only symbols and 187 /// constants (no dimensional identifiers). 188 bool AffineExpr::isSymbolicOrConstant() const { 189 switch (getKind()) { 190 case AffineExprKind::Constant: 191 return true; 192 case AffineExprKind::DimId: 193 return false; 194 case AffineExprKind::SymbolId: 195 return true; 196 197 case AffineExprKind::Add: 198 case AffineExprKind::Mul: 199 case AffineExprKind::FloorDiv: 200 case AffineExprKind::CeilDiv: 201 case AffineExprKind::Mod: { 202 auto expr = llvm::cast<AffineBinaryOpExpr>(*this); 203 return expr.getLHS().isSymbolicOrConstant() && 204 expr.getRHS().isSymbolicOrConstant(); 205 } 206 } 207 llvm_unreachable("Unknown AffineExpr"); 208 } 209 210 /// Returns true if this is a pure affine expression, i.e., multiplication, 211 /// floordiv, ceildiv, and mod is only allowed w.r.t constants. 212 bool AffineExpr::isPureAffine() const { 213 switch (getKind()) { 214 case AffineExprKind::SymbolId: 215 case AffineExprKind::DimId: 216 case AffineExprKind::Constant: 217 return true; 218 case AffineExprKind::Add: { 219 auto op = llvm::cast<AffineBinaryOpExpr>(*this); 220 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine(); 221 } 222 223 case AffineExprKind::Mul: { 224 // TODO: Canonicalize the constants in binary operators to the RHS when 225 // possible, allowing this to merge into the next case. 226 auto op = llvm::cast<AffineBinaryOpExpr>(*this); 227 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() && 228 (llvm::isa<AffineConstantExpr>(op.getLHS()) || 229 llvm::isa<AffineConstantExpr>(op.getRHS())); 230 } 231 case AffineExprKind::FloorDiv: 232 case AffineExprKind::CeilDiv: 233 case AffineExprKind::Mod: { 234 auto op = llvm::cast<AffineBinaryOpExpr>(*this); 235 return op.getLHS().isPureAffine() && 236 llvm::isa<AffineConstantExpr>(op.getRHS()); 237 } 238 } 239 llvm_unreachable("Unknown AffineExpr"); 240 } 241 242 // Returns the greatest known integral divisor of this affine expression. 243 int64_t AffineExpr::getLargestKnownDivisor() const { 244 AffineBinaryOpExpr binExpr(nullptr); 245 switch (getKind()) { 246 case AffineExprKind::DimId: 247 [[fallthrough]]; 248 case AffineExprKind::SymbolId: 249 return 1; 250 case AffineExprKind::CeilDiv: 251 [[fallthrough]]; 252 case AffineExprKind::FloorDiv: { 253 // If the RHS is a constant and divides the known divisor on the LHS, the 254 // quotient is a known divisor of the expression. 255 binExpr = llvm::cast<AffineBinaryOpExpr>(*this); 256 auto rhs = llvm::dyn_cast<AffineConstantExpr>(binExpr.getRHS()); 257 // Leave alone undefined expressions. 258 if (rhs && rhs.getValue() != 0) { 259 int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor(); 260 if (lhsDiv % rhs.getValue() == 0) 261 return std::abs(lhsDiv / rhs.getValue()); 262 } 263 return 1; 264 } 265 case AffineExprKind::Constant: 266 return std::abs(llvm::cast<AffineConstantExpr>(*this).getValue()); 267 case AffineExprKind::Mul: { 268 binExpr = llvm::cast<AffineBinaryOpExpr>(*this); 269 return binExpr.getLHS().getLargestKnownDivisor() * 270 binExpr.getRHS().getLargestKnownDivisor(); 271 } 272 case AffineExprKind::Add: 273 [[fallthrough]]; 274 case AffineExprKind::Mod: { 275 binExpr = llvm::cast<AffineBinaryOpExpr>(*this); 276 return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(), 277 (uint64_t)binExpr.getRHS().getLargestKnownDivisor()); 278 } 279 } 280 llvm_unreachable("Unknown AffineExpr"); 281 } 282 283 bool AffineExpr::isMultipleOf(int64_t factor) const { 284 AffineBinaryOpExpr binExpr(nullptr); 285 uint64_t l, u; 286 switch (getKind()) { 287 case AffineExprKind::SymbolId: 288 [[fallthrough]]; 289 case AffineExprKind::DimId: 290 return factor * factor == 1; 291 case AffineExprKind::Constant: 292 return llvm::cast<AffineConstantExpr>(*this).getValue() % factor == 0; 293 case AffineExprKind::Mul: { 294 binExpr = llvm::cast<AffineBinaryOpExpr>(*this); 295 // It's probably not worth optimizing this further (to not traverse the 296 // whole sub-tree under - it that would require a version of isMultipleOf 297 // that on a 'false' return also returns the largest known divisor). 298 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 || 299 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 || 300 (l * u) % factor == 0; 301 } 302 case AffineExprKind::Add: 303 case AffineExprKind::FloorDiv: 304 case AffineExprKind::CeilDiv: 305 case AffineExprKind::Mod: { 306 binExpr = llvm::cast<AffineBinaryOpExpr>(*this); 307 return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(), 308 (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) % 309 factor == 310 0; 311 } 312 } 313 llvm_unreachable("Unknown AffineExpr"); 314 } 315 316 bool AffineExpr::isFunctionOfDim(unsigned position) const { 317 if (getKind() == AffineExprKind::DimId) { 318 return *this == mlir::getAffineDimExpr(position, getContext()); 319 } 320 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) { 321 return expr.getLHS().isFunctionOfDim(position) || 322 expr.getRHS().isFunctionOfDim(position); 323 } 324 return false; 325 } 326 327 bool AffineExpr::isFunctionOfSymbol(unsigned position) const { 328 if (getKind() == AffineExprKind::SymbolId) { 329 return *this == mlir::getAffineSymbolExpr(position, getContext()); 330 } 331 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) { 332 return expr.getLHS().isFunctionOfSymbol(position) || 333 expr.getRHS().isFunctionOfSymbol(position); 334 } 335 return false; 336 } 337 338 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr) 339 : AffineExpr(ptr) {} 340 AffineExpr AffineBinaryOpExpr::getLHS() const { 341 return static_cast<ImplType *>(expr)->lhs; 342 } 343 AffineExpr AffineBinaryOpExpr::getRHS() const { 344 return static_cast<ImplType *>(expr)->rhs; 345 } 346 347 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {} 348 unsigned AffineDimExpr::getPosition() const { 349 return static_cast<ImplType *>(expr)->position; 350 } 351 352 /// Returns true if the expression is divisible by the given symbol with 353 /// position `symbolPos`. The argument `opKind` specifies here what kind of 354 /// division or mod operation called this division. It helps in implementing the 355 /// commutative property of the floordiv and ceildiv operations. If the argument 356 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv 357 /// operation, then the commutative property can be used otherwise, the floordiv 358 /// operation is not divisible. The same argument holds for ceildiv operation. 359 static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos, 360 AffineExprKind opKind, 361 bool fromMul = false) { 362 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only. 363 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || 364 opKind == AffineExprKind::CeilDiv) && 365 "unexpected opKind"); 366 switch (expr.getKind()) { 367 case AffineExprKind::Constant: 368 return cast<AffineConstantExpr>(expr).getValue() == 0; 369 case AffineExprKind::DimId: 370 return false; 371 case AffineExprKind::SymbolId: 372 return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos); 373 // Checks divisibility by the given symbol for both operands. 374 case AffineExprKind::Add: { 375 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 376 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, 377 opKind) && 378 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind); 379 } 380 // Checks divisibility by the given symbol for both operands. Consider the 381 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`, 382 // this is a division by s1 and both the operands of modulo are divisible by 383 // s1 but it is not divisible by s1 always. The third argument is 384 // `AffineExprKind::Mod` for this reason. 385 case AffineExprKind::Mod: { 386 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 387 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, 388 AffineExprKind::Mod) && 389 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, 390 AffineExprKind::Mod); 391 } 392 // Checks if any of the operand divisible by the given symbol. 393 case AffineExprKind::Mul: { 394 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 395 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind, 396 true) || 397 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind, 398 true); 399 } 400 // Floordiv and ceildiv are divisible by the given symbol when the first 401 // operand is divisible, and the affine expression kind of the argument expr 402 // is same as the argument `opKind`. This can be inferred from commutative 403 // property of floordiv and ceildiv operations and are as follow: 404 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 405 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 406 // It will fail 1.if operations are not same. For example: 407 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a 408 // multiplication operation in the expression. For example: 409 // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified. 410 case AffineExprKind::FloorDiv: 411 case AffineExprKind::CeilDiv: { 412 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 413 if (opKind != expr.getKind()) 414 return false; 415 if (fromMul) 416 return false; 417 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, 418 expr.getKind()); 419 } 420 } 421 llvm_unreachable("Unknown AffineExpr"); 422 } 423 424 /// Divides the given expression by the given symbol at position `symbolPos`. It 425 /// considers the divisibility condition is checked before calling itself. A 426 /// null expression is returned whenever the divisibility condition fails. 427 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, 428 AffineExprKind opKind) { 429 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only. 430 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || 431 opKind == AffineExprKind::CeilDiv) && 432 "unexpected opKind"); 433 switch (expr.getKind()) { 434 case AffineExprKind::Constant: 435 if (cast<AffineConstantExpr>(expr).getValue() != 0) 436 return nullptr; 437 return getAffineConstantExpr(0, expr.getContext()); 438 case AffineExprKind::DimId: 439 return nullptr; 440 case AffineExprKind::SymbolId: 441 return getAffineConstantExpr(1, expr.getContext()); 442 // Dividing both operands by the given symbol. 443 case AffineExprKind::Add: { 444 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 445 return getAffineBinaryOpExpr( 446 expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind), 447 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind)); 448 } 449 // Dividing both operands by the given symbol. 450 case AffineExprKind::Mod: { 451 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 452 return getAffineBinaryOpExpr( 453 expr.getKind(), 454 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), 455 symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind())); 456 } 457 // Dividing any of the operand by the given symbol. 458 case AffineExprKind::Mul: { 459 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 460 if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind)) 461 return binaryExpr.getLHS() * 462 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind); 463 return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) * 464 binaryExpr.getRHS(); 465 } 466 // Dividing first operand only by the given symbol. 467 case AffineExprKind::FloorDiv: 468 case AffineExprKind::CeilDiv: { 469 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 470 return getAffineBinaryOpExpr( 471 expr.getKind(), 472 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), 473 binaryExpr.getRHS()); 474 } 475 } 476 llvm_unreachable("Unknown AffineExpr"); 477 } 478 479 /// Populate `result` with all summand operands of given (potentially nested) 480 /// addition. If the given expression is not an addition, just populate the 481 /// expression itself. 482 /// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)]. 483 static void getSummandExprs(AffineExpr expr, SmallVector<AffineExpr> &result) { 484 auto addExpr = dyn_cast<AffineBinaryOpExpr>(expr); 485 if (!addExpr || addExpr.getKind() != AffineExprKind::Add) { 486 result.push_back(expr); 487 return; 488 } 489 getSummandExprs(addExpr.getLHS(), result); 490 getSummandExprs(addExpr.getRHS(), result); 491 } 492 493 /// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr). 494 /// If so, also return the non-negated expression via `expr`. 495 static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) { 496 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(candidate); 497 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) 498 return false; 499 if (auto lhs = dyn_cast<AffineConstantExpr>(mulExpr.getLHS())) { 500 if (lhs.getValue() == -1) { 501 expr = mulExpr.getRHS(); 502 return true; 503 } 504 } 505 if (auto rhs = dyn_cast<AffineConstantExpr>(mulExpr.getRHS())) { 506 if (rhs.getValue() == -1) { 507 expr = mulExpr.getLHS(); 508 return true; 509 } 510 } 511 return false; 512 } 513 514 /// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on 515 /// the fact that `lhs` contains another modulo expression that ensures that 516 /// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR 517 /// after loop peeling. 518 /// 519 /// Example: lhs = ub - ub % step 520 /// rhs = step 521 /// => (ub - ub % step) % step is guaranteed to evaluate to 0. 522 static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs, 523 unsigned numDims, unsigned numSymbols) { 524 // TODO: Try to unify this function with `getBoundForAffineExpr`. 525 // Collect all summands in lhs. 526 SmallVector<AffineExpr> summands; 527 getSummandExprs(lhs, summands); 528 // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the 529 // remaining summands, then lhs % rhs is guaranteed to evaluate to 0. 530 for (int64_t i = 0, e = summands.size(); i < e; ++i) { 531 AffineExpr current = summands[i]; 532 AffineExpr beforeNegation; 533 if (!isNegatedAffineExpr(current, beforeNegation)) 534 continue; 535 AffineBinaryOpExpr innerMod = dyn_cast<AffineBinaryOpExpr>(beforeNegation); 536 if (!innerMod || innerMod.getKind() != AffineExprKind::Mod) 537 continue; 538 if (innerMod.getRHS() != rhs) 539 continue; 540 // Sum all remaining summands and subtract x. If that expression can be 541 // simplified to zero, then the remaining summands and x are equal. 542 AffineExpr diff = getAffineConstantExpr(0, lhs.getContext()); 543 for (int64_t j = 0; j < e; ++j) 544 if (i != j) 545 diff = diff + summands[j]; 546 diff = diff - innerMod.getLHS(); 547 diff = simplifyAffineExpr(diff, numDims, numSymbols); 548 auto constExpr = dyn_cast<AffineConstantExpr>(diff); 549 if (constExpr && constExpr.getValue() == 0) 550 return true; 551 } 552 return false; 553 } 554 555 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv 556 /// operations when the second operand simplifies to a symbol and the first 557 /// operand is divisible by that symbol. It can be applied to any semi-affine 558 /// expression. Returned expression can either be a semi-affine or pure affine 559 /// expression. 560 static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims, 561 unsigned numSymbols) { 562 switch (expr.getKind()) { 563 case AffineExprKind::Constant: 564 case AffineExprKind::DimId: 565 case AffineExprKind::SymbolId: 566 return expr; 567 case AffineExprKind::Add: 568 case AffineExprKind::Mul: { 569 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 570 return getAffineBinaryOpExpr( 571 expr.getKind(), 572 simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols), 573 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols)); 574 } 575 // Check if the simplification of the second operand is a symbol, and the 576 // first operand is divisible by it. If the operation is a modulo, a constant 577 // zero expression is returned. In the case of floordiv and ceildiv, the 578 // symbol from the simplification of the second operand divides the first 579 // operand. Otherwise, simplification is not possible. 580 case AffineExprKind::FloorDiv: 581 case AffineExprKind::CeilDiv: 582 case AffineExprKind::Mod: { 583 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 584 AffineExpr sLHS = 585 simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols); 586 AffineExpr sRHS = 587 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols); 588 if (isModOfModSubtraction(sLHS, sRHS, numDims, numSymbols)) 589 return getAffineConstantExpr(0, expr.getContext()); 590 AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>( 591 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols)); 592 if (!symbolExpr) 593 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); 594 unsigned symbolPos = symbolExpr.getPosition(); 595 if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, 596 expr.getKind())) 597 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); 598 if (expr.getKind() == AffineExprKind::Mod) 599 return getAffineConstantExpr(0, expr.getContext()); 600 return symbolicDivide(sLHS, symbolPos, expr.getKind()); 601 } 602 } 603 llvm_unreachable("Unknown AffineExpr"); 604 } 605 606 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, 607 MLIRContext *context) { 608 auto assignCtx = [context](AffineDimExprStorage *storage) { 609 storage->context = context; 610 }; 611 612 StorageUniquer &uniquer = context->getAffineUniquer(); 613 return uniquer.get<AffineDimExprStorage>( 614 assignCtx, static_cast<unsigned>(kind), position); 615 } 616 617 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { 618 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context); 619 } 620 621 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr) 622 : AffineExpr(ptr) {} 623 unsigned AffineSymbolExpr::getPosition() const { 624 return static_cast<ImplType *>(expr)->position; 625 } 626 627 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) { 628 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context); 629 } 630 631 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr) 632 : AffineExpr(ptr) {} 633 int64_t AffineConstantExpr::getValue() const { 634 return static_cast<ImplType *>(expr)->constant; 635 } 636 637 bool AffineExpr::operator==(int64_t v) const { 638 return *this == getAffineConstantExpr(v, getContext()); 639 } 640 641 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { 642 auto assignCtx = [context](AffineConstantExprStorage *storage) { 643 storage->context = context; 644 }; 645 646 StorageUniquer &uniquer = context->getAffineUniquer(); 647 return uniquer.get<AffineConstantExprStorage>(assignCtx, constant); 648 } 649 650 SmallVector<AffineExpr> 651 mlir::getAffineConstantExprs(ArrayRef<int64_t> constants, 652 MLIRContext *context) { 653 return llvm::to_vector(llvm::map_range(constants, [&](int64_t constant) { 654 return getAffineConstantExpr(constant, context); 655 })); 656 } 657 658 /// Simplify add expression. Return nullptr if it can't be simplified. 659 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { 660 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs); 661 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs); 662 // Fold if both LHS, RHS are a constant and the sum does not overflow. 663 if (lhsConst && rhsConst) { 664 int64_t sum; 665 if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) { 666 return nullptr; 667 } 668 return getAffineConstantExpr(sum, lhs.getContext()); 669 } 670 671 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). 672 // If only one of them is a symbolic expressions, make it the RHS. 673 if (isa<AffineConstantExpr>(lhs) || 674 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) { 675 return rhs + lhs; 676 } 677 678 // At this point, if there was a constant, it would be on the right. 679 680 // Addition with a zero is a noop, return the other input. 681 if (rhsConst) { 682 if (rhsConst.getValue() == 0) 683 return lhs; 684 } 685 // Fold successive additions like (d0 + 2) + 3 into d0 + 5. 686 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs); 687 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) { 688 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) 689 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue()); 690 } 691 692 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr". 693 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their 694 // respective multiplicands. 695 std::optional<int64_t> rLhsConst, rRhsConst; 696 AffineExpr firstExpr, secondExpr; 697 AffineConstantExpr rLhsConstExpr; 698 auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lhs); 699 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul && 700 (rLhsConstExpr = dyn_cast<AffineConstantExpr>(lBinOpExpr.getRHS()))) { 701 rLhsConst = rLhsConstExpr.getValue(); 702 firstExpr = lBinOpExpr.getLHS(); 703 } else { 704 rLhsConst = 1; 705 firstExpr = lhs; 706 } 707 708 auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(rhs); 709 AffineConstantExpr rRhsConstExpr; 710 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul && 711 (rRhsConstExpr = dyn_cast<AffineConstantExpr>(rBinOpExpr.getRHS()))) { 712 rRhsConst = rRhsConstExpr.getValue(); 713 secondExpr = rBinOpExpr.getLHS(); 714 } else { 715 rRhsConst = 1; 716 secondExpr = rhs; 717 } 718 719 if (rLhsConst && rRhsConst && firstExpr == secondExpr) 720 return getAffineBinaryOpExpr( 721 AffineExprKind::Mul, firstExpr, 722 getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext())); 723 724 // When doing successive additions, bring constant to the right: turn (d0 + 2) 725 // + d1 into (d0 + d1) + 2. 726 if (lBin && lBin.getKind() == AffineExprKind::Add) { 727 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) { 728 return lBin.getLHS() + rhs + lrhs; 729 } 730 } 731 732 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where 733 // q may be a constant or symbolic expression. This leads to a much more 734 // efficient form when 'c' is a power of two, and in general a more compact 735 // and readable form. 736 737 // Process '(expr floordiv c) * (-c)'. 738 if (!rBinOpExpr) 739 return nullptr; 740 741 auto lrhs = rBinOpExpr.getLHS(); 742 auto rrhs = rBinOpExpr.getRHS(); 743 744 AffineExpr llrhs, rlrhs; 745 746 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a 747 // symbolic expression. 748 auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs); 749 // Check rrhsConstOpExpr = -1. 750 auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rrhs); 751 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr && 752 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) { 753 // Check llrhs = expr floordiv q. 754 llrhs = lrhsBinOpExpr.getLHS(); 755 // Check rlrhs = q. 756 rlrhs = lrhsBinOpExpr.getRHS(); 757 auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(llrhs); 758 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv) 759 return nullptr; 760 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS()) 761 return lhs % rlrhs; 762 } 763 764 // Process lrhs, which is 'expr floordiv c'. 765 // expr + (expr // c * -c) = expr % c 766 AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs); 767 if (!lrBinOpExpr || rhs.getKind() != AffineExprKind::Mul || 768 lrBinOpExpr.getKind() != AffineExprKind::FloorDiv) 769 return nullptr; 770 771 llrhs = lrBinOpExpr.getLHS(); 772 rlrhs = lrBinOpExpr.getRHS(); 773 auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rlrhs); 774 // We don't support modulo with a negative RHS. 775 bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0; 776 777 if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) { 778 return lhs % rlrhs; 779 } 780 return nullptr; 781 } 782 783 AffineExpr AffineExpr::operator+(int64_t v) const { 784 return *this + getAffineConstantExpr(v, getContext()); 785 } 786 AffineExpr AffineExpr::operator+(AffineExpr other) const { 787 if (auto simplified = simplifyAdd(*this, other)) 788 return simplified; 789 790 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 791 return uniquer.get<AffineBinaryOpExprStorage>( 792 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other); 793 } 794 795 /// Simplify a multiply expression. Return nullptr if it can't be simplified. 796 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { 797 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs); 798 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs); 799 800 if (lhsConst && rhsConst) { 801 int64_t product; 802 if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) { 803 return nullptr; 804 } 805 return getAffineConstantExpr(product, lhs.getContext()); 806 } 807 808 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) 809 return nullptr; 810 811 // Canonicalize the mul expression so that the constant/symbolic term is the 812 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a 813 // constant. (Note that a constant is trivially symbolic). 814 if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) { 815 // At least one of them has to be symbolic. 816 return rhs * lhs; 817 } 818 819 // At this point, if there was a constant, it would be on the right. 820 821 // Multiplication with a one is a noop, return the other input. 822 if (rhsConst) { 823 if (rhsConst.getValue() == 1) 824 return lhs; 825 // Multiplication with zero. 826 if (rhsConst.getValue() == 0) 827 return rhsConst; 828 } 829 830 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6. 831 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs); 832 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) { 833 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) 834 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue()); 835 } 836 837 // When doing successive multiplication, bring constant to the right: turn (d0 838 // * 2) * d1 into (d0 * d1) * 2. 839 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 840 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) { 841 return (lBin.getLHS() * rhs) * lrhs; 842 } 843 } 844 845 return nullptr; 846 } 847 848 AffineExpr AffineExpr::operator*(int64_t v) const { 849 return *this * getAffineConstantExpr(v, getContext()); 850 } 851 AffineExpr AffineExpr::operator*(AffineExpr other) const { 852 if (auto simplified = simplifyMul(*this, other)) 853 return simplified; 854 855 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 856 return uniquer.get<AffineBinaryOpExprStorage>( 857 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other); 858 } 859 860 // Unary minus, delegate to operator*. 861 AffineExpr AffineExpr::operator-() const { 862 return *this * getAffineConstantExpr(-1, getContext()); 863 } 864 865 // Delegate to operator+. 866 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); } 867 AffineExpr AffineExpr::operator-(AffineExpr other) const { 868 return *this + (-other); 869 } 870 871 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { 872 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs); 873 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs); 874 875 if (!rhsConst || rhsConst.getValue() == 0) 876 return nullptr; 877 878 if (lhsConst) { 879 if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue())) 880 return nullptr; 881 return getAffineConstantExpr( 882 divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()), 883 lhs.getContext()); 884 } 885 886 // Fold floordiv of a multiply with a constant that is a multiple of the 887 // divisor. Eg: (i * 128) floordiv 64 = i * 2. 888 if (rhsConst == 1) 889 return lhs; 890 891 // Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a 892 // multiple of `rhsConst`. 893 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs); 894 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 895 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) { 896 // `rhsConst` is known to be a nonzero constant. 897 if (lrhs.getValue() % rhsConst.getValue() == 0) 898 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 899 } 900 } 901 902 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is 903 // known to be a multiple of divConst. 904 if (lBin && lBin.getKind() == AffineExprKind::Add) { 905 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); 906 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); 907 // rhsConst is known to be a nonzero constant. 908 if (llhsDiv % rhsConst.getValue() == 0 || 909 lrhsDiv % rhsConst.getValue() == 0) 910 return lBin.getLHS().floorDiv(rhsConst.getValue()) + 911 lBin.getRHS().floorDiv(rhsConst.getValue()); 912 } 913 914 return nullptr; 915 } 916 917 AffineExpr AffineExpr::floorDiv(uint64_t v) const { 918 return floorDiv(getAffineConstantExpr(v, getContext())); 919 } 920 AffineExpr AffineExpr::floorDiv(AffineExpr other) const { 921 if (auto simplified = simplifyFloorDiv(*this, other)) 922 return simplified; 923 924 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 925 return uniquer.get<AffineBinaryOpExprStorage>( 926 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this, 927 other); 928 } 929 930 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { 931 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs); 932 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs); 933 934 if (!rhsConst || rhsConst.getValue() == 0) 935 return nullptr; 936 937 if (lhsConst) { 938 if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue())) 939 return nullptr; 940 return getAffineConstantExpr( 941 divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()), 942 lhs.getContext()); 943 } 944 945 // Fold ceildiv of a multiply with a constant that is a multiple of the 946 // divisor. Eg: (i * 128) ceildiv 64 = i * 2. 947 if (rhsConst.getValue() == 1) 948 return lhs; 949 950 // Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a 951 // multiple of `rhsConst`. 952 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs); 953 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 954 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) { 955 // `rhsConst` is known to be a nonzero constant. 956 if (lrhs.getValue() % rhsConst.getValue() == 0) 957 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 958 } 959 } 960 961 return nullptr; 962 } 963 964 AffineExpr AffineExpr::ceilDiv(uint64_t v) const { 965 return ceilDiv(getAffineConstantExpr(v, getContext())); 966 } 967 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const { 968 if (auto simplified = simplifyCeilDiv(*this, other)) 969 return simplified; 970 971 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 972 return uniquer.get<AffineBinaryOpExprStorage>( 973 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this, 974 other); 975 } 976 977 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { 978 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs); 979 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs); 980 981 // mod w.r.t zero or negative numbers is undefined and preserved as is. 982 if (!rhsConst || rhsConst.getValue() < 1) 983 return nullptr; 984 985 if (lhsConst) { 986 // mod never overflows. 987 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()), 988 lhs.getContext()); 989 } 990 991 // Fold modulo of an expression that is known to be a multiple of a constant 992 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) 993 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0. 994 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0) 995 return getAffineConstantExpr(0, lhs.getContext()); 996 997 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is 998 // known to be a multiple of divConst. 999 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs); 1000 if (lBin && lBin.getKind() == AffineExprKind::Add) { 1001 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); 1002 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); 1003 // rhsConst is known to be a positive constant. 1004 if (llhsDiv % rhsConst.getValue() == 0) 1005 return lBin.getRHS() % rhsConst.getValue(); 1006 if (lrhsDiv % rhsConst.getValue() == 0) 1007 return lBin.getLHS() % rhsConst.getValue(); 1008 } 1009 1010 // Simplify (e % a) % b to e % b when b evenly divides a 1011 if (lBin && lBin.getKind() == AffineExprKind::Mod) { 1012 auto intermediate = dyn_cast<AffineConstantExpr>(lBin.getRHS()); 1013 if (intermediate && intermediate.getValue() >= 1 && 1014 mod(intermediate.getValue(), rhsConst.getValue()) == 0) { 1015 return lBin.getLHS() % rhsConst.getValue(); 1016 } 1017 } 1018 1019 return nullptr; 1020 } 1021 1022 AffineExpr AffineExpr::operator%(uint64_t v) const { 1023 return *this % getAffineConstantExpr(v, getContext()); 1024 } 1025 AffineExpr AffineExpr::operator%(AffineExpr other) const { 1026 if (auto simplified = simplifyMod(*this, other)) 1027 return simplified; 1028 1029 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 1030 return uniquer.get<AffineBinaryOpExprStorage>( 1031 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other); 1032 } 1033 1034 AffineExpr AffineExpr::compose(AffineMap map) const { 1035 SmallVector<AffineExpr, 8> dimReplacements(map.getResults()); 1036 return replaceDimsAndSymbols(dimReplacements, {}); 1037 } 1038 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) { 1039 expr.print(os); 1040 return os; 1041 } 1042 1043 /// Constructs an affine expression from a flat ArrayRef. If there are local 1044 /// identifiers (neither dimensional nor symbolic) that appear in the sum of 1045 /// products expression, `localExprs` is expected to have the AffineExpr 1046 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be 1047 /// in the format [dims, symbols, locals, constant term]. 1048 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs, 1049 unsigned numDims, 1050 unsigned numSymbols, 1051 ArrayRef<AffineExpr> localExprs, 1052 MLIRContext *context) { 1053 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. 1054 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && 1055 "unexpected number of local expressions"); 1056 1057 auto expr = getAffineConstantExpr(0, context); 1058 // Dimensions and symbols. 1059 for (unsigned j = 0; j < numDims + numSymbols; j++) { 1060 if (flatExprs[j] == 0) 1061 continue; 1062 auto id = j < numDims ? getAffineDimExpr(j, context) 1063 : getAffineSymbolExpr(j - numDims, context); 1064 expr = expr + id * flatExprs[j]; 1065 } 1066 1067 // Local identifiers. 1068 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; 1069 j++) { 1070 if (flatExprs[j] == 0) 1071 continue; 1072 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; 1073 expr = expr + term; 1074 } 1075 1076 // Constant term. 1077 int64_t constTerm = flatExprs[flatExprs.size() - 1]; 1078 if (constTerm != 0) 1079 expr = expr + constTerm; 1080 return expr; 1081 } 1082 1083 /// Constructs a semi-affine expression from a flat ArrayRef. If there are 1084 /// local identifiers (neither dimensional nor symbolic) that appear in the sum 1085 /// of products expression, `localExprs` is expected to have the AffineExprs for 1086 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in 1087 /// the format [dims, symbols, locals, constant term]. The semi-affine 1088 /// expression is constructed in the sorted order of dimension and symbol 1089 /// position numbers. Note: local expressions/ids are used for mod, div as well 1090 /// as symbolic RHS terms for terms that are not pure affine. 1091 static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs, 1092 unsigned numDims, 1093 unsigned numSymbols, 1094 ArrayRef<AffineExpr> localExprs, 1095 MLIRContext *context) { 1096 assert(!flatExprs.empty() && "flatExprs cannot be empty"); 1097 1098 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. 1099 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && 1100 "unexpected number of local expressions"); 1101 1102 AffineExpr expr = getAffineConstantExpr(0, context); 1103 1104 // We design indices as a pair which help us present the semi-affine map as 1105 // sum of product where terms are sorted based on dimension or symbol 1106 // position: <keyA, keyB> for expressions of the form dimension * symbol, 1107 // where keyA is the position number of the dimension and keyB is the 1108 // position number of the symbol. For dimensional expressions we set the index 1109 // as (position number of the dimension, -1), as we want dimensional 1110 // expressions to appear before symbolic and product of dimensional and 1111 // symbolic expressions having the dimension with the same position number. 1112 // For symbolic expression set the index as (position number of the symbol, 1113 // maximum of last dimension and symbol position) number. For example, we want 1114 // the expression we are constructing to look something like: d0 + d0 * s0 + 1115 // s0 + d1*s1 + s1. 1116 1117 // Stores the affine expression corresponding to a given index. 1118 DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap; 1119 // Stores the constant coefficient value corresponding to a given 1120 // dimension, symbol or a non-pure affine expression stored in `localExprs`. 1121 DenseMap<std::pair<unsigned, signed>, int64_t> coefficients; 1122 // Stores the indices as defined above, and later sorted to produce 1123 // the semi-affine expression in the desired form. 1124 SmallVector<std::pair<unsigned, signed>, 8> indices; 1125 1126 // Example: expression = d0 + d0 * s0 + 2 * s0. 1127 // indices = [{0,-1}, {0, 0}, {0, 1}] 1128 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}] 1129 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}] 1130 1131 // Adds entries to `indexToExprMap`, `coefficients` and `indices`. 1132 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient, 1133 AffineExpr expr) { 1134 assert(!llvm::is_contained(indices, index) && 1135 "Key is already present in indices vector and overwriting will " 1136 "happen in `indexToExprMap` and `coefficients`!"); 1137 1138 indices.push_back(index); 1139 coefficients.insert({index, coefficient}); 1140 indexToExprMap.insert({index, expr}); 1141 }; 1142 1143 // Design indices for dimensional or symbolic terms, and store the indices, 1144 // constant coefficient corresponding to the indices in `coefficients` map, 1145 // and affine expression corresponding to indices in `indexToExprMap` map. 1146 1147 // Ensure we do not have duplicate keys in `indexToExpr` map. 1148 unsigned offsetSym = 0; 1149 signed offsetDim = -1; 1150 for (unsigned j = numDims; j < numDims + numSymbols; ++j) { 1151 if (flatExprs[j] == 0) 1152 continue; 1153 // For symbolic expression set the index as <position number 1154 // of the symbol, max(dimCount, symCount)> number, 1155 // as we want symbolic expressions with the same positional number to 1156 // appear after dimensional expressions having the same positional number. 1157 std::pair<unsigned, signed> indexEntry( 1158 j - numDims, std::max(numDims, numSymbols) + offsetSym++); 1159 addEntry(indexEntry, flatExprs[j], 1160 getAffineSymbolExpr(j - numDims, context)); 1161 } 1162 1163 // Denotes semi-affine product, modulo or division terms, which has been added 1164 // to the `indexToExpr` map. 1165 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1, 1166 false); 1167 unsigned lhsPos, rhsPos; 1168 // Construct indices for product terms involving dimension, symbol or constant 1169 // as lhs/rhs, and store the indices, constant coefficient corresponding to 1170 // the indices in `coefficients` map, and affine expression corresponding to 1171 // in indices in `indexToExprMap` map. 1172 for (const auto &it : llvm::enumerate(localExprs)) { 1173 AffineExpr expr = it.value(); 1174 if (flatExprs[numDims + numSymbols + it.index()] == 0) 1175 continue; 1176 AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS(); 1177 AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS(); 1178 if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) && 1179 (isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) || 1180 isa<AffineConstantExpr>(rhs)))) { 1181 continue; 1182 } 1183 if (isa<AffineConstantExpr>(rhs)) { 1184 // For product/modulo/division expressions, when rhs of modulo/division 1185 // expression is constant, we put 0 in place of keyB, because we want 1186 // them to appear earlier in the semi-affine expression we are 1187 // constructing. When rhs is constant, we place 0 in place of keyB. 1188 if (isa<AffineDimExpr>(lhs)) { 1189 lhsPos = cast<AffineDimExpr>(lhs).getPosition(); 1190 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--); 1191 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], 1192 expr); 1193 } else { 1194 lhsPos = cast<AffineSymbolExpr>(lhs).getPosition(); 1195 std::pair<unsigned, signed> indexEntry( 1196 lhsPos, std::max(numDims, numSymbols) + offsetSym++); 1197 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], 1198 expr); 1199 } 1200 } else if (isa<AffineDimExpr>(lhs)) { 1201 // For product/modulo/division expressions having lhs as dimension and rhs 1202 // as symbol, we order the terms in the semi-affine expression based on 1203 // the pair: <keyA, keyB> for expressions of the form dimension * symbol, 1204 // where keyA is the position number of the dimension and keyB is the 1205 // position number of the symbol. 1206 lhsPos = cast<AffineDimExpr>(lhs).getPosition(); 1207 rhsPos = cast<AffineSymbolExpr>(rhs).getPosition(); 1208 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos); 1209 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr); 1210 } else { 1211 // For product/modulo/division expressions having both lhs and rhs as 1212 // symbol, we design indices as a pair: <keyA, keyB> for expressions 1213 // of the form dimension * symbol, where keyA is the position number of 1214 // the dimension and keyB is the position number of the symbol. 1215 lhsPos = cast<AffineSymbolExpr>(lhs).getPosition(); 1216 rhsPos = cast<AffineSymbolExpr>(rhs).getPosition(); 1217 std::pair<unsigned, signed> indexEntry( 1218 lhsPos, std::max(numDims, numSymbols) + offsetSym++); 1219 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr); 1220 } 1221 addedToMap[it.index()] = true; 1222 } 1223 1224 for (unsigned j = 0; j < numDims; ++j) { 1225 if (flatExprs[j] == 0) 1226 continue; 1227 // For dimensional expressions we set the index as <position number of the 1228 // dimension, 0>, as we want dimensional expressions to appear before 1229 // symbolic ones and products of dimensional and symbolic expressions 1230 // having the dimension with the same position number. 1231 std::pair<unsigned, signed> indexEntry(j, offsetDim--); 1232 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context)); 1233 } 1234 1235 // Constructing the simplified semi-affine sum of product/division/mod 1236 // expression from the flattened form in the desired sorted order of indices 1237 // of the various individual product/division/mod expressions. 1238 llvm::sort(indices); 1239 for (const std::pair<unsigned, unsigned> index : indices) { 1240 assert(indexToExprMap.lookup(index) && 1241 "cannot find key in `indexToExprMap` map"); 1242 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index); 1243 } 1244 1245 // Local identifiers. 1246 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; 1247 j++) { 1248 // If the coefficient of the local expression is 0, continue as we need not 1249 // add it in out final expression. 1250 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols]) 1251 continue; 1252 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; 1253 expr = expr + term; 1254 } 1255 1256 // Constant term. 1257 int64_t constTerm = flatExprs.back(); 1258 if (constTerm != 0) 1259 expr = expr + constTerm; 1260 return expr; 1261 } 1262 1263 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, 1264 unsigned numSymbols) 1265 : numDims(numDims), numSymbols(numSymbols), numLocals(0) { 1266 operandExprStack.reserve(8); 1267 } 1268 1269 // In pure affine t = expr * c, we multiply each coefficient of lhs with c. 1270 // 1271 // In case of semi affine multiplication expressions, t = expr * symbolic_expr, 1272 // introduce a local variable p (= expr * symbolic_expr), and the affine 1273 // expression expr * symbolic_expr is added to `localExprs`. 1274 LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { 1275 assert(operandExprStack.size() >= 2); 1276 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1277 operandExprStack.pop_back(); 1278 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1279 1280 // Flatten semi-affine multiplication expressions by introducing a local 1281 // variable in place of the product; the affine expression 1282 // corresponding to the quantifier is added to `localExprs`. 1283 if (!isa<AffineConstantExpr>(expr.getRHS())) { 1284 SmallVector<int64_t, 8> mulLhs(lhs); 1285 MLIRContext *context = expr.getContext(); 1286 AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, 1287 localExprs, context); 1288 AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1289 localExprs, context); 1290 return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size()); 1291 } 1292 1293 // Get the RHS constant. 1294 int64_t rhsConst = rhs[getConstantIndex()]; 1295 for (int64_t &lhsElt : lhs) 1296 lhsElt *= rhsConst; 1297 1298 return success(); 1299 } 1300 1301 LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { 1302 assert(operandExprStack.size() >= 2); 1303 const auto &rhs = operandExprStack.back(); 1304 auto &lhs = operandExprStack[operandExprStack.size() - 2]; 1305 assert(lhs.size() == rhs.size()); 1306 // Update the LHS in place. 1307 for (unsigned i = 0, e = rhs.size(); i < e; i++) { 1308 lhs[i] += rhs[i]; 1309 } 1310 // Pop off the RHS. 1311 operandExprStack.pop_back(); 1312 return success(); 1313 } 1314 1315 // 1316 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 1317 // 1318 // A mod expression "expr mod c" is thus flattened by introducing a new local 1319 // variable q (= expr floordiv c), such that expr mod c is replaced with 1320 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. 1321 // 1322 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr, 1323 // introduce a local variable m (= expr mod symbolic_expr), and the affine 1324 // expression expr mod symbolic_expr is added to `localExprs`. 1325 LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { 1326 assert(operandExprStack.size() >= 2); 1327 1328 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1329 operandExprStack.pop_back(); 1330 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1331 MLIRContext *context = expr.getContext(); 1332 1333 // Flatten semi affine modulo expressions by introducing a local 1334 // variable in place of the modulo value, and the affine expression 1335 // corresponding to the quantifier is added to `localExprs`. 1336 if (!isa<AffineConstantExpr>(expr.getRHS())) { 1337 SmallVector<int64_t, 8> modLhs(lhs); 1338 AffineExpr dividendExpr = getAffineExprFromFlatForm( 1339 lhs, numDims, numSymbols, localExprs, context); 1340 AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1341 localExprs, context); 1342 AffineExpr modExpr = dividendExpr % divisorExpr; 1343 return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size()); 1344 } 1345 1346 int64_t rhsConst = rhs[getConstantIndex()]; 1347 if (rhsConst <= 0) 1348 return failure(); 1349 1350 // Check if the LHS expression is a multiple of modulo factor. 1351 unsigned i, e; 1352 for (i = 0, e = lhs.size(); i < e; i++) 1353 if (lhs[i] % rhsConst != 0) 1354 break; 1355 // If yes, modulo expression here simplifies to zero. 1356 if (i == lhs.size()) { 1357 std::fill(lhs.begin(), lhs.end(), 0); 1358 return success(); 1359 } 1360 1361 // Add a local variable for the quotient, i.e., expr % c is replaced by 1362 // (expr - q * c) where q = expr floordiv c. Do this while canceling out 1363 // the GCD of expr and c. 1364 SmallVector<int64_t, 8> floorDividend(lhs); 1365 uint64_t gcd = rhsConst; 1366 for (int64_t lhsElt : lhs) 1367 gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt)); 1368 // Simplify the numerator and the denominator. 1369 if (gcd != 1) { 1370 for (int64_t &floorDividendElt : floorDividend) 1371 floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd); 1372 } 1373 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd); 1374 1375 // Construct the AffineExpr form of the floordiv to store in localExprs. 1376 1377 AffineExpr dividendExpr = getAffineExprFromFlatForm( 1378 floorDividend, numDims, numSymbols, localExprs, context); 1379 AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context); 1380 AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr); 1381 int loc; 1382 if ((loc = findLocalId(floorDivExpr)) == -1) { 1383 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); 1384 // Set result at top of stack to "lhs - rhsConst * q". 1385 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; 1386 } else { 1387 // Reuse the existing local id. 1388 lhs[getLocalVarStartIndex() + loc] -= rhsConst; 1389 } 1390 return success(); 1391 } 1392 1393 LogicalResult 1394 SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) { 1395 return visitDivExpr(expr, /*isCeil=*/true); 1396 } 1397 LogicalResult 1398 SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) { 1399 return visitDivExpr(expr, /*isCeil=*/false); 1400 } 1401 1402 LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) { 1403 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1404 auto &eq = operandExprStack.back(); 1405 assert(expr.getPosition() < numDims && "Inconsistent number of dims"); 1406 eq[getDimStartIndex() + expr.getPosition()] = 1; 1407 return success(); 1408 } 1409 1410 LogicalResult 1411 SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) { 1412 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1413 auto &eq = operandExprStack.back(); 1414 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); 1415 eq[getSymbolStartIndex() + expr.getPosition()] = 1; 1416 return success(); 1417 } 1418 1419 LogicalResult 1420 SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) { 1421 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1422 auto &eq = operandExprStack.back(); 1423 eq[getConstantIndex()] = expr.getValue(); 1424 return success(); 1425 } 1426 1427 LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine( 1428 ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr, 1429 SmallVectorImpl<int64_t> &result, unsigned long resultSize) { 1430 assert(result.size() == resultSize && 1431 "`result` vector passed is not of correct size"); 1432 int loc; 1433 if ((loc = findLocalId(localExpr)) == -1) { 1434 if (failed(addLocalIdSemiAffine(lhs, rhs, localExpr))) 1435 return failure(); 1436 } 1437 std::fill(result.begin(), result.end(), 0); 1438 if (loc == -1) 1439 result[getLocalVarStartIndex() + numLocals - 1] = 1; 1440 else 1441 result[getLocalVarStartIndex() + loc] = 1; 1442 return success(); 1443 } 1444 1445 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 1446 // A floordiv is thus flattened by introducing a new local variable q, and 1447 // replacing that expression with 'q' while adding the constraints 1448 // c * q <= expr <= c * q + c - 1 to localVarCst (done by 1449 // IntegerRelation::addLocalFloorDiv). 1450 // 1451 // A ceildiv is similarly flattened: 1452 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c 1453 // 1454 // In case of semi affine division expressions, t = expr floordiv symbolic_expr 1455 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr 1456 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to 1457 // `localExprs`. 1458 LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, 1459 bool isCeil) { 1460 assert(operandExprStack.size() >= 2); 1461 1462 MLIRContext *context = expr.getContext(); 1463 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1464 operandExprStack.pop_back(); 1465 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1466 1467 // Flatten semi affine division expressions by introducing a local 1468 // variable in place of the quotient, and the affine expression corresponding 1469 // to the quantifier is added to `localExprs`. 1470 if (!isa<AffineConstantExpr>(expr.getRHS())) { 1471 SmallVector<int64_t, 8> divLhs(lhs); 1472 AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, 1473 localExprs, context); 1474 AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1475 localExprs, context); 1476 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); 1477 return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size()); 1478 } 1479 1480 // This is a pure affine expr; the RHS is a positive constant. 1481 int64_t rhsConst = rhs[getConstantIndex()]; 1482 if (rhsConst <= 0) 1483 return failure(); 1484 1485 // Simplify the floordiv, ceildiv if possible by canceling out the greatest 1486 // common divisors of the numerator and denominator. 1487 uint64_t gcd = std::abs(rhsConst); 1488 for (int64_t lhsElt : lhs) 1489 gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt)); 1490 // Simplify the numerator and the denominator. 1491 if (gcd != 1) { 1492 for (int64_t &lhsElt : lhs) 1493 lhsElt = lhsElt / static_cast<int64_t>(gcd); 1494 } 1495 int64_t divisor = rhsConst / static_cast<int64_t>(gcd); 1496 // If the divisor becomes 1, the updated LHS is the result. (The 1497 // divisor can't be negative since rhsConst is positive). 1498 if (divisor == 1) 1499 return success(); 1500 1501 // If the divisor cannot be simplified to one, we will have to retain 1502 // the ceil/floor expr (simplified up until here). Add an existential 1503 // quantifier to express its result, i.e., expr1 div expr2 is replaced 1504 // by a new identifier, q. 1505 AffineExpr a = 1506 getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context); 1507 AffineExpr b = getAffineConstantExpr(divisor, context); 1508 1509 int loc; 1510 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); 1511 if ((loc = findLocalId(divExpr)) == -1) { 1512 if (!isCeil) { 1513 SmallVector<int64_t, 8> dividend(lhs); 1514 addLocalFloorDivId(dividend, divisor, divExpr); 1515 } else { 1516 // lhs ceildiv c <=> (lhs + c - 1) floordiv c 1517 SmallVector<int64_t, 8> dividend(lhs); 1518 dividend.back() += divisor - 1; 1519 addLocalFloorDivId(dividend, divisor, divExpr); 1520 } 1521 } 1522 // Set the expression on stack to the local var introduced to capture the 1523 // result of the division (floor or ceil). 1524 std::fill(lhs.begin(), lhs.end(), 0); 1525 if (loc == -1) 1526 lhs[getLocalVarStartIndex() + numLocals - 1] = 1; 1527 else 1528 lhs[getLocalVarStartIndex() + loc] = 1; 1529 return success(); 1530 } 1531 1532 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 1533 // The local identifier added is always a floordiv of a pure add/mul affine 1534 // function of other identifiers, coefficients of which are specified in 1535 // dividend and with respect to a positive constant divisor. localExpr is the 1536 // simplified tree expression (AffineExpr) corresponding to the quantifier. 1537 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend, 1538 int64_t divisor, 1539 AffineExpr localExpr) { 1540 assert(divisor > 0 && "positive constant divisor expected"); 1541 for (SmallVector<int64_t, 8> &subExpr : operandExprStack) 1542 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); 1543 localExprs.push_back(localExpr); 1544 numLocals++; 1545 // dividend and divisor are not used here; an override of this method uses it. 1546 } 1547 1548 LogicalResult SimpleAffineExprFlattener::addLocalIdSemiAffine( 1549 ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) { 1550 for (SmallVector<int64_t, 8> &subExpr : operandExprStack) 1551 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); 1552 localExprs.push_back(localExpr); 1553 ++numLocals; 1554 // lhs and rhs are not used here; an override of this method uses them. 1555 return success(); 1556 } 1557 1558 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) { 1559 SmallVectorImpl<AffineExpr>::iterator it; 1560 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end()) 1561 return -1; 1562 return it - localExprs.begin(); 1563 } 1564 1565 /// Simplify the affine expression by flattening it and reconstructing it. 1566 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, 1567 unsigned numSymbols) { 1568 // Simplify semi-affine expressions separately. 1569 if (!expr.isPureAffine()) 1570 expr = simplifySemiAffine(expr, numDims, numSymbols); 1571 1572 SimpleAffineExprFlattener flattener(numDims, numSymbols); 1573 // has poison expression 1574 if (failed(flattener.walkPostOrder(expr))) 1575 return expr; 1576 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back(); 1577 if (!expr.isPureAffine() && 1578 expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1579 flattener.localExprs, 1580 expr.getContext())) 1581 return expr; 1582 AffineExpr simplifiedExpr = 1583 expr.isPureAffine() 1584 ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1585 flattener.localExprs, expr.getContext()) 1586 : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1587 flattener.localExprs, 1588 expr.getContext()); 1589 1590 flattener.operandExprStack.pop_back(); 1591 assert(flattener.operandExprStack.empty()); 1592 return simplifiedExpr; 1593 } 1594 1595 std::optional<int64_t> mlir::getBoundForAffineExpr( 1596 AffineExpr expr, unsigned numDims, unsigned numSymbols, 1597 ArrayRef<std::optional<int64_t>> constLowerBounds, 1598 ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) { 1599 // Handle divs and mods. 1600 if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) { 1601 // If the LHS of a floor or ceil is bounded and the RHS is a constant, we 1602 // can compute an upper bound. 1603 if (binOpExpr.getKind() == AffineExprKind::FloorDiv) { 1604 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS()); 1605 if (!rhsConst || rhsConst.getValue() < 1) 1606 return std::nullopt; 1607 auto bound = 1608 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols, 1609 constLowerBounds, constUpperBounds, isUpper); 1610 if (!bound) 1611 return std::nullopt; 1612 return divideFloorSigned(*bound, rhsConst.getValue()); 1613 } 1614 if (binOpExpr.getKind() == AffineExprKind::CeilDiv) { 1615 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS()); 1616 if (rhsConst && rhsConst.getValue() >= 1) { 1617 auto bound = 1618 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols, 1619 constLowerBounds, constUpperBounds, isUpper); 1620 if (!bound) 1621 return std::nullopt; 1622 return divideCeilSigned(*bound, rhsConst.getValue()); 1623 } 1624 return std::nullopt; 1625 } 1626 if (binOpExpr.getKind() == AffineExprKind::Mod) { 1627 // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is 1628 // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c 1629 // (same "interval"), then lb mod c <= lhs mod c <= ub mod c. 1630 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS()); 1631 if (rhsConst && rhsConst.getValue() >= 1) { 1632 int64_t rhsConstVal = rhsConst.getValue(); 1633 auto lb = getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols, 1634 constLowerBounds, constUpperBounds, 1635 /*isUpper=*/false); 1636 auto ub = 1637 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols, 1638 constLowerBounds, constUpperBounds, isUpper); 1639 if (ub && lb && 1640 divideFloorSigned(*lb, rhsConstVal) == 1641 divideFloorSigned(*ub, rhsConstVal)) 1642 return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal); 1643 return isUpper ? rhsConstVal - 1 : 0; 1644 } 1645 } 1646 } 1647 // Flatten the expression. 1648 SimpleAffineExprFlattener flattener(numDims, numSymbols); 1649 auto simpleResult = flattener.walkPostOrder(expr); 1650 // has poison expression 1651 if (failed(simpleResult)) 1652 return std::nullopt; 1653 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back(); 1654 // TODO: Handle local variables. We can get hold of flattener.localExprs and 1655 // get bound on the local expr recursively. 1656 if (flattener.numLocals > 0) 1657 return std::nullopt; 1658 int64_t bound = 0; 1659 // Substitute the constant lower or upper bound for the dimensional or 1660 // symbolic input depending on `isUpper` to determine the bound. 1661 for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) { 1662 if (flattenedExpr[i] > 0) { 1663 auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i]; 1664 if (!constBound) 1665 return std::nullopt; 1666 bound += *constBound * flattenedExpr[i]; 1667 } else if (flattenedExpr[i] < 0) { 1668 auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i]; 1669 if (!constBound) 1670 return std::nullopt; 1671 bound += *constBound * flattenedExpr[i]; 1672 } 1673 } 1674 // Constant term. 1675 bound += flattenedExpr.back(); 1676 return bound; 1677 } 1678