xref: /llvm-project/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp (revision 2034f2fc8729bd4645ef7caa3c5c6efa284d2d3f)
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Index/IR/IndexOps.h"
10 #include "mlir/Interfaces/InferIntRangeInterface.h"
11 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
12 
13 #include "llvm/Support/Debug.h"
14 #include <optional>
15 
16 #define DEBUG_TYPE "int-range-analysis"
17 
18 using namespace mlir;
19 using namespace mlir::index;
20 using namespace mlir::intrange;
21 
22 //===----------------------------------------------------------------------===//
23 // Constants
24 //===----------------------------------------------------------------------===//
25 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)26 void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
27                                    SetIntRangeFn setResultRange) {
28   const APInt &value = getValue();
29   setResultRange(getResult(), ConstantIntRanges::constant(value));
30 }
31 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)32 void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
33                                        SetIntRangeFn setResultRange) {
34   bool value = getValue();
35   APInt asInt(/*numBits=*/1, value);
36   setResultRange(getResult(), ConstantIntRanges::constant(asInt));
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Arithmec operations. All of these operations will have their results inferred
41 // using both the 64-bit values and truncated 32-bit values of their inputs,
42 // with the results being the union of those inferences, except where the
43 // truncation of the 64-bit result is equal to the 32-bit result (at which time
44 // we take the 64-bit result).
45 //===----------------------------------------------------------------------===//
46 
47 // Some arithmetic inference functions allow specifying special overflow / wrap
48 // behavior. We do not require this for the IndexOps and use this helper to call
49 // the inference function without any `OverflowFlags`.
50 static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn)51 inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
52   return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
53     return inferWithOvfFn(argRanges, OverflowFlags::None);
54   };
55 }
56 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)57 void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
58                               SetIntRangeFn setResultRange) {
59   setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
60                                            argRanges, CmpMode::Both));
61 }
62 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)63 void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
64                               SetIntRangeFn setResultRange) {
65   setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
66                                            argRanges, CmpMode::Both));
67 }
68 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)69 void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
70                               SetIntRangeFn setResultRange) {
71   setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
72                                            argRanges, CmpMode::Both));
73 }
74 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)75 void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76                                SetIntRangeFn setResultRange) {
77   setResultRange(getResult(),
78                  inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
79 }
80 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)81 void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
82                                SetIntRangeFn setResultRange) {
83   setResultRange(getResult(),
84                  inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
85 }
86 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)87 void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
88                                    SetIntRangeFn setResultRange) {
89   setResultRange(getResult(),
90                  inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
91 }
92 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)93 void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
94                                    SetIntRangeFn setResultRange) {
95   setResultRange(getResult(),
96                  inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
97 }
98 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)99 void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
100                                     SetIntRangeFn setResultRange) {
101   return setResultRange(
102       getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
103 }
104 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)105 void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
106                                SetIntRangeFn setResultRange) {
107   setResultRange(getResult(),
108                  inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
109 }
110 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)111 void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
112                                SetIntRangeFn setResultRange) {
113   setResultRange(getResult(),
114                  inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
115 }
116 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)117 void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
118                                SetIntRangeFn setResultRange) {
119   setResultRange(getResult(),
120                  inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
121 }
122 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)123 void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
124                                SetIntRangeFn setResultRange) {
125   setResultRange(getResult(),
126                  inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
127 }
128 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)129 void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
130                                SetIntRangeFn setResultRange) {
131   setResultRange(getResult(),
132                  inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
133 }
134 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)135 void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
136                                SetIntRangeFn setResultRange) {
137   setResultRange(getResult(),
138                  inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
139 }
140 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)141 void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
142                               SetIntRangeFn setResultRange) {
143   setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
144                                            argRanges, CmpMode::Both));
145 }
146 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)147 void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
148                                SetIntRangeFn setResultRange) {
149   setResultRange(getResult(),
150                  inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
151 }
152 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)153 void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
154                                SetIntRangeFn setResultRange) {
155   setResultRange(getResult(),
156                  inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
157 }
158 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)159 void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
160                               SetIntRangeFn setResultRange) {
161   setResultRange(getResult(),
162                  inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
163 }
164 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)165 void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
166                              SetIntRangeFn setResultRange) {
167   setResultRange(getResult(),
168                  inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
169 }
170 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)171 void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
172                               SetIntRangeFn setResultRange) {
173   setResultRange(getResult(),
174                  inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // Casts
179 //===----------------------------------------------------------------------===//
180 
makeLikeDest(const ConstantIntRanges & range,unsigned srcWidth,unsigned destWidth,bool isSigned)181 static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
182                                       unsigned srcWidth, unsigned destWidth,
183                                       bool isSigned) {
184   if (srcWidth < destWidth)
185     return isSigned ? extSIRange(range, destWidth)
186                     : extUIRange(range, destWidth);
187   if (srcWidth > destWidth)
188     return truncRange(range, destWidth);
189   return range;
190 }
191 
192 // When casting to `index`, we will take the union of the possible fixed-width
193 // casts.
inferIndexCast(const ConstantIntRanges & range,Type sourceType,Type destType,bool isSigned)194 static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
195                                         Type sourceType, Type destType,
196                                         bool isSigned) {
197   unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
198   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
199   if (sourceType.isIndex())
200     return makeLikeDest(range, srcWidth, destWidth, isSigned);
201   // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
202   ConstantIntRanges storageRange =
203       makeLikeDest(range, srcWidth, destWidth, isSigned);
204   ConstantIntRanges minWidthRange =
205       makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
206   ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
207   ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
208   return ret;
209 }
210 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)211 void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
212                                 SetIntRangeFn setResultRange) {
213   Type sourceType = getOperand().getType();
214   Type destType = getResult().getType();
215   setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
216                                              /*isSigned=*/true));
217 }
218 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)219 void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
220                                 SetIntRangeFn setResultRange) {
221   Type sourceType = getOperand().getType();
222   Type destType = getResult().getType();
223   setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
224                                              /*isSigned=*/false));
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // CmpOp
229 //===----------------------------------------------------------------------===//
230 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)231 void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
232                               SetIntRangeFn setResultRange) {
233   index::IndexCmpPredicate indexPred = getPred();
234   intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
235   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
236 
237   APInt min = APInt::getZero(1);
238   APInt max = APInt::getAllOnes(1);
239 
240   std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
241 
242   ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
243                     rhsTrunc = truncRange(rhs, indexMinWidth);
244   std::optional<bool> truthValue32 =
245       intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
246 
247   if (truthValue64 == truthValue32) {
248     if (truthValue64.has_value() && *truthValue64)
249       min = max;
250     else if (truthValue64.has_value() && !(*truthValue64))
251       max = min;
252   }
253   setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // SizeOf, which is bounded between the two supported bitwidth (32 and 64).
258 //===----------------------------------------------------------------------===//
259 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)260 void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
261                                  SetIntRangeFn setResultRange) {
262   unsigned storageWidth =
263       ConstantIntRanges::getStorageBitwidth(getResult().getType());
264   APInt min(/*numBits=*/storageWidth, indexMinWidth);
265   APInt max(/*numBits=*/storageWidth, indexMaxWidth);
266   setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
267 }
268