1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the dataflow analysis class for integer range inference 10 // which is used in transformations over the `arith` dialect such as 11 // branch elimination or signed->unsigned rewriting 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" 16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 17 #include "mlir/Analysis/DataFlow/SparseAnalysis.h" 18 #include "mlir/Analysis/DataFlowFramework.h" 19 #include "mlir/IR/BuiltinAttributes.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/OpDefinition.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/IR/Value.h" 24 #include "mlir/Interfaces/ControlFlowInterfaces.h" 25 #include "mlir/Interfaces/InferIntRangeInterface.h" 26 #include "mlir/Interfaces/LoopLikeInterface.h" 27 #include "mlir/Support/LLVM.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/Support/Casting.h" 30 #include "llvm/Support/Debug.h" 31 #include <cassert> 32 #include <optional> 33 #include <utility> 34 35 #define DEBUG_TYPE "int-range-analysis" 36 37 using namespace mlir; 38 using namespace mlir::dataflow; 39 40 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { 41 Lattice::onUpdate(solver); 42 43 // If the integer range can be narrowed to a constant, update the constant 44 // value of the SSA value. 45 std::optional<APInt> constant = getValue().getValue().getConstantValue(); 46 auto value = cast<Value>(anchor); 47 auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value); 48 if (!constant) 49 return solver->propagateIfChanged( 50 cv, cv->join(ConstantValue::getUnknownConstant())); 51 52 Dialect *dialect; 53 if (auto *parent = value.getDefiningOp()) 54 dialect = parent->getDialect(); 55 else 56 dialect = value.getParentBlock()->getParentOp()->getDialect(); 57 58 Type type = getElementTypeOrSelf(value); 59 solver->propagateIfChanged( 60 cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect))); 61 } 62 63 LogicalResult IntegerRangeAnalysis::visitOperation( 64 Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands, 65 ArrayRef<IntegerValueRangeLattice *> results) { 66 auto inferrable = dyn_cast<InferIntRangeInterface>(op); 67 if (!inferrable) { 68 setAllToEntryStates(results); 69 return success(); 70 } 71 72 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); 73 auto argRanges = llvm::map_to_vector( 74 operands, [](const IntegerValueRangeLattice *lattice) { 75 return lattice->getValue(); 76 }); 77 78 auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { 79 auto result = dyn_cast<OpResult>(v); 80 if (!result) 81 return; 82 assert(llvm::is_contained(op->getResults(), result)); 83 84 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); 85 IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; 86 IntegerValueRange oldRange = lattice->getValue(); 87 88 ChangeResult changed = lattice->join(attrs); 89 90 // Catch loop results with loop variant bounds and conservatively make 91 // them [-inf, inf] so we don't circle around infinitely often (because 92 // the dataflow analysis in MLIR doesn't attempt to work out trip counts 93 // and often can't). 94 bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { 95 return op->hasTrait<OpTrait::IsTerminator>(); 96 }); 97 if (isYieldedResult && !oldRange.isUninitialized() && 98 !(lattice->getValue() == oldRange)) { 99 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); 100 changed |= lattice->join(IntegerValueRange::getMaxRange(v)); 101 } 102 propagateIfChanged(lattice, changed); 103 }; 104 105 inferrable.inferResultRangesFromOptional(argRanges, joinCallback); 106 return success(); 107 } 108 109 void IntegerRangeAnalysis::visitNonControlFlowArguments( 110 Operation *op, const RegionSuccessor &successor, 111 ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) { 112 if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) { 113 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); 114 115 auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { 116 return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); 117 }); 118 119 auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { 120 auto arg = dyn_cast<BlockArgument>(v); 121 if (!arg) 122 return; 123 if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg)) 124 return; 125 126 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); 127 IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; 128 IntegerValueRange oldRange = lattice->getValue(); 129 130 ChangeResult changed = lattice->join(attrs); 131 132 // Catch loop results with loop variant bounds and conservatively make 133 // them [-inf, inf] so we don't circle around infinitely often (because 134 // the dataflow analysis in MLIR doesn't attempt to work out trip counts 135 // and often can't). 136 bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { 137 return op->hasTrait<OpTrait::IsTerminator>(); 138 }); 139 if (isYieldedValue && !oldRange.isUninitialized() && 140 !(lattice->getValue() == oldRange)) { 141 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); 142 changed |= lattice->join(IntegerValueRange::getMaxRange(v)); 143 } 144 propagateIfChanged(lattice, changed); 145 }; 146 147 inferrable.inferResultRangesFromOptional(argRanges, joinCallback); 148 return; 149 } 150 151 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() 152 /// on a LoopLikeInterface return the lower/upper bound for that result if 153 /// possible. 154 auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound, 155 Type boundType, bool getUpper) { 156 unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); 157 if (loopBound.has_value()) { 158 if (auto attr = dyn_cast<Attribute>(*loopBound)) { 159 if (auto bound = dyn_cast_or_null<IntegerAttr>(attr)) 160 return bound.getValue(); 161 } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) { 162 const IntegerValueRangeLattice *lattice = 163 getLatticeElementFor(getProgramPointAfter(op), value); 164 if (lattice != nullptr && !lattice->getValue().isUninitialized()) 165 return getUpper ? lattice->getValue().getValue().smax() 166 : lattice->getValue().getValue().smin(); 167 } 168 } 169 // Given the results of getConstant{Lower,Upper}Bound() 170 // or getConstantStep() on a LoopLikeInterface return the lower/upper 171 // bound 172 return getUpper ? APInt::getSignedMaxValue(width) 173 : APInt::getSignedMinValue(width); 174 }; 175 176 // Infer bounds for loop arguments that have static bounds 177 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) { 178 std::optional<Value> iv = loop.getSingleInductionVar(); 179 if (!iv) { 180 return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( 181 op, successor, argLattices, firstIndex); 182 } 183 std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound(); 184 std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound(); 185 std::optional<OpFoldResult> step = loop.getSingleStep(); 186 APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), 187 /*getUpper=*/false); 188 APInt max = getLoopBoundFromFold(upperBound, iv->getType(), 189 /*getUpper=*/true); 190 // Assume positivity for uniscoverable steps by way of getUpper = true. 191 APInt stepVal = 192 getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true); 193 194 if (stepVal.isNegative()) { 195 std::swap(min, max); 196 } else { 197 // Correct the upper bound by subtracting 1 so that it becomes a <= 198 // bound, because loops do not generally include their upper bound. 199 max -= 1; 200 } 201 202 // If we infer the lower bound to be larger than the upper bound, the 203 // resulting range is meaningless and should not be used in further 204 // inferences. 205 if (max.sge(min)) { 206 IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); 207 auto ivRange = ConstantIntRanges::fromSigned(min, max); 208 propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); 209 } 210 return; 211 } 212 213 return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( 214 op, successor, argLattices, firstIndex); 215 } 216