1 //===- InferIntRangeInterface.h - Integer Range Inference --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains definitions of the integer range inference interface 10 // defined in `InferIntRange.td` 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H 15 #define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H 16 17 #include "mlir/IR/OpDefinition.h" 18 #include <optional> 19 20 namespace mlir { 21 /// A set of arbitrary-precision integers representing bounds on a given integer 22 /// value. These bounds are inclusive on both ends, so 23 /// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for 24 /// the unsigned and signed interpretations of values in order to enable more 25 /// precice inference of the interplay between operations with signed and 26 /// unsigned semantics. 27 class ConstantIntRanges { 28 public: 29 /// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax. 30 /// Non-integer values should be bounded by APInts of bitwidth 0. ConstantIntRanges(const APInt & umin,const APInt & umax,const APInt & smin,const APInt & smax)31 ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin, 32 const APInt &smax) 33 : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) { 34 assert(uminVal.getBitWidth() == umaxVal.getBitWidth() && 35 umaxVal.getBitWidth() == sminVal.getBitWidth() && 36 sminVal.getBitWidth() == smaxVal.getBitWidth() && 37 "All bounds in the ranges must have the same bitwidth"); 38 } 39 40 bool operator==(const ConstantIntRanges &other) const; 41 42 /// The minimum value of an integer when it is interpreted as unsigned. 43 const APInt &umin() const; 44 45 /// The maximum value of an integer when it is interpreted as unsigned. 46 const APInt &umax() const; 47 48 /// The minimum value of an integer when it is interpreted as signed. 49 const APInt &smin() const; 50 51 /// The maximum value of an integer when it is interpreted as signed. 52 const APInt &smax() const; 53 54 /// Return the bitwidth that should be used for integer ranges describing 55 /// `type`. For concrete integer types, this is their bitwidth, for `index`, 56 /// this is the internal storage bitwidth of `index` attributes, and for 57 /// non-integer types this is 0. 58 static unsigned getStorageBitwidth(Type type); 59 60 /// Create a `ConstantIntRanges` with the maximum bounds for the width 61 /// `bitwidth`, that is - [0, uint_max(width)]/[sint_min(width), 62 /// sint_max(width)]. 63 static ConstantIntRanges maxRange(unsigned bitwidth); 64 65 /// Create a `ConstantIntRanges` with a constant value - that is, with the 66 /// bounds [value, value] for both its signed interpretations. 67 static ConstantIntRanges constant(const APInt &value); 68 69 /// Create a `ConstantIntRanges` whose minimum is `min` and maximum is `max` 70 /// with `isSigned` specifying if the min and max should be interpreted as 71 /// signed or unsigned. 72 static ConstantIntRanges range(const APInt &min, const APInt &max, 73 bool isSigned); 74 75 /// Create an `ConstantIntRanges` with the signed minimum and maximum equal 76 /// to `smin` and `smax`, where the unsigned bounds are constructed from the 77 /// signed ones if they correspond to a contigious range of bit patterns when 78 /// viewed as unsigned values and are left at [0, int_max()] otherwise. 79 static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax); 80 81 /// Create an `ConstantIntRanges` with the unsigned minimum and maximum equal 82 /// to `umin` and `umax` and the signed part equal to `umin` and `umax` 83 /// unless the sign bit changes between the minimum and maximum. 84 static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax); 85 86 /// Returns the union (computed separately for signed and unsigned bounds) 87 /// of this range and `other`. 88 ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const; 89 90 /// Returns the intersection (computed separately for signed and unsigned 91 /// bounds) of this range and `other`. 92 ConstantIntRanges intersection(const ConstantIntRanges &other) const; 93 94 /// If either the signed or unsigned interpretations of the range 95 /// indicate that the value it bounds is a constant, return that constant 96 /// value. 97 std::optional<APInt> getConstantValue() const; 98 99 friend raw_ostream &operator<<(raw_ostream &os, 100 const ConstantIntRanges &range); 101 102 private: 103 APInt uminVal, umaxVal, sminVal, smaxVal; 104 }; 105 106 raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &); 107 108 /// This lattice value represents the integer range of an SSA value. 109 class IntegerValueRange { 110 public: 111 /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) 112 /// range that is used to mark the value as unable to be analyzed further, 113 /// where `t` is the type of `value`. 114 static IntegerValueRange getMaxRange(Value value); 115 116 /// Create an integer value range lattice value. IntegerValueRange(ConstantIntRanges value)117 IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {} 118 119 /// Create an integer value range lattice value. 120 IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt) value(std::move (value))121 : value(std::move(value)) {} 122 123 /// Whether the range is uninitialized. This happens when the state hasn't 124 /// been set during the analysis. isUninitialized()125 bool isUninitialized() const { return !value.has_value(); } 126 127 /// Get the known integer value range. getValue()128 const ConstantIntRanges &getValue() const { 129 assert(!isUninitialized()); 130 return *value; 131 } 132 133 /// Compare two ranges. 134 bool operator==(const IntegerValueRange &rhs) const { 135 return value == rhs.value; 136 } 137 138 /// Compute the least upper bound of two ranges. join(const IntegerValueRange & lhs,const IntegerValueRange & rhs)139 static IntegerValueRange join(const IntegerValueRange &lhs, 140 const IntegerValueRange &rhs) { 141 if (lhs.isUninitialized()) 142 return rhs; 143 if (rhs.isUninitialized()) 144 return lhs; 145 return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())}; 146 } 147 148 /// Print the integer value range. print(raw_ostream & os)149 void print(raw_ostream &os) const { os << value; } 150 151 private: 152 /// The known integer value range. 153 std::optional<ConstantIntRanges> value; 154 }; 155 156 raw_ostream &operator<<(raw_ostream &, const IntegerValueRange &); 157 158 /// The type of the `setResultRanges` callback provided to ops implementing 159 /// InferIntRangeInterface. It should be called once for each integer result 160 /// value and be passed the ConstantIntRanges corresponding to that value. 161 using SetIntRangeFn = 162 llvm::function_ref<void(Value, const ConstantIntRanges &)>; 163 164 /// Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values. 165 /// This is the `setResultRanges` callback for the IntegerValueRange based 166 /// interface method. 167 using SetIntLatticeFn = 168 llvm::function_ref<void(Value, const IntegerValueRange &)>; 169 170 class InferIntRangeInterface; 171 172 namespace intrange::detail { 173 /// Default implementation of `inferResultRanges` which dispatches to the 174 /// `inferResultRangesFromOptional`. 175 void defaultInferResultRanges(InferIntRangeInterface interface, 176 ArrayRef<IntegerValueRange> argRanges, 177 SetIntLatticeFn setResultRanges); 178 179 /// Default implementation of `inferResultRangesFromOptional` which dispatches 180 /// to the `inferResultRanges`. 181 void defaultInferResultRangesFromOptional(InferIntRangeInterface interface, 182 ArrayRef<ConstantIntRanges> argRanges, 183 SetIntRangeFn setResultRanges); 184 } // end namespace intrange::detail 185 } // end namespace mlir 186 187 #include "mlir/Interfaces/InferIntRangeInterface.h.inc" 188 189 #endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H 190