xref: /llvm-project/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp (revision 441b82b20bf3a622155354e17ae66e0ccff50796)
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Interfaces/InferIntRangeInterface.h"
11 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
12 
13 #include "llvm/Support/Debug.h"
14 #include <optional>
15 
16 #define DEBUG_TYPE "int-range-analysis"
17 
18 using namespace mlir;
19 using namespace mlir::arith;
20 using namespace mlir::intrange;
21 
22 static intrange::OverflowFlags
23 convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
24   intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
25   if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
26     retFlags |= intrange::OverflowFlags::Nsw;
27   if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
28     retFlags |= intrange::OverflowFlags::Nuw;
29   return retFlags;
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // ConstantOp
34 //===----------------------------------------------------------------------===//
35 
36 void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
37                                           SetIntRangeFn setResultRange) {
38   if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
39     const APInt &value = scalarCstAttr.getValue();
40     setResultRange(getResult(), ConstantIntRanges::constant(value));
41     return;
42   }
43   if (auto arrayCstAttr =
44           llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
45     if (arrayCstAttr.isSplat()) {
46       setResultRange(getResult(), ConstantIntRanges::constant(
47                                       arrayCstAttr.getSplatValue<APInt>()));
48       return;
49     }
50 
51     std::optional<ConstantIntRanges> result;
52     for (const APInt &val : arrayCstAttr) {
53       auto range = ConstantIntRanges::constant(val);
54       result = (result ? result->rangeUnion(range) : range);
55     }
56 
57     assert(result && "Zero-sized vectors are not allowed");
58     setResultRange(getResult(), *result);
59     return;
60   }
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // AddIOp
65 //===----------------------------------------------------------------------===//
66 
67 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
68                                       SetIntRangeFn setResultRange) {
69   setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
70                                                       getOverflowFlags())));
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // SubIOp
75 //===----------------------------------------------------------------------===//
76 
77 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
78                                       SetIntRangeFn setResultRange) {
79   setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
80                                                       getOverflowFlags())));
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // MulIOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
88                                       SetIntRangeFn setResultRange) {
89   setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
90                                                       getOverflowFlags())));
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // DivUIOp
95 //===----------------------------------------------------------------------===//
96 
97 void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
98                                        SetIntRangeFn setResultRange) {
99   setResultRange(getResult(), inferDivU(argRanges));
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // DivSIOp
104 //===----------------------------------------------------------------------===//
105 
106 void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
107                                        SetIntRangeFn setResultRange) {
108   setResultRange(getResult(), inferDivS(argRanges));
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // CeilDivUIOp
113 //===----------------------------------------------------------------------===//
114 
115 void arith::CeilDivUIOp::inferResultRanges(
116     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
117   setResultRange(getResult(), inferCeilDivU(argRanges));
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // CeilDivSIOp
122 //===----------------------------------------------------------------------===//
123 
124 void arith::CeilDivSIOp::inferResultRanges(
125     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
126   setResultRange(getResult(), inferCeilDivS(argRanges));
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // FloorDivSIOp
131 //===----------------------------------------------------------------------===//
132 
133 void arith::FloorDivSIOp::inferResultRanges(
134     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
135   return setResultRange(getResult(), inferFloorDivS(argRanges));
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // RemUIOp
140 //===----------------------------------------------------------------------===//
141 
142 void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
143                                        SetIntRangeFn setResultRange) {
144   setResultRange(getResult(), inferRemU(argRanges));
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // RemSIOp
149 //===----------------------------------------------------------------------===//
150 
151 void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
152                                        SetIntRangeFn setResultRange) {
153   setResultRange(getResult(), inferRemS(argRanges));
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // AndIOp
158 //===----------------------------------------------------------------------===//
159 
160 void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
161                                       SetIntRangeFn setResultRange) {
162   setResultRange(getResult(), inferAnd(argRanges));
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // OrIOp
167 //===----------------------------------------------------------------------===//
168 
169 void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
170                                      SetIntRangeFn setResultRange) {
171   setResultRange(getResult(), inferOr(argRanges));
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // XOrIOp
176 //===----------------------------------------------------------------------===//
177 
178 void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
179                                       SetIntRangeFn setResultRange) {
180   setResultRange(getResult(), inferXor(argRanges));
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // MaxSIOp
185 //===----------------------------------------------------------------------===//
186 
187 void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
188                                        SetIntRangeFn setResultRange) {
189   setResultRange(getResult(), inferMaxS(argRanges));
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // MaxUIOp
194 //===----------------------------------------------------------------------===//
195 
196 void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
197                                        SetIntRangeFn setResultRange) {
198   setResultRange(getResult(), inferMaxU(argRanges));
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // MinSIOp
203 //===----------------------------------------------------------------------===//
204 
205 void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
206                                        SetIntRangeFn setResultRange) {
207   setResultRange(getResult(), inferMinS(argRanges));
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // MinUIOp
212 //===----------------------------------------------------------------------===//
213 
214 void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
215                                        SetIntRangeFn setResultRange) {
216   setResultRange(getResult(), inferMinU(argRanges));
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // ExtUIOp
221 //===----------------------------------------------------------------------===//
222 
223 void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
224                                        SetIntRangeFn setResultRange) {
225   unsigned destWidth =
226       ConstantIntRanges::getStorageBitwidth(getResult().getType());
227   setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // ExtSIOp
232 //===----------------------------------------------------------------------===//
233 
234 void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
235                                        SetIntRangeFn setResultRange) {
236   unsigned destWidth =
237       ConstantIntRanges::getStorageBitwidth(getResult().getType());
238   setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // TruncIOp
243 //===----------------------------------------------------------------------===//
244 
245 void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
246                                         SetIntRangeFn setResultRange) {
247   unsigned destWidth =
248       ConstantIntRanges::getStorageBitwidth(getResult().getType());
249   setResultRange(getResult(), truncRange(argRanges[0], destWidth));
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // IndexCastOp
254 //===----------------------------------------------------------------------===//
255 
256 void arith::IndexCastOp::inferResultRanges(
257     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
258   Type sourceType = getOperand().getType();
259   Type destType = getResult().getType();
260   unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
261   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
262 
263   if (srcWidth < destWidth)
264     setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
265   else if (srcWidth > destWidth)
266     setResultRange(getResult(), truncRange(argRanges[0], destWidth));
267   else
268     setResultRange(getResult(), argRanges[0]);
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // IndexCastUIOp
273 //===----------------------------------------------------------------------===//
274 
275 void arith::IndexCastUIOp::inferResultRanges(
276     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
277   Type sourceType = getOperand().getType();
278   Type destType = getResult().getType();
279   unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
280   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
281 
282   if (srcWidth < destWidth)
283     setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
284   else if (srcWidth > destWidth)
285     setResultRange(getResult(), truncRange(argRanges[0], destWidth));
286   else
287     setResultRange(getResult(), argRanges[0]);
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // CmpIOp
292 //===----------------------------------------------------------------------===//
293 
294 void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
295                                       SetIntRangeFn setResultRange) {
296   arith::CmpIPredicate arithPred = getPredicate();
297   intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
298   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
299 
300   APInt min = APInt::getZero(1);
301   APInt max = APInt::getAllOnes(1);
302 
303   std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
304   if (truthValue.has_value() && *truthValue)
305     min = max;
306   else if (truthValue.has_value() && !(*truthValue))
307     max = min;
308 
309   setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // SelectOp
314 //===----------------------------------------------------------------------===//
315 
316 void arith::SelectOp::inferResultRangesFromOptional(
317     ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
318   std::optional<APInt> mbCondVal =
319       argRanges[0].isUninitialized()
320           ? std::nullopt
321           : argRanges[0].getValue().getConstantValue();
322 
323   const IntegerValueRange &trueCase = argRanges[1];
324   const IntegerValueRange &falseCase = argRanges[2];
325 
326   if (mbCondVal) {
327     if (mbCondVal->isZero())
328       setResultRange(getResult(), falseCase);
329     else
330       setResultRange(getResult(), trueCase);
331     return;
332   }
333   setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // ShLIOp
338 //===----------------------------------------------------------------------===//
339 
340 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
341                                       SetIntRangeFn setResultRange) {
342   setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
343                                                       getOverflowFlags())));
344 }
345 
346 //===----------------------------------------------------------------------===//
347 // ShRUIOp
348 //===----------------------------------------------------------------------===//
349 
350 void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
351                                        SetIntRangeFn setResultRange) {
352   setResultRange(getResult(), inferShrU(argRanges));
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // ShRSIOp
357 //===----------------------------------------------------------------------===//
358 
359 void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
360                                        SetIntRangeFn setResultRange) {
361   setResultRange(getResult(), inferShrS(argRanges));
362 }
363