18c885658SMatthias Springer //===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===// 28c885658SMatthias Springer // 38c885658SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48c885658SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 58c885658SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68c885658SMatthias Springer // 78c885658SMatthias Springer //===----------------------------------------------------------------------===// 88c885658SMatthias Springer 98c885658SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h" 108c885658SMatthias Springer 118c885658SMatthias Springer #include "mlir/IR/BuiltinTypes.h" 128c885658SMatthias Springer #include "mlir/IR/Matchers.h" 13c5624dc0SMatthias Springer #include "mlir/Interfaces/DestinationStyleOpInterface.h" 141abd8d1aSMatthias Springer #include "mlir/Interfaces/ViewLikeInterface.h" 158c885658SMatthias Springer #include "llvm/ADT/APSInt.h" 16d8804ecdSMatthias Springer #include "llvm/Support/Debug.h" 17d8804ecdSMatthias Springer 18d8804ecdSMatthias Springer #define DEBUG_TYPE "value-bounds-op-interface" 198c885658SMatthias Springer 208c885658SMatthias Springer using namespace mlir; 218c885658SMatthias Springer using presburger::BoundType; 228c885658SMatthias Springer using presburger::VarKind; 238c885658SMatthias Springer 248c885658SMatthias Springer namespace mlir { 258c885658SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc" 268c885658SMatthias Springer } // namespace mlir 278c885658SMatthias Springer 2840dd3aa9SMatthias Springer static Operation *getOwnerOfValue(Value value) { 2940dd3aa9SMatthias Springer if (auto bbArg = dyn_cast<BlockArgument>(value)) 3040dd3aa9SMatthias Springer return bbArg.getOwner()->getParentOp(); 3140dd3aa9SMatthias Springer return value.getDefiningOp(); 3240dd3aa9SMatthias Springer } 3340dd3aa9SMatthias Springer 34ff614a57SMatthias Springer HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets, 35ff614a57SMatthias Springer ArrayRef<OpFoldResult> sizes, 36ff614a57SMatthias Springer ArrayRef<OpFoldResult> strides) 37ff614a57SMatthias Springer : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) { 38ff614a57SMatthias Springer assert(offsets.size() == sizes.size() && 39ff614a57SMatthias Springer "expected same number of offsets, sizes, strides"); 40ff614a57SMatthias Springer assert(offsets.size() == strides.size() && 41ff614a57SMatthias Springer "expected same number of offsets, sizes, strides"); 42ff614a57SMatthias Springer } 43ff614a57SMatthias Springer 44ff614a57SMatthias Springer HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets, 45ff614a57SMatthias Springer ArrayRef<OpFoldResult> sizes) 46ff614a57SMatthias Springer : mixedOffsets(offsets), mixedSizes(sizes) { 47ff614a57SMatthias Springer assert(offsets.size() == sizes.size() && 48ff614a57SMatthias Springer "expected same number of offsets and sizes"); 49ff614a57SMatthias Springer // Assume that all strides are 1. 50ff614a57SMatthias Springer if (offsets.empty()) 51ff614a57SMatthias Springer return; 52ff614a57SMatthias Springer MLIRContext *ctx = offsets.front().getContext(); 53ff614a57SMatthias Springer mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1)); 54ff614a57SMatthias Springer } 55ff614a57SMatthias Springer 56ff614a57SMatthias Springer HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op) 57ff614a57SMatthias Springer : HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(), 58ff614a57SMatthias Springer op.getMixedStrides()) {} 59ff614a57SMatthias Springer 608c885658SMatthias Springer /// If ofr is a constant integer or an IntegerAttr, return the integer. 618c885658SMatthias Springer static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) { 628c885658SMatthias Springer // Case 1: Check for Constant integer. 6368f58812STres Popp if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) { 648c885658SMatthias Springer APSInt intVal; 658c885658SMatthias Springer if (matchPattern(val, m_ConstantInt(&intVal))) 668c885658SMatthias Springer return intVal.getSExtValue(); 678c885658SMatthias Springer return std::nullopt; 688c885658SMatthias Springer } 698c885658SMatthias Springer // Case 2: Check for IntegerAttr. 7068f58812STres Popp Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr); 715550c821STres Popp if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr)) 728c885658SMatthias Springer return intAttr.getValue().getSExtValue(); 738c885658SMatthias Springer return std::nullopt; 748c885658SMatthias Springer } 758c885658SMatthias Springer 7640dd3aa9SMatthias Springer ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr) 7740dd3aa9SMatthias Springer : Variable(ofr, std::nullopt) {} 7840dd3aa9SMatthias Springer 7940dd3aa9SMatthias Springer ValueBoundsConstraintSet::Variable::Variable(Value indexValue) 8040dd3aa9SMatthias Springer : Variable(static_cast<OpFoldResult>(indexValue)) {} 8140dd3aa9SMatthias Springer 8240dd3aa9SMatthias Springer ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim) 8340dd3aa9SMatthias Springer : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {} 8440dd3aa9SMatthias Springer 8540dd3aa9SMatthias Springer ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr, 8640dd3aa9SMatthias Springer std::optional<int64_t> dim) { 8740dd3aa9SMatthias Springer Builder b(ofr.getContext()); 8840dd3aa9SMatthias Springer if (auto constInt = ::getConstantIntValue(ofr)) { 8940dd3aa9SMatthias Springer assert(!dim && "expected no dim for index-typed values"); 9040dd3aa9SMatthias Springer map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, 9140dd3aa9SMatthias Springer b.getAffineConstantExpr(*constInt)); 9240dd3aa9SMatthias Springer return; 9340dd3aa9SMatthias Springer } 9440dd3aa9SMatthias Springer Value value = cast<Value>(ofr); 9540dd3aa9SMatthias Springer #ifndef NDEBUG 9640dd3aa9SMatthias Springer if (dim) { 9740dd3aa9SMatthias Springer assert(isa<ShapedType>(value.getType()) && "expected shaped type"); 9840dd3aa9SMatthias Springer } else { 9940dd3aa9SMatthias Springer assert(value.getType().isIndex() && "expected index type"); 10040dd3aa9SMatthias Springer } 10140dd3aa9SMatthias Springer #endif // NDEBUG 10240dd3aa9SMatthias Springer map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, 10340dd3aa9SMatthias Springer b.getAffineSymbolExpr(0)); 10440dd3aa9SMatthias Springer mapOperands.emplace_back(value, dim); 10540dd3aa9SMatthias Springer } 10640dd3aa9SMatthias Springer 10740dd3aa9SMatthias Springer ValueBoundsConstraintSet::Variable::Variable(AffineMap map, 10840dd3aa9SMatthias Springer ArrayRef<Variable> mapOperands) { 10940dd3aa9SMatthias Springer assert(map.getNumResults() == 1 && "expected single result"); 11040dd3aa9SMatthias Springer 11140dd3aa9SMatthias Springer // Turn all dims into symbols. 11240dd3aa9SMatthias Springer Builder b(map.getContext()); 11340dd3aa9SMatthias Springer SmallVector<AffineExpr> dimReplacements, symReplacements; 11440dd3aa9SMatthias Springer for (int64_t i = 0, e = map.getNumDims(); i < e; ++i) 11540dd3aa9SMatthias Springer dimReplacements.push_back(b.getAffineSymbolExpr(i)); 11640dd3aa9SMatthias Springer for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i) 11740dd3aa9SMatthias Springer symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims())); 11840dd3aa9SMatthias Springer AffineMap tmpMap = map.replaceDimsAndSymbols( 11940dd3aa9SMatthias Springer dimReplacements, symReplacements, /*numResultDims=*/0, 12040dd3aa9SMatthias Springer /*numResultSyms=*/map.getNumSymbols() + map.getNumDims()); 12140dd3aa9SMatthias Springer 12240dd3aa9SMatthias Springer // Inline operands. 12340dd3aa9SMatthias Springer DenseMap<AffineExpr, AffineExpr> replacements; 12440dd3aa9SMatthias Springer for (auto [index, var] : llvm::enumerate(mapOperands)) { 12540dd3aa9SMatthias Springer assert(var.map.getNumResults() == 1 && "expected single result"); 12640dd3aa9SMatthias Springer assert(var.map.getNumDims() == 0 && "expected only symbols"); 12740dd3aa9SMatthias Springer SmallVector<AffineExpr> symReplacements; 12840dd3aa9SMatthias Springer for (auto valueDim : var.mapOperands) { 12940dd3aa9SMatthias Springer auto it = llvm::find(this->mapOperands, valueDim); 13040dd3aa9SMatthias Springer if (it != this->mapOperands.end()) { 13140dd3aa9SMatthias Springer // There is already a symbol for this operand. 13240dd3aa9SMatthias Springer symReplacements.push_back(b.getAffineSymbolExpr( 13340dd3aa9SMatthias Springer std::distance(this->mapOperands.begin(), it))); 13440dd3aa9SMatthias Springer } else { 13540dd3aa9SMatthias Springer // This is a new operand: add a new symbol. 13640dd3aa9SMatthias Springer symReplacements.push_back( 13740dd3aa9SMatthias Springer b.getAffineSymbolExpr(this->mapOperands.size())); 13840dd3aa9SMatthias Springer this->mapOperands.push_back(valueDim); 13940dd3aa9SMatthias Springer } 14040dd3aa9SMatthias Springer } 14140dd3aa9SMatthias Springer replacements[b.getAffineSymbolExpr(index)] = 14240dd3aa9SMatthias Springer var.map.getResult(0).replaceSymbols(symReplacements); 14340dd3aa9SMatthias Springer } 14440dd3aa9SMatthias Springer this->map = tmpMap.replace(replacements, /*numResultDims=*/0, 14540dd3aa9SMatthias Springer /*numResultSyms=*/this->mapOperands.size()); 14640dd3aa9SMatthias Springer } 14740dd3aa9SMatthias Springer 14840dd3aa9SMatthias Springer ValueBoundsConstraintSet::Variable::Variable(AffineMap map, 14940dd3aa9SMatthias Springer ArrayRef<Value> mapOperands) 15040dd3aa9SMatthias Springer : Variable(map, llvm::map_to_vector(mapOperands, 15140dd3aa9SMatthias Springer [](Value v) { return Variable(v); })) {} 15240dd3aa9SMatthias Springer 1535e4a4438SMatthias Springer ValueBoundsConstraintSet::ValueBoundsConstraintSet( 15429a925abSBenjamin Maxwell MLIRContext *ctx, StopConditionFn stopCondition, 15529a925abSBenjamin Maxwell bool addConservativeSemiAffineBounds) 15629a925abSBenjamin Maxwell : builder(ctx), stopCondition(stopCondition), 15729a925abSBenjamin Maxwell addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) { 1585e4a4438SMatthias Springer assert(stopCondition && "expected non-null stop condition"); 1595e4a4438SMatthias Springer } 1608c885658SMatthias Springer 1612861856bSBenjamin Maxwell char ValueBoundsConstraintSet::ID = 0; 1622861856bSBenjamin Maxwell 1638c885658SMatthias Springer #ifndef NDEBUG 1648c885658SMatthias Springer static void assertValidValueDim(Value value, std::optional<int64_t> dim) { 1658c885658SMatthias Springer if (value.getType().isIndex()) { 1668c885658SMatthias Springer assert(!dim.has_value() && "invalid dim value"); 1678c885658SMatthias Springer } else if (auto shapedType = dyn_cast<ShapedType>(value.getType())) { 1688c885658SMatthias Springer assert(*dim >= 0 && "invalid dim value"); 1698c885658SMatthias Springer if (shapedType.hasRank()) 1708c885658SMatthias Springer assert(*dim < shapedType.getRank() && "invalid dim value"); 1718c885658SMatthias Springer } else { 1728c885658SMatthias Springer llvm_unreachable("unsupported type"); 1738c885658SMatthias Springer } 1748c885658SMatthias Springer } 1758c885658SMatthias Springer #endif // NDEBUG 1768c885658SMatthias Springer 1778c885658SMatthias Springer void ValueBoundsConstraintSet::addBound(BoundType type, int64_t pos, 1788c885658SMatthias Springer AffineExpr expr) { 17929a925abSBenjamin Maxwell // Note: If `addConservativeSemiAffineBounds` is true then the bound 18029a925abSBenjamin Maxwell // computation function needs to handle the case that the constraints set 18129a925abSBenjamin Maxwell // could become empty. This is because the conservative bounds add assumptions 18229a925abSBenjamin Maxwell // (e.g. for `mod` it assumes `rhs > 0`). If these constraints are later found 18329a925abSBenjamin Maxwell // not to hold, then the bound is invalid. 1848c885658SMatthias Springer LogicalResult status = cstr.addBound( 1858c885658SMatthias Springer type, pos, 18629a925abSBenjamin Maxwell AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr), 18729a925abSBenjamin Maxwell addConservativeSemiAffineBounds 18829a925abSBenjamin Maxwell ? FlatLinearConstraints::AddConservativeSemiAffineBounds::Yes 18929a925abSBenjamin Maxwell : FlatLinearConstraints::AddConservativeSemiAffineBounds::No); 190d8804ecdSMatthias Springer if (failed(status)) { 19129a925abSBenjamin Maxwell // Not all semi-affine expressions are not yet supported by 192d8804ecdSMatthias Springer // FlatLinearConstraints. However, we can just ignore such failures here. 193d8804ecdSMatthias Springer // Even without this bound, there may be enough information in the 194d8804ecdSMatthias Springer // constraint system to compute the requested bound. In case this bound is 195d8804ecdSMatthias Springer // actually needed, `computeBound` will return `failure`. 196d8804ecdSMatthias Springer LLVM_DEBUG(llvm::dbgs() << "Failed to add bound: " << expr << "\n"); 197d8804ecdSMatthias Springer } 1988c885658SMatthias Springer } 1998c885658SMatthias Springer 2008c885658SMatthias Springer AffineExpr ValueBoundsConstraintSet::getExpr(Value value, 2018c885658SMatthias Springer std::optional<int64_t> dim) { 2028c885658SMatthias Springer #ifndef NDEBUG 2038c885658SMatthias Springer assertValidValueDim(value, dim); 2048c885658SMatthias Springer #endif // NDEBUG 2058c885658SMatthias Springer 20676435f2dSMatthias Springer // Check if the value/dim is statically known. In that case, an affine 20776435f2dSMatthias Springer // constant expression should be returned. This allows us to support 20876435f2dSMatthias Springer // multiplications with constants. (Multiplications of two columns in the 20976435f2dSMatthias Springer // constraint set is not supported.) 21076435f2dSMatthias Springer std::optional<int64_t> constSize = std::nullopt; 2118c885658SMatthias Springer auto shapedType = dyn_cast<ShapedType>(value.getType()); 2128c885658SMatthias Springer if (shapedType) { 2138c885658SMatthias Springer if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) 21476435f2dSMatthias Springer constSize = shapedType.getDimSize(*dim); 21576435f2dSMatthias Springer } else if (auto constInt = ::getConstantIntValue(value)) { 21676435f2dSMatthias Springer constSize = *constInt; 2178c885658SMatthias Springer } 2188c885658SMatthias Springer 21976435f2dSMatthias Springer // If the value/dim is already mapped, return the corresponding expression 22076435f2dSMatthias Springer // directly. 2218c885658SMatthias Springer ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); 22276435f2dSMatthias Springer if (valueDimToPosition.contains(valueDim)) { 22376435f2dSMatthias Springer // If it is a constant, return an affine constant expression. Otherwise, 22476435f2dSMatthias Springer // return an affine expression that represents the respective column in the 22576435f2dSMatthias Springer // constraint set. 22676435f2dSMatthias Springer if (constSize) 22776435f2dSMatthias Springer return builder.getAffineConstantExpr(*constSize); 22876435f2dSMatthias Springer return getPosExpr(getPos(value, dim)); 22976435f2dSMatthias Springer } 23076435f2dSMatthias Springer 23176435f2dSMatthias Springer if (constSize) { 23276435f2dSMatthias Springer // Constant index value/dim: add column to the constraint set, add EQ bound 23376435f2dSMatthias Springer // and return an affine constant expression without pushing the newly added 23476435f2dSMatthias Springer // column to the worklist. 23576435f2dSMatthias Springer (void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false); 23676435f2dSMatthias Springer if (shapedType) 23776435f2dSMatthias Springer bound(value)[*dim] == *constSize; 23876435f2dSMatthias Springer else 23976435f2dSMatthias Springer bound(value) == *constSize; 24076435f2dSMatthias Springer return builder.getAffineConstantExpr(*constSize); 24176435f2dSMatthias Springer } 24276435f2dSMatthias Springer 24376435f2dSMatthias Springer // Dynamic value/dim: insert column to the constraint set and put it on the 24476435f2dSMatthias Springer // worklist. Return an affine expression that represents the newly inserted 24576435f2dSMatthias Springer // column in the constraint set. 24676435f2dSMatthias Springer return getPosExpr(insert(value, dim, /*isSymbol=*/true)); 2478c885658SMatthias Springer } 2488c885658SMatthias Springer 2498c885658SMatthias Springer AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) { 25068f58812STres Popp if (Value value = llvm::dyn_cast_if_present<Value>(ofr)) 2518c885658SMatthias Springer return getExpr(value, /*dim=*/std::nullopt); 2521abd8d1aSMatthias Springer auto constInt = ::getConstantIntValue(ofr); 2538c885658SMatthias Springer assert(constInt.has_value() && "expected Integer constant"); 2548c885658SMatthias Springer return builder.getAffineConstantExpr(*constInt); 2558c885658SMatthias Springer } 2568c885658SMatthias Springer 2578c885658SMatthias Springer AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) { 2588c885658SMatthias Springer return builder.getAffineConstantExpr(constant); 2598c885658SMatthias Springer } 2608c885658SMatthias Springer 2618c885658SMatthias Springer int64_t ValueBoundsConstraintSet::insert(Value value, 2628c885658SMatthias Springer std::optional<int64_t> dim, 26376435f2dSMatthias Springer bool isSymbol, bool addToWorklist) { 2648c885658SMatthias Springer #ifndef NDEBUG 2658c885658SMatthias Springer assertValidValueDim(value, dim); 2668c885658SMatthias Springer #endif // NDEBUG 2678c885658SMatthias Springer 2688c885658SMatthias Springer ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); 269af6ad7acSKazu Hirata assert(!valueDimToPosition.contains(valueDim) && "already mapped"); 2708c885658SMatthias Springer int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) 2718c885658SMatthias Springer : cstr.appendVar(VarKind::SetDim); 27240dd3aa9SMatthias Springer LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos 27340dd3aa9SMatthias Springer << " for: " << value 27440dd3aa9SMatthias Springer << " (dim: " << dim.value_or(kIndexValue) 27540dd3aa9SMatthias Springer << ", owner: " << getOwnerOfValue(value)->getName() 27640dd3aa9SMatthias Springer << ")\n"); 2778c885658SMatthias Springer positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim); 2788c885658SMatthias Springer // Update reverse mapping. 2798c885658SMatthias Springer for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) 280ff930645SMatthias Springer if (positionToValueDim[i].has_value()) 281ff930645SMatthias Springer valueDimToPosition[*positionToValueDim[i]] = i; 2828c885658SMatthias Springer 28376435f2dSMatthias Springer if (addToWorklist) { 28476435f2dSMatthias Springer LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value 28576435f2dSMatthias Springer << " (dim: " << dim.value_or(kIndexValue) << ")\n"); 2860dc9087aSMatthias Springer worklist.push(pos); 28776435f2dSMatthias Springer } 28876435f2dSMatthias Springer 2898c885658SMatthias Springer return pos; 2908c885658SMatthias Springer } 2918c885658SMatthias Springer 292ff930645SMatthias Springer int64_t ValueBoundsConstraintSet::insert(bool isSymbol) { 293ff930645SMatthias Springer int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) 294ff930645SMatthias Springer : cstr.appendVar(VarKind::SetDim); 29540dd3aa9SMatthias Springer LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos 29640dd3aa9SMatthias Springer << "\n"); 297ff930645SMatthias Springer positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt); 298ff930645SMatthias Springer // Update reverse mapping. 299ff930645SMatthias Springer for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) 300ff930645SMatthias Springer if (positionToValueDim[i].has_value()) 301ff930645SMatthias Springer valueDimToPosition[*positionToValueDim[i]] = i; 302ff930645SMatthias Springer return pos; 303ff930645SMatthias Springer } 304ff930645SMatthias Springer 305297eca98SMatthias Springer int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands, 306297eca98SMatthias Springer bool isSymbol) { 307297eca98SMatthias Springer assert(map.getNumResults() == 1 && "expected affine map with one result"); 30821265f69SMatthias Springer int64_t pos = insert(isSymbol); 309297eca98SMatthias Springer 310297eca98SMatthias Springer // Add map and operands to the constraint set. Dimensions are converted to 311297eca98SMatthias Springer // symbols. All operands are added to the worklist (unless they were already 312297eca98SMatthias Springer // processed). 313297eca98SMatthias Springer auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) { 314297eca98SMatthias Springer return getExpr(v.first, v.second); 315297eca98SMatthias Springer }; 316297eca98SMatthias Springer SmallVector<AffineExpr> dimReplacements = llvm::to_vector( 317297eca98SMatthias Springer llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper)); 318297eca98SMatthias Springer SmallVector<AffineExpr> symReplacements = llvm::to_vector( 319297eca98SMatthias Springer llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper)); 320297eca98SMatthias Springer addBound( 321297eca98SMatthias Springer presburger::BoundType::EQ, pos, 322297eca98SMatthias Springer map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements)); 323297eca98SMatthias Springer 324297eca98SMatthias Springer return pos; 325297eca98SMatthias Springer } 326297eca98SMatthias Springer 32740dd3aa9SMatthias Springer int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) { 32840dd3aa9SMatthias Springer return insert(var.map, var.mapOperands, isSymbol); 32940dd3aa9SMatthias Springer } 33040dd3aa9SMatthias Springer 3318c885658SMatthias Springer int64_t ValueBoundsConstraintSet::getPos(Value value, 3328c885658SMatthias Springer std::optional<int64_t> dim) const { 3338c885658SMatthias Springer #ifndef NDEBUG 3348c885658SMatthias Springer assertValidValueDim(value, dim); 3355550c821STres Popp assert((isa<OpResult>(value) || 3365550c821STres Popp cast<BlockArgument>(value).getOwner()->isEntryBlock()) && 3378c885658SMatthias Springer "unstructured control flow is not supported"); 3388c885658SMatthias Springer #endif // NDEBUG 33940dd3aa9SMatthias Springer LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value 34040dd3aa9SMatthias Springer << " (dim: " << dim.value_or(kIndexValue) 34140dd3aa9SMatthias Springer << ", owner: " << getOwnerOfValue(value)->getName() 34240dd3aa9SMatthias Springer << ")\n"); 3438c885658SMatthias Springer auto it = 3448c885658SMatthias Springer valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue))); 3458c885658SMatthias Springer assert(it != valueDimToPosition.end() && "expected mapped entry"); 3468c885658SMatthias Springer return it->second; 3478c885658SMatthias Springer } 3488c885658SMatthias Springer 34976435f2dSMatthias Springer AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) { 35076435f2dSMatthias Springer assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position"); 35176435f2dSMatthias Springer return pos < cstr.getNumDimVars() 35276435f2dSMatthias Springer ? builder.getAffineDimExpr(pos) 35376435f2dSMatthias Springer : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); 35476435f2dSMatthias Springer } 35576435f2dSMatthias Springer 356297eca98SMatthias Springer bool ValueBoundsConstraintSet::isMapped(Value value, 357297eca98SMatthias Springer std::optional<int64_t> dim) const { 358297eca98SMatthias Springer auto it = 359297eca98SMatthias Springer valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue))); 360297eca98SMatthias Springer return it != valueDimToPosition.end(); 361297eca98SMatthias Springer } 362297eca98SMatthias Springer 3635e4a4438SMatthias Springer void ValueBoundsConstraintSet::processWorklist() { 3645e4a4438SMatthias Springer LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n"); 3658c885658SMatthias Springer while (!worklist.empty()) { 3660dc9087aSMatthias Springer int64_t pos = worklist.front(); 3670dc9087aSMatthias Springer worklist.pop(); 368ff930645SMatthias Springer assert(positionToValueDim[pos].has_value() && 369ff930645SMatthias Springer "did not expect std::nullopt on worklist"); 370ff930645SMatthias Springer ValueDim valueDim = *positionToValueDim[pos]; 3718c885658SMatthias Springer Value value = valueDim.first; 3728c885658SMatthias Springer int64_t dim = valueDim.second; 3738c885658SMatthias Springer 3748c885658SMatthias Springer // Check for static dim size. 3758c885658SMatthias Springer if (dim != kIndexValue) { 3768c885658SMatthias Springer auto shapedType = cast<ShapedType>(value.getType()); 3778c885658SMatthias Springer if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) { 3788c885658SMatthias Springer bound(value)[dim] == getExpr(shapedType.getDimSize(dim)); 3798c885658SMatthias Springer continue; 3808c885658SMatthias Springer } 3818c885658SMatthias Springer } 3828c885658SMatthias Springer 3838c885658SMatthias Springer // Do not process any further if the stop condition is met. 384c3f5fd76SMatthias Springer auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim); 3855e4a4438SMatthias Springer if (stopCondition(value, maybeDim, *this)) { 3865e4a4438SMatthias Springer LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value 3875e4a4438SMatthias Springer << " (dim: " << maybeDim << ")\n"); 3888c885658SMatthias Springer continue; 3895e4a4438SMatthias Springer } 3908c885658SMatthias Springer 3918c885658SMatthias Springer // Query `ValueBoundsOpInterface` for constraints. New items may be added to 3928c885658SMatthias Springer // the worklist. 3938c885658SMatthias Springer auto valueBoundsOp = 3948c885658SMatthias Springer dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value)); 3955e4a4438SMatthias Springer LLVM_DEBUG(llvm::dbgs() 3965e4a4438SMatthias Springer << "Query value bounds for: " << value 3975e4a4438SMatthias Springer << " (owner: " << getOwnerOfValue(value)->getName() << ")\n"); 398c5624dc0SMatthias Springer if (valueBoundsOp) { 3998c885658SMatthias Springer if (dim == kIndexValue) { 4008c885658SMatthias Springer valueBoundsOp.populateBoundsForIndexValue(value, *this); 4018c885658SMatthias Springer } else { 4028c885658SMatthias Springer valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this); 4038c885658SMatthias Springer } 404c5624dc0SMatthias Springer continue; 405c5624dc0SMatthias Springer } 4065e4a4438SMatthias Springer LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n"); 407c5624dc0SMatthias Springer 408c5624dc0SMatthias Springer // If the op does not implement `ValueBoundsOpInterface`, check if it 409c5624dc0SMatthias Springer // implements the `DestinationStyleOpInterface`. OpResults of such ops are 410c5624dc0SMatthias Springer // tied to OpOperands. Tied values have the same shape. 411c5624dc0SMatthias Springer auto dstOp = value.getDefiningOp<DestinationStyleOpInterface>(); 412c5624dc0SMatthias Springer if (!dstOp || dim == kIndexValue) 413c5624dc0SMatthias Springer continue; 414c5624dc0SMatthias Springer Value tiedOperand = dstOp.getTiedOpOperand(cast<OpResult>(value))->get(); 415c5624dc0SMatthias Springer bound(value)[dim] == getExpr(tiedOperand, dim); 4168c885658SMatthias Springer } 4178c885658SMatthias Springer } 4188c885658SMatthias Springer 4198c885658SMatthias Springer void ValueBoundsConstraintSet::projectOut(int64_t pos) { 4208c885658SMatthias Springer assert(pos >= 0 && pos < static_cast<int64_t>(positionToValueDim.size()) && 4218c885658SMatthias Springer "invalid position"); 4228c885658SMatthias Springer cstr.projectOut(pos); 423ff930645SMatthias Springer if (positionToValueDim[pos].has_value()) { 424ff930645SMatthias Springer bool erased = valueDimToPosition.erase(*positionToValueDim[pos]); 4258c885658SMatthias Springer (void)erased; 4268c885658SMatthias Springer assert(erased && "inconsistent reverse mapping"); 427ff930645SMatthias Springer } 4288c885658SMatthias Springer positionToValueDim.erase(positionToValueDim.begin() + pos); 4298c885658SMatthias Springer // Update reverse mapping. 4308c885658SMatthias Springer for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) 431ff930645SMatthias Springer if (positionToValueDim[i].has_value()) 432ff930645SMatthias Springer valueDimToPosition[*positionToValueDim[i]] = i; 4338c885658SMatthias Springer } 4348c885658SMatthias Springer 4358c885658SMatthias Springer void ValueBoundsConstraintSet::projectOut( 4368c885658SMatthias Springer function_ref<bool(ValueDim)> condition) { 4378c885658SMatthias Springer int64_t nextPos = 0; 4388c885658SMatthias Springer while (nextPos < static_cast<int64_t>(positionToValueDim.size())) { 439ff930645SMatthias Springer if (positionToValueDim[nextPos].has_value() && 440ff930645SMatthias Springer condition(*positionToValueDim[nextPos])) { 4418c885658SMatthias Springer projectOut(nextPos); 4428c885658SMatthias Springer // The column was projected out so another column is now at that position. 4438c885658SMatthias Springer // Do not increase the counter. 4448c885658SMatthias Springer } else { 4458c885658SMatthias Springer ++nextPos; 4468c885658SMatthias Springer } 4478c885658SMatthias Springer } 4488c885658SMatthias Springer } 4498c885658SMatthias Springer 45040dd3aa9SMatthias Springer void ValueBoundsConstraintSet::projectOutAnonymous( 45140dd3aa9SMatthias Springer std::optional<int64_t> except) { 45240dd3aa9SMatthias Springer int64_t nextPos = 0; 45340dd3aa9SMatthias Springer while (nextPos < static_cast<int64_t>(positionToValueDim.size())) { 45440dd3aa9SMatthias Springer if (positionToValueDim[nextPos].has_value() || except == nextPos) { 45540dd3aa9SMatthias Springer ++nextPos; 45640dd3aa9SMatthias Springer } else { 45740dd3aa9SMatthias Springer projectOut(nextPos); 45840dd3aa9SMatthias Springer // The column was projected out so another column is now at that position. 45940dd3aa9SMatthias Springer // Do not increase the counter. 46040dd3aa9SMatthias Springer } 46140dd3aa9SMatthias Springer } 46240dd3aa9SMatthias Springer } 46340dd3aa9SMatthias Springer 4648c885658SMatthias Springer LogicalResult ValueBoundsConstraintSet::computeBound( 4658c885658SMatthias Springer AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, 46640dd3aa9SMatthias Springer const Variable &var, StopConditionFn stopCondition, bool closedUB) { 46740dd3aa9SMatthias Springer MLIRContext *ctx = var.getContext(); 468eabb6ccdSMatthias Springer int64_t ubAdjustment = closedUB ? 0 : 1; 46940dd3aa9SMatthias Springer Builder b(ctx); 4708c885658SMatthias Springer mapOperands.clear(); 4718c885658SMatthias Springer 4728c885658SMatthias Springer // Process the backward slice of `value` (i.e., reverse use-def chain) until 4738c885658SMatthias Springer // `stopCondition` is met. 47440dd3aa9SMatthias Springer ValueBoundsConstraintSet cstr(ctx, stopCondition); 47540dd3aa9SMatthias Springer int64_t pos = cstr.insert(var, /*isSymbol=*/false); 47640dd3aa9SMatthias Springer assert(pos == 0 && "expected first column"); 4775e4a4438SMatthias Springer cstr.processWorklist(); 4788c885658SMatthias Springer 4798c885658SMatthias Springer // Project out all variables (apart from `valueDim`) that do not match the 4808c885658SMatthias Springer // stop condition. 4818c885658SMatthias Springer cstr.projectOut([&](ValueDim p) { 482c3f5fd76SMatthias Springer auto maybeDim = 483c3f5fd76SMatthias Springer p.second == kIndexValue ? std::nullopt : std::make_optional(p.second); 4845e4a4438SMatthias Springer return !stopCondition(p.first, maybeDim, cstr); 4858c885658SMatthias Springer }); 48640dd3aa9SMatthias Springer cstr.projectOutAnonymous(/*except=*/pos); 4878c885658SMatthias Springer 4888c885658SMatthias Springer // Compute lower and upper bounds for `valueDim`. 4898c885658SMatthias Springer SmallVector<AffineMap> lb(1), ub(1); 49040dd3aa9SMatthias Springer cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub, 4919b514235SMehdi Amini /*closedUB=*/true); 492041bc485SMatthias Springer 4938c885658SMatthias Springer // Note: There are TODOs in the implementation of `getSliceBounds`. In such a 4948c885658SMatthias Springer // case, no lower/upper bound can be computed at the moment. 495041bc485SMatthias Springer // EQ, UB bounds: upper bound is needed. 496041bc485SMatthias Springer if ((type != BoundType::LB) && 497041bc485SMatthias Springer (ub.empty() || !ub[0] || ub[0].getNumResults() == 0)) 498041bc485SMatthias Springer return failure(); 499041bc485SMatthias Springer // EQ, LB bounds: lower bound is needed. 500041bc485SMatthias Springer if ((type != BoundType::UB) && 501041bc485SMatthias Springer (lb.empty() || !lb[0] || lb[0].getNumResults() == 0)) 5028c885658SMatthias Springer return failure(); 5038c885658SMatthias Springer 504041bc485SMatthias Springer // TODO: Generate an affine map with multiple results. 505041bc485SMatthias Springer if (type != BoundType::LB) 506041bc485SMatthias Springer assert(ub.size() == 1 && ub[0].getNumResults() == 1 && 507041bc485SMatthias Springer "multiple bounds not supported"); 508041bc485SMatthias Springer if (type != BoundType::UB) 509041bc485SMatthias Springer assert(lb.size() == 1 && lb[0].getNumResults() == 1 && 510041bc485SMatthias Springer "multiple bounds not supported"); 511041bc485SMatthias Springer 512041bc485SMatthias Springer // EQ bound: lower and upper bound must match. 513041bc485SMatthias Springer if (type == BoundType::EQ && ub[0] != lb[0]) 5148c885658SMatthias Springer return failure(); 5158c885658SMatthias Springer 516041bc485SMatthias Springer AffineMap bound; 517041bc485SMatthias Springer if (type == BoundType::EQ || type == BoundType::LB) { 518041bc485SMatthias Springer bound = lb[0]; 519041bc485SMatthias Springer } else { 520eabb6ccdSMatthias Springer // Computed UB is a closed bound. 521041bc485SMatthias Springer bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(), 522eabb6ccdSMatthias Springer ub[0].getResult(0) + ubAdjustment); 523041bc485SMatthias Springer } 524041bc485SMatthias Springer 5258c885658SMatthias Springer // Gather all SSA values that are used in the computed bound. 5268c885658SMatthias Springer assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() && 5278c885658SMatthias Springer "inconsistent mapping state"); 5288c885658SMatthias Springer SmallVector<AffineExpr> replacementDims, replacementSymbols; 5298c885658SMatthias Springer int64_t numDims = 0, numSymbols = 0; 5308c885658SMatthias Springer for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) { 5318c885658SMatthias Springer // Skip `value`. 5328c885658SMatthias Springer if (i == pos) 5338c885658SMatthias Springer continue; 5348c885658SMatthias Springer // Check if the position `i` is used in the generated bound. If so, it must 5358c885658SMatthias Springer // be included in the generated affine.apply op. 5368c885658SMatthias Springer bool used = false; 5378c885658SMatthias Springer bool isDim = i < cstr.cstr.getNumDimVars(); 5388c885658SMatthias Springer if (isDim) { 539041bc485SMatthias Springer if (bound.isFunctionOfDim(i)) 5408c885658SMatthias Springer used = true; 5418c885658SMatthias Springer } else { 542041bc485SMatthias Springer if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) 5438c885658SMatthias Springer used = true; 5448c885658SMatthias Springer } 5458c885658SMatthias Springer 5468c885658SMatthias Springer if (!used) { 5478c885658SMatthias Springer // Not used: Remove dim/symbol from the result. 5488c885658SMatthias Springer if (isDim) { 5498c885658SMatthias Springer replacementDims.push_back(b.getAffineConstantExpr(0)); 5508c885658SMatthias Springer } else { 5518c885658SMatthias Springer replacementSymbols.push_back(b.getAffineConstantExpr(0)); 5528c885658SMatthias Springer } 5538c885658SMatthias Springer continue; 5548c885658SMatthias Springer } 5558c885658SMatthias Springer 5568c885658SMatthias Springer if (isDim) { 5578c885658SMatthias Springer replacementDims.push_back(b.getAffineDimExpr(numDims++)); 5588c885658SMatthias Springer } else { 5598c885658SMatthias Springer replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++)); 5608c885658SMatthias Springer } 5618c885658SMatthias Springer 562ff930645SMatthias Springer assert(cstr.positionToValueDim[i].has_value() && 563ff930645SMatthias Springer "cannot build affine map in terms of anonymous column"); 564ff930645SMatthias Springer ValueBoundsConstraintSet::ValueDim valueDim = *cstr.positionToValueDim[i]; 5658c885658SMatthias Springer Value value = valueDim.first; 5668c885658SMatthias Springer int64_t dim = valueDim.second; 5678c885658SMatthias Springer if (dim == ValueBoundsConstraintSet::kIndexValue) { 5688c885658SMatthias Springer // An index-type value is used: can be used directly in the affine.apply 5698c885658SMatthias Springer // op. 5708c885658SMatthias Springer assert(value.getType().isIndex() && "expected index type"); 5718c885658SMatthias Springer mapOperands.push_back(std::make_pair(value, std::nullopt)); 5728c885658SMatthias Springer continue; 5738c885658SMatthias Springer } 5748c885658SMatthias Springer 5758c885658SMatthias Springer assert(cast<ShapedType>(value.getType()).isDynamicDim(dim) && 5768c885658SMatthias Springer "expected dynamic dim"); 5778c885658SMatthias Springer mapOperands.push_back(std::make_pair(value, dim)); 5788c885658SMatthias Springer } 5798c885658SMatthias Springer 580041bc485SMatthias Springer resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols, 5818c885658SMatthias Springer numDims, numSymbols); 5828c885658SMatthias Springer return success(); 5838c885658SMatthias Springer } 5848c885658SMatthias Springer 58577124386SMatthias Springer LogicalResult ValueBoundsConstraintSet::computeDependentBound( 586c3f5fd76SMatthias Springer AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, 58740dd3aa9SMatthias Springer const Variable &var, ValueDimList dependencies, bool closedUB) { 588eabb6ccdSMatthias Springer return computeBound( 58940dd3aa9SMatthias Springer resultMap, mapOperands, type, var, 5905e4a4438SMatthias Springer [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { 591eabb6ccdSMatthias Springer return llvm::is_contained(dependencies, std::make_pair(v, d)); 592eabb6ccdSMatthias Springer }, 593eabb6ccdSMatthias Springer closedUB); 594c3f5fd76SMatthias Springer } 595c3f5fd76SMatthias Springer 59677124386SMatthias Springer LogicalResult ValueBoundsConstraintSet::computeIndependentBound( 59777124386SMatthias Springer AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, 59840dd3aa9SMatthias Springer const Variable &var, ValueRange independencies, bool closedUB) { 59977124386SMatthias Springer // Return "true" if the given value is independent of all values in 60077124386SMatthias Springer // `independencies`. I.e., neither the value itself nor any value in the 60177124386SMatthias Springer // backward slice (reverse use-def chain) is contained in `independencies`. 60277124386SMatthias Springer auto isIndependent = [&](Value v) { 60377124386SMatthias Springer SmallVector<Value> worklist; 60477124386SMatthias Springer DenseSet<Value> visited; 60577124386SMatthias Springer worklist.push_back(v); 60677124386SMatthias Springer while (!worklist.empty()) { 60777124386SMatthias Springer Value next = worklist.pop_back_val(); 6086ffa7cd8SKazu Hirata if (!visited.insert(next).second) 60977124386SMatthias Springer continue; 61077124386SMatthias Springer if (llvm::is_contained(independencies, next)) 61177124386SMatthias Springer return false; 61277124386SMatthias Springer // TODO: DominanceInfo could be used to stop the traversal early. 61377124386SMatthias Springer Operation *op = next.getDefiningOp(); 61477124386SMatthias Springer if (!op) 61577124386SMatthias Springer continue; 61677124386SMatthias Springer worklist.append(op->getOperands().begin(), op->getOperands().end()); 61777124386SMatthias Springer } 61877124386SMatthias Springer return true; 61977124386SMatthias Springer }; 62077124386SMatthias Springer 62177124386SMatthias Springer // Reify bounds in terms of any independent values. 62277124386SMatthias Springer return computeBound( 62340dd3aa9SMatthias Springer resultMap, mapOperands, type, var, 6245e4a4438SMatthias Springer [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { 6255e4a4438SMatthias Springer return isIndependent(v); 6265e4a4438SMatthias Springer }, 62777124386SMatthias Springer closedUB); 62877124386SMatthias Springer } 62977124386SMatthias Springer 6300dc9087aSMatthias Springer FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( 63140dd3aa9SMatthias Springer presburger::BoundType type, const Variable &var, 632eabb6ccdSMatthias Springer StopConditionFn stopCondition, bool closedUB) { 6335e4a4438SMatthias Springer // Default stop condition if none was specified: Keep adding constraints until 6345e4a4438SMatthias Springer // a bound could be computed. 63576435f2dSMatthias Springer int64_t pos = 0; 6365e4a4438SMatthias Springer auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim, 6375e4a4438SMatthias Springer ValueBoundsConstraintSet &cstr) { 6382861856bSBenjamin Maxwell return cstr.cstr.getConstantBound64(type, pos).has_value(); 6395e4a4438SMatthias Springer }; 6405e4a4438SMatthias Springer 6415e4a4438SMatthias Springer ValueBoundsConstraintSet cstr( 64240dd3aa9SMatthias Springer var.getContext(), stopCondition ? stopCondition : defaultStopCondition); 64340dd3aa9SMatthias Springer pos = cstr.populateConstraints(var.map, var.mapOperands); 64476435f2dSMatthias Springer assert(pos == 0 && "expected `map` is the first column"); 6455e4a4438SMatthias Springer 6462861856bSBenjamin Maxwell // Compute constant bound for `valueDim`. 6472861856bSBenjamin Maxwell int64_t ubAdjustment = closedUB ? 0 : 1; 6482861856bSBenjamin Maxwell if (auto bound = cstr.cstr.getConstantBound64(type, pos)) 6492861856bSBenjamin Maxwell return type == BoundType::UB ? *bound + ubAdjustment : *bound; 6502861856bSBenjamin Maxwell return failure(); 6512861856bSBenjamin Maxwell } 6522861856bSBenjamin Maxwell 65376435f2dSMatthias Springer void ValueBoundsConstraintSet::populateConstraints(Value value, 6545e4a4438SMatthias Springer std::optional<int64_t> dim) { 6552861856bSBenjamin Maxwell #ifndef NDEBUG 6562861856bSBenjamin Maxwell assertValidValueDim(value, dim); 6572861856bSBenjamin Maxwell #endif // NDEBUG 6582861856bSBenjamin Maxwell 65976435f2dSMatthias Springer // `getExpr` pushes the value/dim onto the worklist (unless it was already 66076435f2dSMatthias Springer // analyzed). 66176435f2dSMatthias Springer (void)getExpr(value, dim); 66276435f2dSMatthias Springer // Process all values/dims on the worklist. This may traverse and analyze 66376435f2dSMatthias Springer // additional IR, depending the current stop function. 66476435f2dSMatthias Springer processWorklist(); 6652861856bSBenjamin Maxwell } 6662861856bSBenjamin Maxwell 66776435f2dSMatthias Springer int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map, 66876435f2dSMatthias Springer ValueDimList operands) { 669297eca98SMatthias Springer int64_t pos = insert(map, operands, /*isSymbol=*/false); 6702861856bSBenjamin Maxwell // Process the backward slice of `operands` (i.e., reverse use-def chain) 6712861856bSBenjamin Maxwell // until `stopCondition` is met. 6725e4a4438SMatthias Springer processWorklist(); 6732861856bSBenjamin Maxwell return pos; 6742861856bSBenjamin Maxwell } 6752861856bSBenjamin Maxwell 6763049ac44SLei Zhang FailureOr<int64_t> 6773049ac44SLei Zhang ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2, 678ff930645SMatthias Springer std::optional<int64_t> dim1, 679ff930645SMatthias Springer std::optional<int64_t> dim2) { 680ff930645SMatthias Springer #ifndef NDEBUG 681ff930645SMatthias Springer assertValidValueDim(value1, dim1); 682ff930645SMatthias Springer assertValidValueDim(value2, dim2); 683ff930645SMatthias Springer #endif // NDEBUG 684ff930645SMatthias Springer 685ff930645SMatthias Springer Builder b(value1.getContext()); 686ff930645SMatthias Springer AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, 687ff930645SMatthias Springer b.getAffineDimExpr(0) - b.getAffineDimExpr(1)); 68840dd3aa9SMatthias Springer return computeConstantBound(presburger::BoundType::EQ, 68940dd3aa9SMatthias Springer Variable(map, {{value1, dim1}, {value2, dim2}})); 6903049ac44SLei Zhang } 6913049ac44SLei Zhang 69240dd3aa9SMatthias Springer bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, 693297eca98SMatthias Springer ComparisonOperator cmp, 69440dd3aa9SMatthias Springer int64_t rhsPos) { 69576435f2dSMatthias Springer // This function returns "true" if "lhs CMP rhs" is proven to hold. 69676435f2dSMatthias Springer // 69776435f2dSMatthias Springer // Example for ComparisonOperator::LE and index-typed values: We would like to 69876435f2dSMatthias Springer // prove that lhs <= rhs. Proof by contradiction: add the inverse 69976435f2dSMatthias Springer // relation (lhs > rhs) to the constraint set and check if the resulting 70076435f2dSMatthias Springer // constraint set is "empty" (i.e. has no solution). In that case, 70176435f2dSMatthias Springer // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds. 70276435f2dSMatthias Springer 70376435f2dSMatthias Springer // We cannot prove anything if the constraint set is already empty. 70476435f2dSMatthias Springer if (cstr.isEmpty()) { 70576435f2dSMatthias Springer LLVM_DEBUG( 70676435f2dSMatthias Springer llvm::dbgs() 70776435f2dSMatthias Springer << "cannot compare value/dims: constraint system is already empty"); 70876435f2dSMatthias Springer return false; 70976435f2dSMatthias Springer } 71076435f2dSMatthias Springer 71176435f2dSMatthias Springer // EQ can be expressed as LE and GE. 71276435f2dSMatthias Springer if (cmp == EQ) 713297eca98SMatthias Springer return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) && 714297eca98SMatthias Springer comparePos(lhsPos, ComparisonOperator::GE, rhsPos); 715297eca98SMatthias Springer 716297eca98SMatthias Springer // Construct inequality. 717297eca98SMatthias Springer SmallVector<int64_t> eq(cstr.getNumCols(), 0); 718297eca98SMatthias Springer if (cmp == LT || cmp == LE) { 719297eca98SMatthias Springer ++eq[lhsPos]; 720297eca98SMatthias Springer --eq[rhsPos]; 721297eca98SMatthias Springer } else if (cmp == GT || cmp == GE) { 722297eca98SMatthias Springer --eq[lhsPos]; 723297eca98SMatthias Springer ++eq[rhsPos]; 724297eca98SMatthias Springer } else { 725297eca98SMatthias Springer llvm_unreachable("unsupported comparison operator"); 726297eca98SMatthias Springer } 727297eca98SMatthias Springer if (cmp == LE || cmp == GE) 728297eca98SMatthias Springer eq[cstr.getNumCols() - 1] -= 1; 729297eca98SMatthias Springer 730297eca98SMatthias Springer // Add inequality to the constraint set and check if it made the constraint 731297eca98SMatthias Springer // set empty. 732297eca98SMatthias Springer int64_t ineqPos = cstr.getNumInequalities(); 733297eca98SMatthias Springer cstr.addInequality(eq); 734297eca98SMatthias Springer bool isEmpty = cstr.isEmpty(); 735297eca98SMatthias Springer cstr.removeInequality(ineqPos); 736297eca98SMatthias Springer return isEmpty; 737ff930645SMatthias Springer } 738ff930645SMatthias Springer 73940dd3aa9SMatthias Springer bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs, 74040dd3aa9SMatthias Springer ComparisonOperator cmp, 74140dd3aa9SMatthias Springer const Variable &rhs) { 74240dd3aa9SMatthias Springer int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands); 74340dd3aa9SMatthias Springer int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands); 74440dd3aa9SMatthias Springer return comparePos(lhsPos, cmp, rhsPos); 745ff614a57SMatthias Springer } 746297eca98SMatthias Springer 74740dd3aa9SMatthias Springer bool ValueBoundsConstraintSet::compare(const Variable &lhs, 74840dd3aa9SMatthias Springer ComparisonOperator cmp, 74940dd3aa9SMatthias Springer const Variable &rhs) { 750297eca98SMatthias Springer int64_t lhsPos = -1, rhsPos = -1; 751297eca98SMatthias Springer auto stopCondition = [&](Value v, std::optional<int64_t> dim, 752297eca98SMatthias Springer ValueBoundsConstraintSet &cstr) { 753297eca98SMatthias Springer // Keep processing as long as lhs/rhs were not processed. 754dc390289SJie Fu if (size_t(lhsPos) >= cstr.positionToValueDim.size() || 755dc390289SJie Fu size_t(rhsPos) >= cstr.positionToValueDim.size()) 756297eca98SMatthias Springer return false; 757297eca98SMatthias Springer // Keep processing as long as the relation cannot be proven. 758297eca98SMatthias Springer return cstr.comparePos(lhsPos, cmp, rhsPos); 759297eca98SMatthias Springer }; 760297eca98SMatthias Springer ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); 76140dd3aa9SMatthias Springer lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands); 76240dd3aa9SMatthias Springer rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands); 763297eca98SMatthias Springer return cstr.comparePos(lhsPos, cmp, rhsPos); 764297eca98SMatthias Springer } 765297eca98SMatthias Springer 76640dd3aa9SMatthias Springer FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1, 76740dd3aa9SMatthias Springer const Variable &var2) { 76840dd3aa9SMatthias Springer if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2)) 769297eca98SMatthias Springer return true; 77040dd3aa9SMatthias Springer if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) || 77140dd3aa9SMatthias Springer ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2)) 772297eca98SMatthias Springer return false; 773ff614a57SMatthias Springer return failure(); 774ff614a57SMatthias Springer } 775ff614a57SMatthias Springer 776ff614a57SMatthias Springer FailureOr<bool> 777ff614a57SMatthias Springer ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx, 778ff614a57SMatthias Springer HyperrectangularSlice slice1, 779ff614a57SMatthias Springer HyperrectangularSlice slice2) { 780*9f2dd085Sklensy assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() && 7811abd8d1aSMatthias Springer "expected slices of same rank"); 782*9f2dd085Sklensy assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() && 7831abd8d1aSMatthias Springer "expected slices of same rank"); 784*9f2dd085Sklensy assert(slice1.getMixedStrides().size() == slice2.getMixedStrides().size() && 7851abd8d1aSMatthias Springer "expected slices of same rank"); 7861abd8d1aSMatthias Springer 787ff614a57SMatthias Springer Builder b(ctx); 7881abd8d1aSMatthias Springer bool foundUnknownBound = false; 789ff614a57SMatthias Springer for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) { 7901abd8d1aSMatthias Springer AffineMap map = 7911abd8d1aSMatthias Springer AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4, 7921abd8d1aSMatthias Springer b.getAffineSymbolExpr(0) + 7931abd8d1aSMatthias Springer b.getAffineSymbolExpr(1) * b.getAffineSymbolExpr(2) - 7941abd8d1aSMatthias Springer b.getAffineSymbolExpr(3)); 7951abd8d1aSMatthias Springer { 7961abd8d1aSMatthias Springer // Case 1: Slices are guaranteed to be non-overlapping if 7971abd8d1aSMatthias Springer // offset1 + size1 * stride1 <= offset2 (for at least one dimension). 7981abd8d1aSMatthias Springer SmallVector<OpFoldResult> ofrOperands; 7991abd8d1aSMatthias Springer ofrOperands.push_back(slice1.getMixedOffsets()[i]); 8001abd8d1aSMatthias Springer ofrOperands.push_back(slice1.getMixedSizes()[i]); 8011abd8d1aSMatthias Springer ofrOperands.push_back(slice1.getMixedStrides()[i]); 8021abd8d1aSMatthias Springer ofrOperands.push_back(slice2.getMixedOffsets()[i]); 8031abd8d1aSMatthias Springer SmallVector<Value> valueOperands; 8041abd8d1aSMatthias Springer AffineMap foldedMap = 8051abd8d1aSMatthias Springer foldAttributesIntoMap(b, map, ofrOperands, valueOperands); 8061abd8d1aSMatthias Springer FailureOr<int64_t> constBound = computeConstantBound( 80740dd3aa9SMatthias Springer presburger::BoundType::EQ, Variable(foldedMap, valueOperands)); 8081abd8d1aSMatthias Springer foundUnknownBound |= failed(constBound); 8091abd8d1aSMatthias Springer if (succeeded(constBound) && *constBound <= 0) 8101abd8d1aSMatthias Springer return false; 8111abd8d1aSMatthias Springer } 8121abd8d1aSMatthias Springer { 8131abd8d1aSMatthias Springer // Case 2: Slices are guaranteed to be non-overlapping if 8141abd8d1aSMatthias Springer // offset2 + size2 * stride2 <= offset1 (for at least one dimension). 8151abd8d1aSMatthias Springer SmallVector<OpFoldResult> ofrOperands; 8161abd8d1aSMatthias Springer ofrOperands.push_back(slice2.getMixedOffsets()[i]); 8171abd8d1aSMatthias Springer ofrOperands.push_back(slice2.getMixedSizes()[i]); 8181abd8d1aSMatthias Springer ofrOperands.push_back(slice2.getMixedStrides()[i]); 8191abd8d1aSMatthias Springer ofrOperands.push_back(slice1.getMixedOffsets()[i]); 8201abd8d1aSMatthias Springer SmallVector<Value> valueOperands; 8211abd8d1aSMatthias Springer AffineMap foldedMap = 8221abd8d1aSMatthias Springer foldAttributesIntoMap(b, map, ofrOperands, valueOperands); 8231abd8d1aSMatthias Springer FailureOr<int64_t> constBound = computeConstantBound( 82440dd3aa9SMatthias Springer presburger::BoundType::EQ, Variable(foldedMap, valueOperands)); 8251abd8d1aSMatthias Springer foundUnknownBound |= failed(constBound); 8261abd8d1aSMatthias Springer if (succeeded(constBound) && *constBound <= 0) 8271abd8d1aSMatthias Springer return false; 8281abd8d1aSMatthias Springer } 8291abd8d1aSMatthias Springer } 8301abd8d1aSMatthias Springer 8311abd8d1aSMatthias Springer // If at least one bound could not be computed, we cannot be certain that the 8321abd8d1aSMatthias Springer // slices are really overlapping. 8331abd8d1aSMatthias Springer if (foundUnknownBound) 8341abd8d1aSMatthias Springer return failure(); 8351abd8d1aSMatthias Springer 8361abd8d1aSMatthias Springer // All bounds could be computed and none of the above cases applied. 8371abd8d1aSMatthias Springer // Therefore, the slices are guaranteed to overlap. 8381abd8d1aSMatthias Springer return true; 8391abd8d1aSMatthias Springer } 8401abd8d1aSMatthias Springer 841ff614a57SMatthias Springer FailureOr<bool> 842ff614a57SMatthias Springer ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx, 843ff614a57SMatthias Springer HyperrectangularSlice slice1, 844ff614a57SMatthias Springer HyperrectangularSlice slice2) { 845*9f2dd085Sklensy assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() && 846ff614a57SMatthias Springer "expected slices of same rank"); 847*9f2dd085Sklensy assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() && 848ff614a57SMatthias Springer "expected slices of same rank"); 849*9f2dd085Sklensy assert(slice1.getMixedStrides().size() == slice2.getMixedStrides().size() && 850ff614a57SMatthias Springer "expected slices of same rank"); 851ff614a57SMatthias Springer 852ff614a57SMatthias Springer // The two slices are equivalent if all of their offsets, sizes and strides 853ff614a57SMatthias Springer // are equal. If equality cannot be determined for at least one of those 854ff614a57SMatthias Springer // values, equivalence cannot be determined and this function returns 855ff614a57SMatthias Springer // "failure". 856ff614a57SMatthias Springer for (auto [offset1, offset2] : 857ff614a57SMatthias Springer llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) { 858ff614a57SMatthias Springer FailureOr<bool> equal = areEqual(offset1, offset2); 859ff614a57SMatthias Springer if (failed(equal)) 860ff614a57SMatthias Springer return failure(); 861ff614a57SMatthias Springer if (!equal.value()) 862ff614a57SMatthias Springer return false; 863ff614a57SMatthias Springer } 864ff614a57SMatthias Springer for (auto [size1, size2] : 865ff614a57SMatthias Springer llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) { 866ff614a57SMatthias Springer FailureOr<bool> equal = areEqual(size1, size2); 867ff614a57SMatthias Springer if (failed(equal)) 868ff614a57SMatthias Springer return failure(); 869ff614a57SMatthias Springer if (!equal.value()) 870ff614a57SMatthias Springer return false; 871ff614a57SMatthias Springer } 872ff614a57SMatthias Springer for (auto [stride1, stride2] : 873ff614a57SMatthias Springer llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) { 874ff614a57SMatthias Springer FailureOr<bool> equal = areEqual(stride1, stride2); 875ff614a57SMatthias Springer if (failed(equal)) 876ff614a57SMatthias Springer return failure(); 877ff614a57SMatthias Springer if (!equal.value()) 878ff614a57SMatthias Springer return false; 879ff614a57SMatthias Springer } 880ff614a57SMatthias Springer return true; 881ff614a57SMatthias Springer } 882ff614a57SMatthias Springer 88310b07f23SMatthias Springer void ValueBoundsConstraintSet::dump() const { 88410b07f23SMatthias Springer llvm::errs() << "==========\nColumns:\n"; 88510b07f23SMatthias Springer llvm::errs() << "(column\tdim\tvalue)\n"; 88610b07f23SMatthias Springer for (auto [index, valueDim] : llvm::enumerate(positionToValueDim)) { 88710b07f23SMatthias Springer llvm::errs() << " " << index << "\t"; 88810b07f23SMatthias Springer if (valueDim) { 88910b07f23SMatthias Springer if (valueDim->second == kIndexValue) { 89010b07f23SMatthias Springer llvm::errs() << "n/a\t"; 89110b07f23SMatthias Springer } else { 89210b07f23SMatthias Springer llvm::errs() << valueDim->second << "\t"; 89310b07f23SMatthias Springer } 89410b07f23SMatthias Springer llvm::errs() << getOwnerOfValue(valueDim->first)->getName() << " "; 89510b07f23SMatthias Springer if (OpResult result = dyn_cast<OpResult>(valueDim->first)) { 89610b07f23SMatthias Springer llvm::errs() << "(result " << result.getResultNumber() << ")"; 89710b07f23SMatthias Springer } else { 89810b07f23SMatthias Springer llvm::errs() << "(bbarg " 89910b07f23SMatthias Springer << cast<BlockArgument>(valueDim->first).getArgNumber() 90010b07f23SMatthias Springer << ")"; 90110b07f23SMatthias Springer } 90210b07f23SMatthias Springer llvm::errs() << "\n"; 90310b07f23SMatthias Springer } else { 90410b07f23SMatthias Springer llvm::errs() << "n/a\tn/a\n"; 90510b07f23SMatthias Springer } 90610b07f23SMatthias Springer } 90710b07f23SMatthias Springer llvm::errs() << "\nConstraint set:\n"; 90810b07f23SMatthias Springer cstr.dump(); 90910b07f23SMatthias Springer llvm::errs() << "==========\n"; 91010b07f23SMatthias Springer } 91110b07f23SMatthias Springer 9128c885658SMatthias Springer ValueBoundsConstraintSet::BoundBuilder & 9138c885658SMatthias Springer ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) { 9148c885658SMatthias Springer assert(!this->dim.has_value() && "dim was already set"); 9158c885658SMatthias Springer this->dim = dim; 9168c885658SMatthias Springer #ifndef NDEBUG 9178c885658SMatthias Springer assertValidValueDim(value, this->dim); 9188c885658SMatthias Springer #endif // NDEBUG 9198c885658SMatthias Springer return *this; 9208c885658SMatthias Springer } 9218c885658SMatthias Springer 9228c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator<(AffineExpr expr) { 9238c885658SMatthias Springer #ifndef NDEBUG 9248c885658SMatthias Springer assertValidValueDim(value, this->dim); 9258c885658SMatthias Springer #endif // NDEBUG 9268c885658SMatthias Springer cstr.addBound(BoundType::UB, cstr.getPos(value, this->dim), expr); 9278c885658SMatthias Springer } 9288c885658SMatthias Springer 9298c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator<=(AffineExpr expr) { 9308c885658SMatthias Springer operator<(expr + 1); 9318c885658SMatthias Springer } 9328c885658SMatthias Springer 9338c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator>(AffineExpr expr) { 9348c885658SMatthias Springer operator>=(expr + 1); 9358c885658SMatthias Springer } 9368c885658SMatthias Springer 9378c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator>=(AffineExpr expr) { 9388c885658SMatthias Springer #ifndef NDEBUG 9398c885658SMatthias Springer assertValidValueDim(value, this->dim); 9408c885658SMatthias Springer #endif // NDEBUG 9418c885658SMatthias Springer cstr.addBound(BoundType::LB, cstr.getPos(value, this->dim), expr); 9428c885658SMatthias Springer } 9438c885658SMatthias Springer 9448c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator==(AffineExpr expr) { 9458c885658SMatthias Springer #ifndef NDEBUG 9468c885658SMatthias Springer assertValidValueDim(value, this->dim); 9478c885658SMatthias Springer #endif // NDEBUG 9488c885658SMatthias Springer cstr.addBound(BoundType::EQ, cstr.getPos(value, this->dim), expr); 9498c885658SMatthias Springer } 9508c885658SMatthias Springer 9518c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator<(OpFoldResult ofr) { 9528c885658SMatthias Springer operator<(cstr.getExpr(ofr)); 9538c885658SMatthias Springer } 9548c885658SMatthias Springer 9558c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator<=(OpFoldResult ofr) { 9568c885658SMatthias Springer operator<=(cstr.getExpr(ofr)); 9578c885658SMatthias Springer } 9588c885658SMatthias Springer 9598c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator>(OpFoldResult ofr) { 9608c885658SMatthias Springer operator>(cstr.getExpr(ofr)); 9618c885658SMatthias Springer } 9628c885658SMatthias Springer 9638c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator>=(OpFoldResult ofr) { 9648c885658SMatthias Springer operator>=(cstr.getExpr(ofr)); 9658c885658SMatthias Springer } 9668c885658SMatthias Springer 9678c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator==(OpFoldResult ofr) { 9688c885658SMatthias Springer operator==(cstr.getExpr(ofr)); 9698c885658SMatthias Springer } 9708c885658SMatthias Springer 9718c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator<(int64_t i) { 9728c885658SMatthias Springer operator<(cstr.getExpr(i)); 9738c885658SMatthias Springer } 9748c885658SMatthias Springer 9758c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator<=(int64_t i) { 9768c885658SMatthias Springer operator<=(cstr.getExpr(i)); 9778c885658SMatthias Springer } 9788c885658SMatthias Springer 9798c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator>(int64_t i) { 9808c885658SMatthias Springer operator>(cstr.getExpr(i)); 9818c885658SMatthias Springer } 9828c885658SMatthias Springer 9838c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator>=(int64_t i) { 9848c885658SMatthias Springer operator>=(cstr.getExpr(i)); 9858c885658SMatthias Springer } 9868c885658SMatthias Springer 9878c885658SMatthias Springer void ValueBoundsConstraintSet::BoundBuilder::operator==(int64_t i) { 9888c885658SMatthias Springer operator==(cstr.getExpr(i)); 9898c885658SMatthias Springer } 990