xref: /llvm-project/mlir/lib/Interfaces/InferIntRangeInterface.cpp (revision 616aff126caaf93a0d9868d279e4c99d1e45fef0)
1 //===- InferIntRangeInterface.cpp -  Integer range inference interface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/InferIntRangeInterface.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
13 #include <optional>
14 
15 using namespace mlir;
16 
17 bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
18   return umin().getBitWidth() == other.umin().getBitWidth() &&
19          umin() == other.umin() && umax() == other.umax() &&
20          smin() == other.smin() && smax() == other.smax();
21 }
22 
23 const APInt &ConstantIntRanges::umin() const { return uminVal; }
24 
25 const APInt &ConstantIntRanges::umax() const { return umaxVal; }
26 
27 const APInt &ConstantIntRanges::smin() const { return sminVal; }
28 
29 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
30 
31 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
32   type = getElementTypeOrSelf(type);
33   if (type.isIndex())
34     return IndexType::kInternalStorageBitWidth;
35   if (auto integerType = dyn_cast<IntegerType>(type))
36     return integerType.getWidth();
37   // Non-integer types have their bounds stored in width 0 `APInt`s.
38   return 0;
39 }
40 
41 ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
42   return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
43 }
44 
45 ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
46   return {value, value, value, value};
47 }
48 
49 ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max,
50                                            bool isSigned) {
51   if (isSigned)
52     return fromSigned(min, max);
53   return fromUnsigned(min, max);
54 }
55 
56 ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
57                                                 const APInt &smax) {
58   unsigned int width = smin.getBitWidth();
59   APInt umin, umax;
60   if (smin.isNonNegative() == smax.isNonNegative()) {
61     umin = smin.ult(smax) ? smin : smax;
62     umax = smin.ugt(smax) ? smin : smax;
63   } else {
64     umin = APInt::getMinValue(width);
65     umax = APInt::getMaxValue(width);
66   }
67   return {umin, umax, smin, smax};
68 }
69 
70 ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
71                                                   const APInt &umax) {
72   unsigned int width = umin.getBitWidth();
73   APInt smin, smax;
74   if (umin.isNonNegative() == umax.isNonNegative()) {
75     smin = umin.slt(umax) ? umin : umax;
76     smax = umin.sgt(umax) ? umin : umax;
77   } else {
78     smin = APInt::getSignedMinValue(width);
79     smax = APInt::getSignedMaxValue(width);
80   }
81   return {umin, umax, smin, smax};
82 }
83 
84 ConstantIntRanges
85 ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
86   // "Not an integer" poisons everything and also cannot be fed to comparison
87   // operators.
88   if (umin().getBitWidth() == 0)
89     return *this;
90   if (other.umin().getBitWidth() == 0)
91     return other;
92 
93   const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
94   const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
95   const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
96   const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
97 
98   return {uminUnion, umaxUnion, sminUnion, smaxUnion};
99 }
100 
101 ConstantIntRanges
102 ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
103   // "Not an integer" poisons everything and also cannot be fed to comparison
104   // operators.
105   if (umin().getBitWidth() == 0)
106     return *this;
107   if (other.umin().getBitWidth() == 0)
108     return other;
109 
110   const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
111   const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
112   const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
113   const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
114 
115   return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
116 }
117 
118 std::optional<APInt> ConstantIntRanges::getConstantValue() const {
119   // Note: we need to exclude the trivially-equal width 0 values here.
120   if (umin() == umax() && umin().getBitWidth() != 0)
121     return umin();
122   if (smin() == smax() && smin().getBitWidth() != 0)
123     return smin();
124   return std::nullopt;
125 }
126 
127 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
128   return os << "unsigned : [" << range.umin() << ", " << range.umax()
129             << "] signed : [" << range.smin() << ", " << range.smax() << "]";
130 }
131 
132 IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
133   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
134   if (width == 0)
135     return {};
136 
137   APInt umin = APInt::getMinValue(width);
138   APInt umax = APInt::getMaxValue(width);
139   APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
140   APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
141   return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
142 }
143 
144 raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
145   range.print(os);
146   return os;
147 }
148 
149 void mlir::intrange::detail::defaultInferResultRanges(
150     InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
151     SetIntLatticeFn setResultRanges) {
152   llvm::SmallVector<ConstantIntRanges> unpacked;
153   unpacked.reserve(argRanges.size());
154 
155   for (const IntegerValueRange &range : argRanges) {
156     if (range.isUninitialized())
157       return;
158     unpacked.push_back(range.getValue());
159   }
160 
161   interface.inferResultRanges(
162       unpacked,
163       [&setResultRanges](Value value, const ConstantIntRanges &argRanges) {
164         setResultRanges(value, IntegerValueRange{argRanges});
165       });
166 }
167 
168 void mlir::intrange::detail::defaultInferResultRangesFromOptional(
169     InferIntRangeInterface interface, ArrayRef<ConstantIntRanges> argRanges,
170     SetIntRangeFn setResultRanges) {
171   auto ranges = llvm::to_vector_of<IntegerValueRange>(argRanges);
172   interface.inferResultRangesFromOptional(
173       ranges,
174       [&setResultRanges](Value value, const IntegerValueRange &argRanges) {
175         if (!argRanges.isUninitialized())
176           setResultRanges(value, argRanges.getValue());
177       });
178 }
179