1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Index/IR/IndexOps.h"
10 #include "mlir/Interfaces/InferIntRangeInterface.h"
11 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
12
13 #include "llvm/Support/Debug.h"
14 #include <optional>
15
16 #define DEBUG_TYPE "int-range-analysis"
17
18 using namespace mlir;
19 using namespace mlir::index;
20 using namespace mlir::intrange;
21
22 //===----------------------------------------------------------------------===//
23 // Constants
24 //===----------------------------------------------------------------------===//
25
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)26 void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
27 SetIntRangeFn setResultRange) {
28 const APInt &value = getValue();
29 setResultRange(getResult(), ConstantIntRanges::constant(value));
30 }
31
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)32 void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
33 SetIntRangeFn setResultRange) {
34 bool value = getValue();
35 APInt asInt(/*numBits=*/1, value);
36 setResultRange(getResult(), ConstantIntRanges::constant(asInt));
37 }
38
39 //===----------------------------------------------------------------------===//
40 // Arithmec operations. All of these operations will have their results inferred
41 // using both the 64-bit values and truncated 32-bit values of their inputs,
42 // with the results being the union of those inferences, except where the
43 // truncation of the 64-bit result is equal to the 32-bit result (at which time
44 // we take the 64-bit result).
45 //===----------------------------------------------------------------------===//
46
47 // Some arithmetic inference functions allow specifying special overflow / wrap
48 // behavior. We do not require this for the IndexOps and use this helper to call
49 // the inference function without any `OverflowFlags`.
50 static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn)51 inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
52 return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
53 return inferWithOvfFn(argRanges, OverflowFlags::None);
54 };
55 }
56
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)57 void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
58 SetIntRangeFn setResultRange) {
59 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
60 argRanges, CmpMode::Both));
61 }
62
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)63 void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
64 SetIntRangeFn setResultRange) {
65 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
66 argRanges, CmpMode::Both));
67 }
68
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)69 void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
70 SetIntRangeFn setResultRange) {
71 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
72 argRanges, CmpMode::Both));
73 }
74
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)75 void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76 SetIntRangeFn setResultRange) {
77 setResultRange(getResult(),
78 inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
79 }
80
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)81 void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
82 SetIntRangeFn setResultRange) {
83 setResultRange(getResult(),
84 inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
85 }
86
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)87 void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
88 SetIntRangeFn setResultRange) {
89 setResultRange(getResult(),
90 inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
91 }
92
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)93 void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
94 SetIntRangeFn setResultRange) {
95 setResultRange(getResult(),
96 inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
97 }
98
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)99 void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
100 SetIntRangeFn setResultRange) {
101 return setResultRange(
102 getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
103 }
104
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)105 void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
106 SetIntRangeFn setResultRange) {
107 setResultRange(getResult(),
108 inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
109 }
110
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)111 void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
112 SetIntRangeFn setResultRange) {
113 setResultRange(getResult(),
114 inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
115 }
116
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)117 void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
118 SetIntRangeFn setResultRange) {
119 setResultRange(getResult(),
120 inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
121 }
122
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)123 void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
124 SetIntRangeFn setResultRange) {
125 setResultRange(getResult(),
126 inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
127 }
128
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)129 void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
130 SetIntRangeFn setResultRange) {
131 setResultRange(getResult(),
132 inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
133 }
134
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)135 void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
136 SetIntRangeFn setResultRange) {
137 setResultRange(getResult(),
138 inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
139 }
140
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)141 void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
142 SetIntRangeFn setResultRange) {
143 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
144 argRanges, CmpMode::Both));
145 }
146
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)147 void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
148 SetIntRangeFn setResultRange) {
149 setResultRange(getResult(),
150 inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
151 }
152
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)153 void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
154 SetIntRangeFn setResultRange) {
155 setResultRange(getResult(),
156 inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
157 }
158
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)159 void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
160 SetIntRangeFn setResultRange) {
161 setResultRange(getResult(),
162 inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
163 }
164
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)165 void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
166 SetIntRangeFn setResultRange) {
167 setResultRange(getResult(),
168 inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
169 }
170
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)171 void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
172 SetIntRangeFn setResultRange) {
173 setResultRange(getResult(),
174 inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
175 }
176
177 //===----------------------------------------------------------------------===//
178 // Casts
179 //===----------------------------------------------------------------------===//
180
makeLikeDest(const ConstantIntRanges & range,unsigned srcWidth,unsigned destWidth,bool isSigned)181 static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
182 unsigned srcWidth, unsigned destWidth,
183 bool isSigned) {
184 if (srcWidth < destWidth)
185 return isSigned ? extSIRange(range, destWidth)
186 : extUIRange(range, destWidth);
187 if (srcWidth > destWidth)
188 return truncRange(range, destWidth);
189 return range;
190 }
191
192 // When casting to `index`, we will take the union of the possible fixed-width
193 // casts.
inferIndexCast(const ConstantIntRanges & range,Type sourceType,Type destType,bool isSigned)194 static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
195 Type sourceType, Type destType,
196 bool isSigned) {
197 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
198 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
199 if (sourceType.isIndex())
200 return makeLikeDest(range, srcWidth, destWidth, isSigned);
201 // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
202 ConstantIntRanges storageRange =
203 makeLikeDest(range, srcWidth, destWidth, isSigned);
204 ConstantIntRanges minWidthRange =
205 makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
206 ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
207 ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
208 return ret;
209 }
210
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)211 void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
212 SetIntRangeFn setResultRange) {
213 Type sourceType = getOperand().getType();
214 Type destType = getResult().getType();
215 setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
216 /*isSigned=*/true));
217 }
218
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)219 void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
220 SetIntRangeFn setResultRange) {
221 Type sourceType = getOperand().getType();
222 Type destType = getResult().getType();
223 setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
224 /*isSigned=*/false));
225 }
226
227 //===----------------------------------------------------------------------===//
228 // CmpOp
229 //===----------------------------------------------------------------------===//
230
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)231 void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
232 SetIntRangeFn setResultRange) {
233 index::IndexCmpPredicate indexPred = getPred();
234 intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
235 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
236
237 APInt min = APInt::getZero(1);
238 APInt max = APInt::getAllOnes(1);
239
240 std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
241
242 ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
243 rhsTrunc = truncRange(rhs, indexMinWidth);
244 std::optional<bool> truthValue32 =
245 intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
246
247 if (truthValue64 == truthValue32) {
248 if (truthValue64.has_value() && *truthValue64)
249 min = max;
250 else if (truthValue64.has_value() && !(*truthValue64))
251 max = min;
252 }
253 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
254 }
255
256 //===----------------------------------------------------------------------===//
257 // SizeOf, which is bounded between the two supported bitwidth (32 and 64).
258 //===----------------------------------------------------------------------===//
259
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)260 void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
261 SetIntRangeFn setResultRange) {
262 unsigned storageWidth =
263 ConstantIntRanges::getStorageBitwidth(getResult().getType());
264 APInt min(/*numBits=*/storageWidth, indexMinWidth);
265 APInt max(/*numBits=*/storageWidth, indexMaxWidth);
266 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
267 }
268