1 //===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 10 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 14 #include "mlir/Interfaces/ViewLikeInterface.h" 15 #include "llvm/ADT/APSInt.h" 16 #include "llvm/Support/Debug.h" 17 18 #define DEBUG_TYPE "value-bounds-op-interface" 19 20 using namespace mlir; 21 using presburger::BoundType; 22 using presburger::VarKind; 23 24 namespace mlir { 25 #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc" 26 } // namespace mlir 27 28 static Operation *getOwnerOfValue(Value value) { 29 if (auto bbArg = dyn_cast<BlockArgument>(value)) 30 return bbArg.getOwner()->getParentOp(); 31 return value.getDefiningOp(); 32 } 33 34 HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets, 35 ArrayRef<OpFoldResult> sizes, 36 ArrayRef<OpFoldResult> strides) 37 : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) { 38 assert(offsets.size() == sizes.size() && 39 "expected same number of offsets, sizes, strides"); 40 assert(offsets.size() == strides.size() && 41 "expected same number of offsets, sizes, strides"); 42 } 43 44 HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets, 45 ArrayRef<OpFoldResult> sizes) 46 : mixedOffsets(offsets), mixedSizes(sizes) { 47 assert(offsets.size() == sizes.size() && 48 "expected same number of offsets and sizes"); 49 // Assume that all strides are 1. 50 if (offsets.empty()) 51 return; 52 MLIRContext *ctx = offsets.front().getContext(); 53 mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1)); 54 } 55 56 HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op) 57 : HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(), 58 op.getMixedStrides()) {} 59 60 /// If ofr is a constant integer or an IntegerAttr, return the integer. 61 static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) { 62 // Case 1: Check for Constant integer. 63 if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) { 64 APSInt intVal; 65 if (matchPattern(val, m_ConstantInt(&intVal))) 66 return intVal.getSExtValue(); 67 return std::nullopt; 68 } 69 // Case 2: Check for IntegerAttr. 70 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr); 71 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr)) 72 return intAttr.getValue().getSExtValue(); 73 return std::nullopt; 74 } 75 76 ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr) 77 : Variable(ofr, std::nullopt) {} 78 79 ValueBoundsConstraintSet::Variable::Variable(Value indexValue) 80 : Variable(static_cast<OpFoldResult>(indexValue)) {} 81 82 ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim) 83 : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {} 84 85 ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr, 86 std::optional<int64_t> dim) { 87 Builder b(ofr.getContext()); 88 if (auto constInt = ::getConstantIntValue(ofr)) { 89 assert(!dim && "expected no dim for index-typed values"); 90 map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, 91 b.getAffineConstantExpr(*constInt)); 92 return; 93 } 94 Value value = cast<Value>(ofr); 95 #ifndef NDEBUG 96 if (dim) { 97 assert(isa<ShapedType>(value.getType()) && "expected shaped type"); 98 } else { 99 assert(value.getType().isIndex() && "expected index type"); 100 } 101 #endif // NDEBUG 102 map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, 103 b.getAffineSymbolExpr(0)); 104 mapOperands.emplace_back(value, dim); 105 } 106 107 ValueBoundsConstraintSet::Variable::Variable(AffineMap map, 108 ArrayRef<Variable> mapOperands) { 109 assert(map.getNumResults() == 1 && "expected single result"); 110 111 // Turn all dims into symbols. 112 Builder b(map.getContext()); 113 SmallVector<AffineExpr> dimReplacements, symReplacements; 114 for (int64_t i = 0, e = map.getNumDims(); i < e; ++i) 115 dimReplacements.push_back(b.getAffineSymbolExpr(i)); 116 for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i) 117 symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims())); 118 AffineMap tmpMap = map.replaceDimsAndSymbols( 119 dimReplacements, symReplacements, /*numResultDims=*/0, 120 /*numResultSyms=*/map.getNumSymbols() + map.getNumDims()); 121 122 // Inline operands. 123 DenseMap<AffineExpr, AffineExpr> replacements; 124 for (auto [index, var] : llvm::enumerate(mapOperands)) { 125 assert(var.map.getNumResults() == 1 && "expected single result"); 126 assert(var.map.getNumDims() == 0 && "expected only symbols"); 127 SmallVector<AffineExpr> symReplacements; 128 for (auto valueDim : var.mapOperands) { 129 auto it = llvm::find(this->mapOperands, valueDim); 130 if (it != this->mapOperands.end()) { 131 // There is already a symbol for this operand. 132 symReplacements.push_back(b.getAffineSymbolExpr( 133 std::distance(this->mapOperands.begin(), it))); 134 } else { 135 // This is a new operand: add a new symbol. 136 symReplacements.push_back( 137 b.getAffineSymbolExpr(this->mapOperands.size())); 138 this->mapOperands.push_back(valueDim); 139 } 140 } 141 replacements[b.getAffineSymbolExpr(index)] = 142 var.map.getResult(0).replaceSymbols(symReplacements); 143 } 144 this->map = tmpMap.replace(replacements, /*numResultDims=*/0, 145 /*numResultSyms=*/this->mapOperands.size()); 146 } 147 148 ValueBoundsConstraintSet::Variable::Variable(AffineMap map, 149 ArrayRef<Value> mapOperands) 150 : Variable(map, llvm::map_to_vector(mapOperands, 151 [](Value v) { return Variable(v); })) {} 152 153 ValueBoundsConstraintSet::ValueBoundsConstraintSet( 154 MLIRContext *ctx, StopConditionFn stopCondition, 155 bool addConservativeSemiAffineBounds) 156 : builder(ctx), stopCondition(stopCondition), 157 addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) { 158 assert(stopCondition && "expected non-null stop condition"); 159 } 160 161 char ValueBoundsConstraintSet::ID = 0; 162 163 #ifndef NDEBUG 164 static void assertValidValueDim(Value value, std::optional<int64_t> dim) { 165 if (value.getType().isIndex()) { 166 assert(!dim.has_value() && "invalid dim value"); 167 } else if (auto shapedType = dyn_cast<ShapedType>(value.getType())) { 168 assert(*dim >= 0 && "invalid dim value"); 169 if (shapedType.hasRank()) 170 assert(*dim < shapedType.getRank() && "invalid dim value"); 171 } else { 172 llvm_unreachable("unsupported type"); 173 } 174 } 175 #endif // NDEBUG 176 177 void ValueBoundsConstraintSet::addBound(BoundType type, int64_t pos, 178 AffineExpr expr) { 179 // Note: If `addConservativeSemiAffineBounds` is true then the bound 180 // computation function needs to handle the case that the constraints set 181 // could become empty. This is because the conservative bounds add assumptions 182 // (e.g. for `mod` it assumes `rhs > 0`). If these constraints are later found 183 // not to hold, then the bound is invalid. 184 LogicalResult status = cstr.addBound( 185 type, pos, 186 AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr), 187 addConservativeSemiAffineBounds 188 ? FlatLinearConstraints::AddConservativeSemiAffineBounds::Yes 189 : FlatLinearConstraints::AddConservativeSemiAffineBounds::No); 190 if (failed(status)) { 191 // Not all semi-affine expressions are not yet supported by 192 // FlatLinearConstraints. However, we can just ignore such failures here. 193 // Even without this bound, there may be enough information in the 194 // constraint system to compute the requested bound. In case this bound is 195 // actually needed, `computeBound` will return `failure`. 196 LLVM_DEBUG(llvm::dbgs() << "Failed to add bound: " << expr << "\n"); 197 } 198 } 199 200 AffineExpr ValueBoundsConstraintSet::getExpr(Value value, 201 std::optional<int64_t> dim) { 202 #ifndef NDEBUG 203 assertValidValueDim(value, dim); 204 #endif // NDEBUG 205 206 // Check if the value/dim is statically known. In that case, an affine 207 // constant expression should be returned. This allows us to support 208 // multiplications with constants. (Multiplications of two columns in the 209 // constraint set is not supported.) 210 std::optional<int64_t> constSize = std::nullopt; 211 auto shapedType = dyn_cast<ShapedType>(value.getType()); 212 if (shapedType) { 213 if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) 214 constSize = shapedType.getDimSize(*dim); 215 } else if (auto constInt = ::getConstantIntValue(value)) { 216 constSize = *constInt; 217 } 218 219 // If the value/dim is already mapped, return the corresponding expression 220 // directly. 221 ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); 222 if (valueDimToPosition.contains(valueDim)) { 223 // If it is a constant, return an affine constant expression. Otherwise, 224 // return an affine expression that represents the respective column in the 225 // constraint set. 226 if (constSize) 227 return builder.getAffineConstantExpr(*constSize); 228 return getPosExpr(getPos(value, dim)); 229 } 230 231 if (constSize) { 232 // Constant index value/dim: add column to the constraint set, add EQ bound 233 // and return an affine constant expression without pushing the newly added 234 // column to the worklist. 235 (void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false); 236 if (shapedType) 237 bound(value)[*dim] == *constSize; 238 else 239 bound(value) == *constSize; 240 return builder.getAffineConstantExpr(*constSize); 241 } 242 243 // Dynamic value/dim: insert column to the constraint set and put it on the 244 // worklist. Return an affine expression that represents the newly inserted 245 // column in the constraint set. 246 return getPosExpr(insert(value, dim, /*isSymbol=*/true)); 247 } 248 249 AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) { 250 if (Value value = llvm::dyn_cast_if_present<Value>(ofr)) 251 return getExpr(value, /*dim=*/std::nullopt); 252 auto constInt = ::getConstantIntValue(ofr); 253 assert(constInt.has_value() && "expected Integer constant"); 254 return builder.getAffineConstantExpr(*constInt); 255 } 256 257 AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) { 258 return builder.getAffineConstantExpr(constant); 259 } 260 261 int64_t ValueBoundsConstraintSet::insert(Value value, 262 std::optional<int64_t> dim, 263 bool isSymbol, bool addToWorklist) { 264 #ifndef NDEBUG 265 assertValidValueDim(value, dim); 266 #endif // NDEBUG 267 268 ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); 269 assert(!valueDimToPosition.contains(valueDim) && "already mapped"); 270 int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) 271 : cstr.appendVar(VarKind::SetDim); 272 LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos 273 << " for: " << value 274 << " (dim: " << dim.value_or(kIndexValue) 275 << ", owner: " << getOwnerOfValue(value)->getName() 276 << ")\n"); 277 positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim); 278 // Update reverse mapping. 279 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) 280 if (positionToValueDim[i].has_value()) 281 valueDimToPosition[*positionToValueDim[i]] = i; 282 283 if (addToWorklist) { 284 LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value 285 << " (dim: " << dim.value_or(kIndexValue) << ")\n"); 286 worklist.push(pos); 287 } 288 289 return pos; 290 } 291 292 int64_t ValueBoundsConstraintSet::insert(bool isSymbol) { 293 int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) 294 : cstr.appendVar(VarKind::SetDim); 295 LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos 296 << "\n"); 297 positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt); 298 // Update reverse mapping. 299 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) 300 if (positionToValueDim[i].has_value()) 301 valueDimToPosition[*positionToValueDim[i]] = i; 302 return pos; 303 } 304 305 int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands, 306 bool isSymbol) { 307 assert(map.getNumResults() == 1 && "expected affine map with one result"); 308 int64_t pos = insert(isSymbol); 309 310 // Add map and operands to the constraint set. Dimensions are converted to 311 // symbols. All operands are added to the worklist (unless they were already 312 // processed). 313 auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) { 314 return getExpr(v.first, v.second); 315 }; 316 SmallVector<AffineExpr> dimReplacements = llvm::to_vector( 317 llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper)); 318 SmallVector<AffineExpr> symReplacements = llvm::to_vector( 319 llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper)); 320 addBound( 321 presburger::BoundType::EQ, pos, 322 map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements)); 323 324 return pos; 325 } 326 327 int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) { 328 return insert(var.map, var.mapOperands, isSymbol); 329 } 330 331 int64_t ValueBoundsConstraintSet::getPos(Value value, 332 std::optional<int64_t> dim) const { 333 #ifndef NDEBUG 334 assertValidValueDim(value, dim); 335 assert((isa<OpResult>(value) || 336 cast<BlockArgument>(value).getOwner()->isEntryBlock()) && 337 "unstructured control flow is not supported"); 338 #endif // NDEBUG 339 LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value 340 << " (dim: " << dim.value_or(kIndexValue) 341 << ", owner: " << getOwnerOfValue(value)->getName() 342 << ")\n"); 343 auto it = 344 valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue))); 345 assert(it != valueDimToPosition.end() && "expected mapped entry"); 346 return it->second; 347 } 348 349 AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) { 350 assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position"); 351 return pos < cstr.getNumDimVars() 352 ? builder.getAffineDimExpr(pos) 353 : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); 354 } 355 356 bool ValueBoundsConstraintSet::isMapped(Value value, 357 std::optional<int64_t> dim) const { 358 auto it = 359 valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue))); 360 return it != valueDimToPosition.end(); 361 } 362 363 void ValueBoundsConstraintSet::processWorklist() { 364 LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n"); 365 while (!worklist.empty()) { 366 int64_t pos = worklist.front(); 367 worklist.pop(); 368 assert(positionToValueDim[pos].has_value() && 369 "did not expect std::nullopt on worklist"); 370 ValueDim valueDim = *positionToValueDim[pos]; 371 Value value = valueDim.first; 372 int64_t dim = valueDim.second; 373 374 // Check for static dim size. 375 if (dim != kIndexValue) { 376 auto shapedType = cast<ShapedType>(value.getType()); 377 if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) { 378 bound(value)[dim] == getExpr(shapedType.getDimSize(dim)); 379 continue; 380 } 381 } 382 383 // Do not process any further if the stop condition is met. 384 auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim); 385 if (stopCondition(value, maybeDim, *this)) { 386 LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value 387 << " (dim: " << maybeDim << ")\n"); 388 continue; 389 } 390 391 // Query `ValueBoundsOpInterface` for constraints. New items may be added to 392 // the worklist. 393 auto valueBoundsOp = 394 dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value)); 395 LLVM_DEBUG(llvm::dbgs() 396 << "Query value bounds for: " << value 397 << " (owner: " << getOwnerOfValue(value)->getName() << ")\n"); 398 if (valueBoundsOp) { 399 if (dim == kIndexValue) { 400 valueBoundsOp.populateBoundsForIndexValue(value, *this); 401 } else { 402 valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this); 403 } 404 continue; 405 } 406 LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n"); 407 408 // If the op does not implement `ValueBoundsOpInterface`, check if it 409 // implements the `DestinationStyleOpInterface`. OpResults of such ops are 410 // tied to OpOperands. Tied values have the same shape. 411 auto dstOp = value.getDefiningOp<DestinationStyleOpInterface>(); 412 if (!dstOp || dim == kIndexValue) 413 continue; 414 Value tiedOperand = dstOp.getTiedOpOperand(cast<OpResult>(value))->get(); 415 bound(value)[dim] == getExpr(tiedOperand, dim); 416 } 417 } 418 419 void ValueBoundsConstraintSet::projectOut(int64_t pos) { 420 assert(pos >= 0 && pos < static_cast<int64_t>(positionToValueDim.size()) && 421 "invalid position"); 422 cstr.projectOut(pos); 423 if (positionToValueDim[pos].has_value()) { 424 bool erased = valueDimToPosition.erase(*positionToValueDim[pos]); 425 (void)erased; 426 assert(erased && "inconsistent reverse mapping"); 427 } 428 positionToValueDim.erase(positionToValueDim.begin() + pos); 429 // Update reverse mapping. 430 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) 431 if (positionToValueDim[i].has_value()) 432 valueDimToPosition[*positionToValueDim[i]] = i; 433 } 434 435 void ValueBoundsConstraintSet::projectOut( 436 function_ref<bool(ValueDim)> condition) { 437 int64_t nextPos = 0; 438 while (nextPos < static_cast<int64_t>(positionToValueDim.size())) { 439 if (positionToValueDim[nextPos].has_value() && 440 condition(*positionToValueDim[nextPos])) { 441 projectOut(nextPos); 442 // The column was projected out so another column is now at that position. 443 // Do not increase the counter. 444 } else { 445 ++nextPos; 446 } 447 } 448 } 449 450 void ValueBoundsConstraintSet::projectOutAnonymous( 451 std::optional<int64_t> except) { 452 int64_t nextPos = 0; 453 while (nextPos < static_cast<int64_t>(positionToValueDim.size())) { 454 if (positionToValueDim[nextPos].has_value() || except == nextPos) { 455 ++nextPos; 456 } else { 457 projectOut(nextPos); 458 // The column was projected out so another column is now at that position. 459 // Do not increase the counter. 460 } 461 } 462 } 463 464 LogicalResult ValueBoundsConstraintSet::computeBound( 465 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, 466 const Variable &var, StopConditionFn stopCondition, bool closedUB) { 467 MLIRContext *ctx = var.getContext(); 468 int64_t ubAdjustment = closedUB ? 0 : 1; 469 Builder b(ctx); 470 mapOperands.clear(); 471 472 // Process the backward slice of `value` (i.e., reverse use-def chain) until 473 // `stopCondition` is met. 474 ValueBoundsConstraintSet cstr(ctx, stopCondition); 475 int64_t pos = cstr.insert(var, /*isSymbol=*/false); 476 assert(pos == 0 && "expected first column"); 477 cstr.processWorklist(); 478 479 // Project out all variables (apart from `valueDim`) that do not match the 480 // stop condition. 481 cstr.projectOut([&](ValueDim p) { 482 auto maybeDim = 483 p.second == kIndexValue ? std::nullopt : std::make_optional(p.second); 484 return !stopCondition(p.first, maybeDim, cstr); 485 }); 486 cstr.projectOutAnonymous(/*except=*/pos); 487 488 // Compute lower and upper bounds for `valueDim`. 489 SmallVector<AffineMap> lb(1), ub(1); 490 cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub, 491 /*closedUB=*/true); 492 493 // Note: There are TODOs in the implementation of `getSliceBounds`. In such a 494 // case, no lower/upper bound can be computed at the moment. 495 // EQ, UB bounds: upper bound is needed. 496 if ((type != BoundType::LB) && 497 (ub.empty() || !ub[0] || ub[0].getNumResults() == 0)) 498 return failure(); 499 // EQ, LB bounds: lower bound is needed. 500 if ((type != BoundType::UB) && 501 (lb.empty() || !lb[0] || lb[0].getNumResults() == 0)) 502 return failure(); 503 504 // TODO: Generate an affine map with multiple results. 505 if (type != BoundType::LB) 506 assert(ub.size() == 1 && ub[0].getNumResults() == 1 && 507 "multiple bounds not supported"); 508 if (type != BoundType::UB) 509 assert(lb.size() == 1 && lb[0].getNumResults() == 1 && 510 "multiple bounds not supported"); 511 512 // EQ bound: lower and upper bound must match. 513 if (type == BoundType::EQ && ub[0] != lb[0]) 514 return failure(); 515 516 AffineMap bound; 517 if (type == BoundType::EQ || type == BoundType::LB) { 518 bound = lb[0]; 519 } else { 520 // Computed UB is a closed bound. 521 bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(), 522 ub[0].getResult(0) + ubAdjustment); 523 } 524 525 // Gather all SSA values that are used in the computed bound. 526 assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() && 527 "inconsistent mapping state"); 528 SmallVector<AffineExpr> replacementDims, replacementSymbols; 529 int64_t numDims = 0, numSymbols = 0; 530 for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) { 531 // Skip `value`. 532 if (i == pos) 533 continue; 534 // Check if the position `i` is used in the generated bound. If so, it must 535 // be included in the generated affine.apply op. 536 bool used = false; 537 bool isDim = i < cstr.cstr.getNumDimVars(); 538 if (isDim) { 539 if (bound.isFunctionOfDim(i)) 540 used = true; 541 } else { 542 if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) 543 used = true; 544 } 545 546 if (!used) { 547 // Not used: Remove dim/symbol from the result. 548 if (isDim) { 549 replacementDims.push_back(b.getAffineConstantExpr(0)); 550 } else { 551 replacementSymbols.push_back(b.getAffineConstantExpr(0)); 552 } 553 continue; 554 } 555 556 if (isDim) { 557 replacementDims.push_back(b.getAffineDimExpr(numDims++)); 558 } else { 559 replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++)); 560 } 561 562 assert(cstr.positionToValueDim[i].has_value() && 563 "cannot build affine map in terms of anonymous column"); 564 ValueBoundsConstraintSet::ValueDim valueDim = *cstr.positionToValueDim[i]; 565 Value value = valueDim.first; 566 int64_t dim = valueDim.second; 567 if (dim == ValueBoundsConstraintSet::kIndexValue) { 568 // An index-type value is used: can be used directly in the affine.apply 569 // op. 570 assert(value.getType().isIndex() && "expected index type"); 571 mapOperands.push_back(std::make_pair(value, std::nullopt)); 572 continue; 573 } 574 575 assert(cast<ShapedType>(value.getType()).isDynamicDim(dim) && 576 "expected dynamic dim"); 577 mapOperands.push_back(std::make_pair(value, dim)); 578 } 579 580 resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols, 581 numDims, numSymbols); 582 return success(); 583 } 584 585 LogicalResult ValueBoundsConstraintSet::computeDependentBound( 586 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, 587 const Variable &var, ValueDimList dependencies, bool closedUB) { 588 return computeBound( 589 resultMap, mapOperands, type, var, 590 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { 591 return llvm::is_contained(dependencies, std::make_pair(v, d)); 592 }, 593 closedUB); 594 } 595 596 LogicalResult ValueBoundsConstraintSet::computeIndependentBound( 597 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, 598 const Variable &var, ValueRange independencies, bool closedUB) { 599 // Return "true" if the given value is independent of all values in 600 // `independencies`. I.e., neither the value itself nor any value in the 601 // backward slice (reverse use-def chain) is contained in `independencies`. 602 auto isIndependent = [&](Value v) { 603 SmallVector<Value> worklist; 604 DenseSet<Value> visited; 605 worklist.push_back(v); 606 while (!worklist.empty()) { 607 Value next = worklist.pop_back_val(); 608 if (!visited.insert(next).second) 609 continue; 610 if (llvm::is_contained(independencies, next)) 611 return false; 612 // TODO: DominanceInfo could be used to stop the traversal early. 613 Operation *op = next.getDefiningOp(); 614 if (!op) 615 continue; 616 worklist.append(op->getOperands().begin(), op->getOperands().end()); 617 } 618 return true; 619 }; 620 621 // Reify bounds in terms of any independent values. 622 return computeBound( 623 resultMap, mapOperands, type, var, 624 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { 625 return isIndependent(v); 626 }, 627 closedUB); 628 } 629 630 FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( 631 presburger::BoundType type, const Variable &var, 632 StopConditionFn stopCondition, bool closedUB) { 633 // Default stop condition if none was specified: Keep adding constraints until 634 // a bound could be computed. 635 int64_t pos = 0; 636 auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim, 637 ValueBoundsConstraintSet &cstr) { 638 return cstr.cstr.getConstantBound64(type, pos).has_value(); 639 }; 640 641 ValueBoundsConstraintSet cstr( 642 var.getContext(), stopCondition ? stopCondition : defaultStopCondition); 643 pos = cstr.populateConstraints(var.map, var.mapOperands); 644 assert(pos == 0 && "expected `map` is the first column"); 645 646 // Compute constant bound for `valueDim`. 647 int64_t ubAdjustment = closedUB ? 0 : 1; 648 if (auto bound = cstr.cstr.getConstantBound64(type, pos)) 649 return type == BoundType::UB ? *bound + ubAdjustment : *bound; 650 return failure(); 651 } 652 653 void ValueBoundsConstraintSet::populateConstraints(Value value, 654 std::optional<int64_t> dim) { 655 #ifndef NDEBUG 656 assertValidValueDim(value, dim); 657 #endif // NDEBUG 658 659 // `getExpr` pushes the value/dim onto the worklist (unless it was already 660 // analyzed). 661 (void)getExpr(value, dim); 662 // Process all values/dims on the worklist. This may traverse and analyze 663 // additional IR, depending the current stop function. 664 processWorklist(); 665 } 666 667 int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map, 668 ValueDimList operands) { 669 int64_t pos = insert(map, operands, /*isSymbol=*/false); 670 // Process the backward slice of `operands` (i.e., reverse use-def chain) 671 // until `stopCondition` is met. 672 processWorklist(); 673 return pos; 674 } 675 676 FailureOr<int64_t> 677 ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2, 678 std::optional<int64_t> dim1, 679 std::optional<int64_t> dim2) { 680 #ifndef NDEBUG 681 assertValidValueDim(value1, dim1); 682 assertValidValueDim(value2, dim2); 683 #endif // NDEBUG 684 685 Builder b(value1.getContext()); 686 AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, 687 b.getAffineDimExpr(0) - b.getAffineDimExpr(1)); 688 return computeConstantBound(presburger::BoundType::EQ, 689 Variable(map, {{value1, dim1}, {value2, dim2}})); 690 } 691 692 bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, 693 ComparisonOperator cmp, 694 int64_t rhsPos) { 695 // This function returns "true" if "lhs CMP rhs" is proven to hold. 696 // 697 // Example for ComparisonOperator::LE and index-typed values: We would like to 698 // prove that lhs <= rhs. Proof by contradiction: add the inverse 699 // relation (lhs > rhs) to the constraint set and check if the resulting 700 // constraint set is "empty" (i.e. has no solution). In that case, 701 // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds. 702 703 // We cannot prove anything if the constraint set is already empty. 704 if (cstr.isEmpty()) { 705 LLVM_DEBUG( 706 llvm::dbgs() 707 << "cannot compare value/dims: constraint system is already empty"); 708 return false; 709 } 710 711 // EQ can be expressed as LE and GE. 712 if (cmp == EQ) 713 return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) && 714 comparePos(lhsPos, ComparisonOperator::GE, rhsPos); 715 716 // Construct inequality. 717 SmallVector<int64_t> eq(cstr.getNumCols(), 0); 718 if (cmp == LT || cmp == LE) { 719 ++eq[lhsPos]; 720 --eq[rhsPos]; 721 } else if (cmp == GT || cmp == GE) { 722 --eq[lhsPos]; 723 ++eq[rhsPos]; 724 } else { 725 llvm_unreachable("unsupported comparison operator"); 726 } 727 if (cmp == LE || cmp == GE) 728 eq[cstr.getNumCols() - 1] -= 1; 729 730 // Add inequality to the constraint set and check if it made the constraint 731 // set empty. 732 int64_t ineqPos = cstr.getNumInequalities(); 733 cstr.addInequality(eq); 734 bool isEmpty = cstr.isEmpty(); 735 cstr.removeInequality(ineqPos); 736 return isEmpty; 737 } 738 739 bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs, 740 ComparisonOperator cmp, 741 const Variable &rhs) { 742 int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands); 743 int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands); 744 return comparePos(lhsPos, cmp, rhsPos); 745 } 746 747 bool ValueBoundsConstraintSet::compare(const Variable &lhs, 748 ComparisonOperator cmp, 749 const Variable &rhs) { 750 int64_t lhsPos = -1, rhsPos = -1; 751 auto stopCondition = [&](Value v, std::optional<int64_t> dim, 752 ValueBoundsConstraintSet &cstr) { 753 // Keep processing as long as lhs/rhs were not processed. 754 if (size_t(lhsPos) >= cstr.positionToValueDim.size() || 755 size_t(rhsPos) >= cstr.positionToValueDim.size()) 756 return false; 757 // Keep processing as long as the relation cannot be proven. 758 return cstr.comparePos(lhsPos, cmp, rhsPos); 759 }; 760 ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); 761 lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands); 762 rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands); 763 return cstr.comparePos(lhsPos, cmp, rhsPos); 764 } 765 766 FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1, 767 const Variable &var2) { 768 if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2)) 769 return true; 770 if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) || 771 ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2)) 772 return false; 773 return failure(); 774 } 775 776 FailureOr<bool> 777 ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx, 778 HyperrectangularSlice slice1, 779 HyperrectangularSlice slice2) { 780 assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() && 781 "expected slices of same rank"); 782 assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() && 783 "expected slices of same rank"); 784 assert(slice1.getMixedStrides().size() == slice2.getMixedStrides().size() && 785 "expected slices of same rank"); 786 787 Builder b(ctx); 788 bool foundUnknownBound = false; 789 for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) { 790 AffineMap map = 791 AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4, 792 b.getAffineSymbolExpr(0) + 793 b.getAffineSymbolExpr(1) * b.getAffineSymbolExpr(2) - 794 b.getAffineSymbolExpr(3)); 795 { 796 // Case 1: Slices are guaranteed to be non-overlapping if 797 // offset1 + size1 * stride1 <= offset2 (for at least one dimension). 798 SmallVector<OpFoldResult> ofrOperands; 799 ofrOperands.push_back(slice1.getMixedOffsets()[i]); 800 ofrOperands.push_back(slice1.getMixedSizes()[i]); 801 ofrOperands.push_back(slice1.getMixedStrides()[i]); 802 ofrOperands.push_back(slice2.getMixedOffsets()[i]); 803 SmallVector<Value> valueOperands; 804 AffineMap foldedMap = 805 foldAttributesIntoMap(b, map, ofrOperands, valueOperands); 806 FailureOr<int64_t> constBound = computeConstantBound( 807 presburger::BoundType::EQ, Variable(foldedMap, valueOperands)); 808 foundUnknownBound |= failed(constBound); 809 if (succeeded(constBound) && *constBound <= 0) 810 return false; 811 } 812 { 813 // Case 2: Slices are guaranteed to be non-overlapping if 814 // offset2 + size2 * stride2 <= offset1 (for at least one dimension). 815 SmallVector<OpFoldResult> ofrOperands; 816 ofrOperands.push_back(slice2.getMixedOffsets()[i]); 817 ofrOperands.push_back(slice2.getMixedSizes()[i]); 818 ofrOperands.push_back(slice2.getMixedStrides()[i]); 819 ofrOperands.push_back(slice1.getMixedOffsets()[i]); 820 SmallVector<Value> valueOperands; 821 AffineMap foldedMap = 822 foldAttributesIntoMap(b, map, ofrOperands, valueOperands); 823 FailureOr<int64_t> constBound = computeConstantBound( 824 presburger::BoundType::EQ, Variable(foldedMap, valueOperands)); 825 foundUnknownBound |= failed(constBound); 826 if (succeeded(constBound) && *constBound <= 0) 827 return false; 828 } 829 } 830 831 // If at least one bound could not be computed, we cannot be certain that the 832 // slices are really overlapping. 833 if (foundUnknownBound) 834 return failure(); 835 836 // All bounds could be computed and none of the above cases applied. 837 // Therefore, the slices are guaranteed to overlap. 838 return true; 839 } 840 841 FailureOr<bool> 842 ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx, 843 HyperrectangularSlice slice1, 844 HyperrectangularSlice slice2) { 845 assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() && 846 "expected slices of same rank"); 847 assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() && 848 "expected slices of same rank"); 849 assert(slice1.getMixedStrides().size() == slice2.getMixedStrides().size() && 850 "expected slices of same rank"); 851 852 // The two slices are equivalent if all of their offsets, sizes and strides 853 // are equal. If equality cannot be determined for at least one of those 854 // values, equivalence cannot be determined and this function returns 855 // "failure". 856 for (auto [offset1, offset2] : 857 llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) { 858 FailureOr<bool> equal = areEqual(offset1, offset2); 859 if (failed(equal)) 860 return failure(); 861 if (!equal.value()) 862 return false; 863 } 864 for (auto [size1, size2] : 865 llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) { 866 FailureOr<bool> equal = areEqual(size1, size2); 867 if (failed(equal)) 868 return failure(); 869 if (!equal.value()) 870 return false; 871 } 872 for (auto [stride1, stride2] : 873 llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) { 874 FailureOr<bool> equal = areEqual(stride1, stride2); 875 if (failed(equal)) 876 return failure(); 877 if (!equal.value()) 878 return false; 879 } 880 return true; 881 } 882 883 void ValueBoundsConstraintSet::dump() const { 884 llvm::errs() << "==========\nColumns:\n"; 885 llvm::errs() << "(column\tdim\tvalue)\n"; 886 for (auto [index, valueDim] : llvm::enumerate(positionToValueDim)) { 887 llvm::errs() << " " << index << "\t"; 888 if (valueDim) { 889 if (valueDim->second == kIndexValue) { 890 llvm::errs() << "n/a\t"; 891 } else { 892 llvm::errs() << valueDim->second << "\t"; 893 } 894 llvm::errs() << getOwnerOfValue(valueDim->first)->getName() << " "; 895 if (OpResult result = dyn_cast<OpResult>(valueDim->first)) { 896 llvm::errs() << "(result " << result.getResultNumber() << ")"; 897 } else { 898 llvm::errs() << "(bbarg " 899 << cast<BlockArgument>(valueDim->first).getArgNumber() 900 << ")"; 901 } 902 llvm::errs() << "\n"; 903 } else { 904 llvm::errs() << "n/a\tn/a\n"; 905 } 906 } 907 llvm::errs() << "\nConstraint set:\n"; 908 cstr.dump(); 909 llvm::errs() << "==========\n"; 910 } 911 912 ValueBoundsConstraintSet::BoundBuilder & 913 ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) { 914 assert(!this->dim.has_value() && "dim was already set"); 915 this->dim = dim; 916 #ifndef NDEBUG 917 assertValidValueDim(value, this->dim); 918 #endif // NDEBUG 919 return *this; 920 } 921 922 void ValueBoundsConstraintSet::BoundBuilder::operator<(AffineExpr expr) { 923 #ifndef NDEBUG 924 assertValidValueDim(value, this->dim); 925 #endif // NDEBUG 926 cstr.addBound(BoundType::UB, cstr.getPos(value, this->dim), expr); 927 } 928 929 void ValueBoundsConstraintSet::BoundBuilder::operator<=(AffineExpr expr) { 930 operator<(expr + 1); 931 } 932 933 void ValueBoundsConstraintSet::BoundBuilder::operator>(AffineExpr expr) { 934 operator>=(expr + 1); 935 } 936 937 void ValueBoundsConstraintSet::BoundBuilder::operator>=(AffineExpr expr) { 938 #ifndef NDEBUG 939 assertValidValueDim(value, this->dim); 940 #endif // NDEBUG 941 cstr.addBound(BoundType::LB, cstr.getPos(value, this->dim), expr); 942 } 943 944 void ValueBoundsConstraintSet::BoundBuilder::operator==(AffineExpr expr) { 945 #ifndef NDEBUG 946 assertValidValueDim(value, this->dim); 947 #endif // NDEBUG 948 cstr.addBound(BoundType::EQ, cstr.getPos(value, this->dim), expr); 949 } 950 951 void ValueBoundsConstraintSet::BoundBuilder::operator<(OpFoldResult ofr) { 952 operator<(cstr.getExpr(ofr)); 953 } 954 955 void ValueBoundsConstraintSet::BoundBuilder::operator<=(OpFoldResult ofr) { 956 operator<=(cstr.getExpr(ofr)); 957 } 958 959 void ValueBoundsConstraintSet::BoundBuilder::operator>(OpFoldResult ofr) { 960 operator>(cstr.getExpr(ofr)); 961 } 962 963 void ValueBoundsConstraintSet::BoundBuilder::operator>=(OpFoldResult ofr) { 964 operator>=(cstr.getExpr(ofr)); 965 } 966 967 void ValueBoundsConstraintSet::BoundBuilder::operator==(OpFoldResult ofr) { 968 operator==(cstr.getExpr(ofr)); 969 } 970 971 void ValueBoundsConstraintSet::BoundBuilder::operator<(int64_t i) { 972 operator<(cstr.getExpr(i)); 973 } 974 975 void ValueBoundsConstraintSet::BoundBuilder::operator<=(int64_t i) { 976 operator<=(cstr.getExpr(i)); 977 } 978 979 void ValueBoundsConstraintSet::BoundBuilder::operator>(int64_t i) { 980 operator>(cstr.getExpr(i)); 981 } 982 983 void ValueBoundsConstraintSet::BoundBuilder::operator>=(int64_t i) { 984 operator>=(cstr.getExpr(i)); 985 } 986 987 void ValueBoundsConstraintSet::BoundBuilder::operator==(int64_t i) { 988 operator==(cstr.getExpr(i)); 989 } 990