xref: /llvm-project/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h (revision 6aeea700df6f3f8db9e6a79be4aa593c6fcc7d18)
1 //===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares implementations of range inference for operations that are
10 // common to both the `arith` and `index` dialects to facilitate reuse.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
15 #define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
16 
17 #include "mlir/Interfaces/InferIntRangeInterface.h"
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/BitmaskEnum.h"
20 #include <optional>
21 
22 namespace mlir {
23 namespace intrange {
24 /// Function that performs inference on an array of `ConstantIntRanges`,
25 /// abstracted away here to permit writing the function that handles both
26 /// 64- and 32-bit index types.
27 using InferRangeFn =
28     std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
29 
30 /// Function that performs inferrence on an array of `IntegerValueRange`.
31 using InferIntegerValueRangeFn =
32     std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>;
33 
34 static constexpr unsigned indexMinWidth = 32;
35 static constexpr unsigned indexMaxWidth = 64;
36 
37 enum class CmpMode : uint32_t { Both, Signed, Unsigned };
38 
39 enum class OverflowFlags : uint32_t {
40   None = 0,
41   Nsw = 1,
42   Nuw = 2,
43   LLVM_MARK_AS_BITMASK_ENUM(Nuw)
44 };
45 
46 /// Function that performs inference on an array of `ConstantIntRanges` while
47 /// taking special overflow behavior into account.
48 using InferRangeWithOvfFlagsFn =
49     function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
50 
51 /// Compute `inferFn` on `ranges`, whose size should be the index storage
52 /// bitwidth. Then, compute the function on `argRanges` again after truncating
53 /// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
54 /// equal to the 32-bit result, use it (to preserve compatibility with folders
55 /// and inference precision), and take the union of the results otherwise.
56 ///
57 /// The `mode` argument specifies if the unsigned, signed, or both results of
58 /// the inference computation should be used when comparing the results.
59 ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn,
60                                ArrayRef<ConstantIntRanges> argRanges,
61                                CmpMode mode);
62 
63 /// Independently zero-extend the unsigned values and sign-extend the signed
64 /// values in `range` to `destWidth` bits, returning the resulting range.
65 ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth);
66 
67 /// Use the unsigned values in `range` to zero-extend it to `destWidth`.
68 ConstantIntRanges extUIRange(const ConstantIntRanges &range,
69                              unsigned destWidth);
70 
71 /// Use the signed values in `range` to sign-extend it to `destWidth`.
72 ConstantIntRanges extSIRange(const ConstantIntRanges &range,
73                              unsigned destWidth);
74 
75 /// Truncate `range` to `destWidth` bits, taking care to handle cases such as
76 /// the truncation of [255, 256] to i8 not being a uniform range.
77 ConstantIntRanges truncRange(const ConstantIntRanges &range,
78                              unsigned destWidth);
79 
80 ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges,
81                            OverflowFlags ovfFlags = OverflowFlags::None);
82 
83 ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges,
84                            OverflowFlags ovfFlags = OverflowFlags::None);
85 
86 ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges,
87                            OverflowFlags ovfFlags = OverflowFlags::None);
88 
89 ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
90 
91 ConstantIntRanges inferDivU(ArrayRef<ConstantIntRanges> argRanges);
92 
93 ConstantIntRanges inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges);
94 
95 ConstantIntRanges inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges);
96 
97 ConstantIntRanges inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges);
98 
99 ConstantIntRanges inferRemS(ArrayRef<ConstantIntRanges> argRanges);
100 
101 ConstantIntRanges inferRemU(ArrayRef<ConstantIntRanges> argRanges);
102 
103 ConstantIntRanges inferMaxS(ArrayRef<ConstantIntRanges> argRanges);
104 
105 ConstantIntRanges inferMaxU(ArrayRef<ConstantIntRanges> argRanges);
106 
107 ConstantIntRanges inferMinS(ArrayRef<ConstantIntRanges> argRanges);
108 
109 ConstantIntRanges inferMinU(ArrayRef<ConstantIntRanges> argRanges);
110 
111 ConstantIntRanges inferAnd(ArrayRef<ConstantIntRanges> argRanges);
112 
113 ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
114 
115 ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
116 
117 ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges,
118                            OverflowFlags ovfFlags = OverflowFlags::None);
119 
120 ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
121 
122 ConstantIntRanges inferShrU(ArrayRef<ConstantIntRanges> argRanges);
123 
124 /// Copy of the enum from `arith` and `index` to allow the common integer range
125 /// infrastructure to not depend on either dialect.
126 enum class CmpPredicate : uint64_t {
127   eq,
128   ne,
129   slt,
130   sle,
131   sgt,
132   sge,
133   ult,
134   ule,
135   ugt,
136   uge,
137 };
138 
139 /// Returns a boolean value if `pred` is statically true or false for
140 /// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the
141 /// value of the predicate cannot be determined.
142 std::optional<bool> evaluatePred(CmpPredicate pred,
143                                  const ConstantIntRanges &lhs,
144                                  const ConstantIntRanges &rhs);
145 
146 } // namespace intrange
147 } // namespace mlir
148 
149 #endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
150