1ab701975SMogball //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// 2ab701975SMogball // 3ab701975SMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ab701975SMogball // See https://llvm.org/LICENSE.txt for license information. 5ab701975SMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ab701975SMogball // 7ab701975SMogball //===----------------------------------------------------------------------===// 8ab701975SMogball // 9ab701975SMogball // This file defines the dataflow analysis class for integer range inference 10ab701975SMogball // which is used in transformations over the `arith` dialect such as 11ab701975SMogball // branch elimination or signed->unsigned rewriting 12ab701975SMogball // 13ab701975SMogball //===----------------------------------------------------------------------===// 14ab701975SMogball 15ab701975SMogball #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" 16ab701975SMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 17383f2bd5SMehdi Amini #include "mlir/Analysis/DataFlow/SparseAnalysis.h" 18383f2bd5SMehdi Amini #include "mlir/Analysis/DataFlowFramework.h" 19383f2bd5SMehdi Amini #include "mlir/IR/BuiltinAttributes.h" 20383f2bd5SMehdi Amini #include "mlir/IR/Dialect.h" 21383f2bd5SMehdi Amini #include "mlir/IR/OpDefinition.h" 22f54cdc5dSIvan Butygin #include "mlir/IR/TypeUtilities.h" 23383f2bd5SMehdi Amini #include "mlir/IR/Value.h" 24383f2bd5SMehdi Amini #include "mlir/Interfaces/ControlFlowInterfaces.h" 25ab701975SMogball #include "mlir/Interfaces/InferIntRangeInterface.h" 26ab701975SMogball #include "mlir/Interfaces/LoopLikeInterface.h" 27383f2bd5SMehdi Amini #include "mlir/Support/LLVM.h" 28383f2bd5SMehdi Amini #include "llvm/ADT/STLExtras.h" 29383f2bd5SMehdi Amini #include "llvm/Support/Casting.h" 30ab701975SMogball #include "llvm/Support/Debug.h" 31383f2bd5SMehdi Amini #include <cassert> 32a1fe1f5fSKazu Hirata #include <optional> 33383f2bd5SMehdi Amini #include <utility> 34ab701975SMogball 35ab701975SMogball #define DEBUG_TYPE "int-range-analysis" 36ab701975SMogball 37ab701975SMogball using namespace mlir; 38ab701975SMogball using namespace mlir::dataflow; 39ab701975SMogball 40ab701975SMogball void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { 41ab701975SMogball Lattice::onUpdate(solver); 42ab701975SMogball 43ab701975SMogball // If the integer range can be narrowed to a constant, update the constant 44ab701975SMogball // value of the SSA value. 450a81ace0SKazu Hirata std::optional<APInt> constant = getValue().getValue().getConstantValue(); 46*b5c5c2b2SKazu Hirata auto value = cast<Value>(anchor); 47ab701975SMogball auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value); 48ab701975SMogball if (!constant) 49de0ebc52SZhixun Tan return solver->propagateIfChanged( 50de0ebc52SZhixun Tan cv, cv->join(ConstantValue::getUnknownConstant())); 51ab701975SMogball 52ab701975SMogball Dialect *dialect; 53ab701975SMogball if (auto *parent = value.getDefiningOp()) 54ab701975SMogball dialect = parent->getDialect(); 55ab701975SMogball else 56ab701975SMogball dialect = value.getParentBlock()->getParentOp()->getDialect(); 57f54cdc5dSIvan Butygin 58f54cdc5dSIvan Butygin Type type = getElementTypeOrSelf(value); 59ab701975SMogball solver->propagateIfChanged( 60f54cdc5dSIvan Butygin cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect))); 61ab701975SMogball } 62ab701975SMogball 6315e915a4SIvan Butygin LogicalResult IntegerRangeAnalysis::visitOperation( 64ab701975SMogball Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands, 65ab701975SMogball ArrayRef<IntegerValueRangeLattice *> results) { 66ab701975SMogball auto inferrable = dyn_cast<InferIntRangeInterface>(op); 6715e915a4SIvan Butygin if (!inferrable) { 6815e915a4SIvan Butygin setAllToEntryStates(results); 6915e915a4SIvan Butygin return success(); 7015e915a4SIvan Butygin } 71ab701975SMogball 72ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); 736aeea700SSpenser Bauman auto argRanges = llvm::map_to_vector( 746aeea700SSpenser Bauman operands, [](const IntegerValueRangeLattice *lattice) { 756aeea700SSpenser Bauman return lattice->getValue(); 766aeea700SSpenser Bauman }); 77ab701975SMogball 786aeea700SSpenser Bauman auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { 795550c821STres Popp auto result = dyn_cast<OpResult>(v); 80ab701975SMogball if (!result) 81ab701975SMogball return; 82360c1111SKazu Hirata assert(llvm::is_contained(op->getResults(), result)); 83ab701975SMogball 84ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); 85ab701975SMogball IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; 8647bf3e38SZhixun Tan IntegerValueRange oldRange = lattice->getValue(); 87ab701975SMogball 886aeea700SSpenser Bauman ChangeResult changed = lattice->join(attrs); 89ab701975SMogball 90ab701975SMogball // Catch loop results with loop variant bounds and conservatively make 91ab701975SMogball // them [-inf, inf] so we don't circle around infinitely often (because 92ab701975SMogball // the dataflow analysis in MLIR doesn't attempt to work out trip counts 93ab701975SMogball // and often can't). 94ab701975SMogball bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { 95ab701975SMogball return op->hasTrait<OpTrait::IsTerminator>(); 96ab701975SMogball }); 9747bf3e38SZhixun Tan if (isYieldedResult && !oldRange.isUninitialized() && 9847bf3e38SZhixun Tan !(lattice->getValue() == oldRange)) { 99ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); 100de0ebc52SZhixun Tan changed |= lattice->join(IntegerValueRange::getMaxRange(v)); 101ab701975SMogball } 102ab701975SMogball propagateIfChanged(lattice, changed); 103ab701975SMogball }; 104ab701975SMogball 1056aeea700SSpenser Bauman inferrable.inferResultRangesFromOptional(argRanges, joinCallback); 10615e915a4SIvan Butygin return success(); 107ab701975SMogball } 108ab701975SMogball 109ab701975SMogball void IntegerRangeAnalysis::visitNonControlFlowArguments( 110ab701975SMogball Operation *op, const RegionSuccessor &successor, 111ab701975SMogball ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) { 112ab701975SMogball if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) { 113ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); 114ab701975SMogball 1156aeea700SSpenser Bauman auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { 1164b3f251bSdonald chen return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); 1176aeea700SSpenser Bauman }); 1186aeea700SSpenser Bauman 1196aeea700SSpenser Bauman auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { 1205550c821STres Popp auto arg = dyn_cast<BlockArgument>(v); 121ab701975SMogball if (!arg) 122ab701975SMogball return; 123360c1111SKazu Hirata if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg)) 124ab701975SMogball return; 125ab701975SMogball 126ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); 127ab701975SMogball IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; 12847bf3e38SZhixun Tan IntegerValueRange oldRange = lattice->getValue(); 129ab701975SMogball 1306aeea700SSpenser Bauman ChangeResult changed = lattice->join(attrs); 131ab701975SMogball 132ab701975SMogball // Catch loop results with loop variant bounds and conservatively make 133ab701975SMogball // them [-inf, inf] so we don't circle around infinitely often (because 134ab701975SMogball // the dataflow analysis in MLIR doesn't attempt to work out trip counts 135ab701975SMogball // and often can't). 136ab701975SMogball bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { 137ab701975SMogball return op->hasTrait<OpTrait::IsTerminator>(); 138ab701975SMogball }); 13947bf3e38SZhixun Tan if (isYieldedValue && !oldRange.isUninitialized() && 14047bf3e38SZhixun Tan !(lattice->getValue() == oldRange)) { 141ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); 142de0ebc52SZhixun Tan changed |= lattice->join(IntegerValueRange::getMaxRange(v)); 143ab701975SMogball } 144ab701975SMogball propagateIfChanged(lattice, changed); 145ab701975SMogball }; 146ab701975SMogball 1476aeea700SSpenser Bauman inferrable.inferResultRangesFromOptional(argRanges, joinCallback); 148ab701975SMogball return; 149ab701975SMogball } 150ab701975SMogball 151ab701975SMogball /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() 152ab701975SMogball /// on a LoopLikeInterface return the lower/upper bound for that result if 153ab701975SMogball /// possible. 15422426110SRamkumar Ramachandra auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound, 155ab701975SMogball Type boundType, bool getUpper) { 156ab701975SMogball unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); 157491d2701SKazu Hirata if (loopBound.has_value()) { 158*b5c5c2b2SKazu Hirata if (auto attr = dyn_cast<Attribute>(*loopBound)) { 159*b5c5c2b2SKazu Hirata if (auto bound = dyn_cast_or_null<IntegerAttr>(attr)) 160ab701975SMogball return bound.getValue(); 16168f58812STres Popp } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) { 162ab701975SMogball const IntegerValueRangeLattice *lattice = 1634b3f251bSdonald chen getLatticeElementFor(getProgramPointAfter(op), value); 16413c648f6SVictor Perez if (lattice != nullptr && !lattice->getValue().isUninitialized()) 165ab701975SMogball return getUpper ? lattice->getValue().getValue().smax() 166ab701975SMogball : lattice->getValue().getValue().smin(); 167ab701975SMogball } 168ab701975SMogball } 169ab701975SMogball // Given the results of getConstant{Lower,Upper}Bound() 170ab701975SMogball // or getConstantStep() on a LoopLikeInterface return the lower/upper 171ab701975SMogball // bound 172ab701975SMogball return getUpper ? APInt::getSignedMaxValue(width) 173ab701975SMogball : APInt::getSignedMinValue(width); 174ab701975SMogball }; 175ab701975SMogball 176ab701975SMogball // Infer bounds for loop arguments that have static bounds 177ab701975SMogball if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) { 17822426110SRamkumar Ramachandra std::optional<Value> iv = loop.getSingleInductionVar(); 179ab701975SMogball if (!iv) { 180b2b7efb9SAlex Zinenko return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( 181ab701975SMogball op, successor, argLattices, firstIndex); 182ab701975SMogball } 18322426110SRamkumar Ramachandra std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound(); 18422426110SRamkumar Ramachandra std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound(); 18522426110SRamkumar Ramachandra std::optional<OpFoldResult> step = loop.getSingleStep(); 186ab701975SMogball APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), 187ab701975SMogball /*getUpper=*/false); 188ab701975SMogball APInt max = getLoopBoundFromFold(upperBound, iv->getType(), 189ab701975SMogball /*getUpper=*/true); 190ab701975SMogball // Assume positivity for uniscoverable steps by way of getUpper = true. 191ab701975SMogball APInt stepVal = 192ab701975SMogball getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true); 193ab701975SMogball 194ab701975SMogball if (stepVal.isNegative()) { 195ab701975SMogball std::swap(min, max); 196ab701975SMogball } else { 197ab701975SMogball // Correct the upper bound by subtracting 1 so that it becomes a <= 198ab701975SMogball // bound, because loops do not generally include their upper bound. 199ab701975SMogball max -= 1; 200ab701975SMogball } 201ab701975SMogball 202b78883fcSFelix Schneider // If we infer the lower bound to be larger than the upper bound, the 203b78883fcSFelix Schneider // resulting range is meaningless and should not be used in further 204b78883fcSFelix Schneider // inferences. 205b78883fcSFelix Schneider if (max.sge(min)) { 206ab701975SMogball IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); 207ab701975SMogball auto ivRange = ConstantIntRanges::fromSigned(min, max); 20847bf3e38SZhixun Tan propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); 209b78883fcSFelix Schneider } 210ab701975SMogball return; 211ab701975SMogball } 212ab701975SMogball 213b2b7efb9SAlex Zinenko return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( 214ab701975SMogball op, successor, argLattices, firstIndex); 215ab701975SMogball } 216