1 //===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares implementations of range inference for operations that are 10 // common to both the `arith` and `index` dialects to facilitate reuse. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H 15 #define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H 16 17 #include "mlir/Interfaces/InferIntRangeInterface.h" 18 #include "llvm/ADT/ArrayRef.h" 19 #include "llvm/ADT/BitmaskEnum.h" 20 #include <optional> 21 22 namespace mlir { 23 namespace intrange { 24 /// Function that performs inference on an array of `ConstantIntRanges`, 25 /// abstracted away here to permit writing the function that handles both 26 /// 64- and 32-bit index types. 27 using InferRangeFn = 28 std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>; 29 30 /// Function that performs inferrence on an array of `IntegerValueRange`. 31 using InferIntegerValueRangeFn = 32 std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>; 33 34 static constexpr unsigned indexMinWidth = 32; 35 static constexpr unsigned indexMaxWidth = 64; 36 37 enum class CmpMode : uint32_t { Both, Signed, Unsigned }; 38 39 enum class OverflowFlags : uint32_t { 40 None = 0, 41 Nsw = 1, 42 Nuw = 2, 43 LLVM_MARK_AS_BITMASK_ENUM(Nuw) 44 }; 45 46 /// Function that performs inference on an array of `ConstantIntRanges` while 47 /// taking special overflow behavior into account. 48 using InferRangeWithOvfFlagsFn = 49 function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>; 50 51 /// Compute `inferFn` on `ranges`, whose size should be the index storage 52 /// bitwidth. Then, compute the function on `argRanges` again after truncating 53 /// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is 54 /// equal to the 32-bit result, use it (to preserve compatibility with folders 55 /// and inference precision), and take the union of the results otherwise. 56 /// 57 /// The `mode` argument specifies if the unsigned, signed, or both results of 58 /// the inference computation should be used when comparing the results. 59 ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, 60 ArrayRef<ConstantIntRanges> argRanges, 61 CmpMode mode); 62 63 /// Independently zero-extend the unsigned values and sign-extend the signed 64 /// values in `range` to `destWidth` bits, returning the resulting range. 65 ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth); 66 67 /// Use the unsigned values in `range` to zero-extend it to `destWidth`. 68 ConstantIntRanges extUIRange(const ConstantIntRanges &range, 69 unsigned destWidth); 70 71 /// Use the signed values in `range` to sign-extend it to `destWidth`. 72 ConstantIntRanges extSIRange(const ConstantIntRanges &range, 73 unsigned destWidth); 74 75 /// Truncate `range` to `destWidth` bits, taking care to handle cases such as 76 /// the truncation of [255, 256] to i8 not being a uniform range. 77 ConstantIntRanges truncRange(const ConstantIntRanges &range, 78 unsigned destWidth); 79 80 ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges, 81 OverflowFlags ovfFlags = OverflowFlags::None); 82 83 ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges, 84 OverflowFlags ovfFlags = OverflowFlags::None); 85 86 ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges, 87 OverflowFlags ovfFlags = OverflowFlags::None); 88 89 ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges); 90 91 ConstantIntRanges inferDivU(ArrayRef<ConstantIntRanges> argRanges); 92 93 ConstantIntRanges inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges); 94 95 ConstantIntRanges inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges); 96 97 ConstantIntRanges inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges); 98 99 ConstantIntRanges inferRemS(ArrayRef<ConstantIntRanges> argRanges); 100 101 ConstantIntRanges inferRemU(ArrayRef<ConstantIntRanges> argRanges); 102 103 ConstantIntRanges inferMaxS(ArrayRef<ConstantIntRanges> argRanges); 104 105 ConstantIntRanges inferMaxU(ArrayRef<ConstantIntRanges> argRanges); 106 107 ConstantIntRanges inferMinS(ArrayRef<ConstantIntRanges> argRanges); 108 109 ConstantIntRanges inferMinU(ArrayRef<ConstantIntRanges> argRanges); 110 111 ConstantIntRanges inferAnd(ArrayRef<ConstantIntRanges> argRanges); 112 113 ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges); 114 115 ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges); 116 117 ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges, 118 OverflowFlags ovfFlags = OverflowFlags::None); 119 120 ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges); 121 122 ConstantIntRanges inferShrU(ArrayRef<ConstantIntRanges> argRanges); 123 124 /// Copy of the enum from `arith` and `index` to allow the common integer range 125 /// infrastructure to not depend on either dialect. 126 enum class CmpPredicate : uint64_t { 127 eq, 128 ne, 129 slt, 130 sle, 131 sgt, 132 sge, 133 ult, 134 ule, 135 ugt, 136 uge, 137 }; 138 139 /// Returns a boolean value if `pred` is statically true or false for 140 /// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the 141 /// value of the predicate cannot be determined. 142 std::optional<bool> evaluatePred(CmpPredicate pred, 143 const ConstantIntRanges &lhs, 144 const ConstantIntRanges &rhs); 145 146 } // namespace intrange 147 } // namespace mlir 148 149 #endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H 150