1 //===- InferIntRangeInterface.cpp - Integer range inference interface ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/InferIntRangeInterface.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "mlir/IR/TypeUtilities.h" 12 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc" 13 #include <optional> 14 15 using namespace mlir; 16 17 bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const { 18 return umin().getBitWidth() == other.umin().getBitWidth() && 19 umin() == other.umin() && umax() == other.umax() && 20 smin() == other.smin() && smax() == other.smax(); 21 } 22 23 const APInt &ConstantIntRanges::umin() const { return uminVal; } 24 25 const APInt &ConstantIntRanges::umax() const { return umaxVal; } 26 27 const APInt &ConstantIntRanges::smin() const { return sminVal; } 28 29 const APInt &ConstantIntRanges::smax() const { return smaxVal; } 30 31 unsigned ConstantIntRanges::getStorageBitwidth(Type type) { 32 type = getElementTypeOrSelf(type); 33 if (type.isIndex()) 34 return IndexType::kInternalStorageBitWidth; 35 if (auto integerType = dyn_cast<IntegerType>(type)) 36 return integerType.getWidth(); 37 // Non-integer types have their bounds stored in width 0 `APInt`s. 38 return 0; 39 } 40 41 ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) { 42 return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth)); 43 } 44 45 ConstantIntRanges ConstantIntRanges::constant(const APInt &value) { 46 return {value, value, value, value}; 47 } 48 49 ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max, 50 bool isSigned) { 51 if (isSigned) 52 return fromSigned(min, max); 53 return fromUnsigned(min, max); 54 } 55 56 ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin, 57 const APInt &smax) { 58 unsigned int width = smin.getBitWidth(); 59 APInt umin, umax; 60 if (smin.isNonNegative() == smax.isNonNegative()) { 61 umin = smin.ult(smax) ? smin : smax; 62 umax = smin.ugt(smax) ? smin : smax; 63 } else { 64 umin = APInt::getMinValue(width); 65 umax = APInt::getMaxValue(width); 66 } 67 return {umin, umax, smin, smax}; 68 } 69 70 ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin, 71 const APInt &umax) { 72 unsigned int width = umin.getBitWidth(); 73 APInt smin, smax; 74 if (umin.isNonNegative() == umax.isNonNegative()) { 75 smin = umin.slt(umax) ? umin : umax; 76 smax = umin.sgt(umax) ? umin : umax; 77 } else { 78 smin = APInt::getSignedMinValue(width); 79 smax = APInt::getSignedMaxValue(width); 80 } 81 return {umin, umax, smin, smax}; 82 } 83 84 ConstantIntRanges 85 ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const { 86 // "Not an integer" poisons everything and also cannot be fed to comparison 87 // operators. 88 if (umin().getBitWidth() == 0) 89 return *this; 90 if (other.umin().getBitWidth() == 0) 91 return other; 92 93 const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin(); 94 const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax(); 95 const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin(); 96 const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax(); 97 98 return {uminUnion, umaxUnion, sminUnion, smaxUnion}; 99 } 100 101 ConstantIntRanges 102 ConstantIntRanges::intersection(const ConstantIntRanges &other) const { 103 // "Not an integer" poisons everything and also cannot be fed to comparison 104 // operators. 105 if (umin().getBitWidth() == 0) 106 return *this; 107 if (other.umin().getBitWidth() == 0) 108 return other; 109 110 const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin(); 111 const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax(); 112 const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin(); 113 const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax(); 114 115 return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect}; 116 } 117 118 std::optional<APInt> ConstantIntRanges::getConstantValue() const { 119 // Note: we need to exclude the trivially-equal width 0 values here. 120 if (umin() == umax() && umin().getBitWidth() != 0) 121 return umin(); 122 if (smin() == smax() && smin().getBitWidth() != 0) 123 return smin(); 124 return std::nullopt; 125 } 126 127 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) { 128 return os << "unsigned : [" << range.umin() << ", " << range.umax() 129 << "] signed : [" << range.smin() << ", " << range.smax() << "]"; 130 } 131 132 IntegerValueRange IntegerValueRange::getMaxRange(Value value) { 133 unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType()); 134 if (width == 0) 135 return {}; 136 137 APInt umin = APInt::getMinValue(width); 138 APInt umax = APInt::getMaxValue(width); 139 APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin; 140 APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax; 141 return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}}; 142 } 143 144 raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) { 145 range.print(os); 146 return os; 147 } 148 149 void mlir::intrange::detail::defaultInferResultRanges( 150 InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges, 151 SetIntLatticeFn setResultRanges) { 152 llvm::SmallVector<ConstantIntRanges> unpacked; 153 unpacked.reserve(argRanges.size()); 154 155 for (const IntegerValueRange &range : argRanges) { 156 if (range.isUninitialized()) 157 return; 158 unpacked.push_back(range.getValue()); 159 } 160 161 interface.inferResultRanges( 162 unpacked, 163 [&setResultRanges](Value value, const ConstantIntRanges &argRanges) { 164 setResultRanges(value, IntegerValueRange{argRanges}); 165 }); 166 } 167 168 void mlir::intrange::detail::defaultInferResultRangesFromOptional( 169 InferIntRangeInterface interface, ArrayRef<ConstantIntRanges> argRanges, 170 SetIntRangeFn setResultRanges) { 171 auto ranges = llvm::to_vector_of<IntegerValueRange>(argRanges); 172 interface.inferResultRangesFromOptional( 173 ranges, 174 [&setResultRanges](Value value, const IntegerValueRange &argRanges) { 175 if (!argRanges.isUninitialized()) 176 setResultRanges(value, argRanges.getValue()); 177 }); 178 } 179