xref: /llvm-project/mlir/include/mlir/Interfaces/InferIntRangeInterface.h (revision 6aeea700df6f3f8db9e6a79be4aa593c6fcc7d18)
1 //===- InferIntRangeInterface.h - Integer Range Inference --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of the integer range inference interface
10 // defined in `InferIntRange.td`
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
15 #define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
16 
17 #include "mlir/IR/OpDefinition.h"
18 #include <optional>
19 
20 namespace mlir {
21 /// A set of arbitrary-precision integers representing bounds on a given integer
22 /// value. These bounds are inclusive on both ends, so
23 /// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for
24 /// the unsigned and signed interpretations of values in order to enable more
25 /// precice inference of the interplay between operations with signed and
26 /// unsigned semantics.
27 class ConstantIntRanges {
28 public:
29   /// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
30   /// Non-integer values should be bounded by APInts of bitwidth 0.
ConstantIntRanges(const APInt & umin,const APInt & umax,const APInt & smin,const APInt & smax)31   ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin,
32                     const APInt &smax)
33       : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) {
34     assert(uminVal.getBitWidth() == umaxVal.getBitWidth() &&
35            umaxVal.getBitWidth() == sminVal.getBitWidth() &&
36            sminVal.getBitWidth() == smaxVal.getBitWidth() &&
37            "All bounds in the ranges must have the same bitwidth");
38   }
39 
40   bool operator==(const ConstantIntRanges &other) const;
41 
42   /// The minimum value of an integer when it is interpreted as unsigned.
43   const APInt &umin() const;
44 
45   /// The maximum value of an integer when it is interpreted as unsigned.
46   const APInt &umax() const;
47 
48   /// The minimum value of an integer when it is interpreted as signed.
49   const APInt &smin() const;
50 
51   /// The maximum value of an integer when it is interpreted as signed.
52   const APInt &smax() const;
53 
54   /// Return the bitwidth that should be used for integer ranges describing
55   /// `type`. For concrete integer types, this is their bitwidth, for `index`,
56   /// this is the internal storage bitwidth of `index` attributes, and for
57   /// non-integer types this is 0.
58   static unsigned getStorageBitwidth(Type type);
59 
60   /// Create a `ConstantIntRanges` with the maximum bounds for the width
61   /// `bitwidth`, that is - [0, uint_max(width)]/[sint_min(width),
62   /// sint_max(width)].
63   static ConstantIntRanges maxRange(unsigned bitwidth);
64 
65   /// Create a `ConstantIntRanges` with a constant value - that is, with the
66   /// bounds [value, value] for both its signed interpretations.
67   static ConstantIntRanges constant(const APInt &value);
68 
69   /// Create a `ConstantIntRanges` whose minimum is `min` and maximum is `max`
70   /// with `isSigned` specifying if the min and max should be interpreted as
71   /// signed or unsigned.
72   static ConstantIntRanges range(const APInt &min, const APInt &max,
73                                  bool isSigned);
74 
75   /// Create an `ConstantIntRanges` with the signed minimum and maximum equal
76   /// to `smin` and `smax`, where the unsigned bounds are constructed from the
77   /// signed ones if they correspond to a contigious range of bit patterns when
78   /// viewed as unsigned values and are left at [0, int_max()] otherwise.
79   static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax);
80 
81   /// Create an `ConstantIntRanges` with the unsigned minimum and maximum equal
82   /// to `umin` and `umax` and the signed part equal to `umin` and `umax`
83   /// unless the sign bit changes between the minimum and maximum.
84   static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax);
85 
86   /// Returns the union (computed separately for signed and unsigned bounds)
87   /// of this range and `other`.
88   ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const;
89 
90   /// Returns the intersection (computed separately for signed and unsigned
91   /// bounds) of this range and `other`.
92   ConstantIntRanges intersection(const ConstantIntRanges &other) const;
93 
94   /// If either the signed or unsigned interpretations of the range
95   /// indicate that the value it bounds is a constant, return that constant
96   /// value.
97   std::optional<APInt> getConstantValue() const;
98 
99   friend raw_ostream &operator<<(raw_ostream &os,
100                                  const ConstantIntRanges &range);
101 
102 private:
103   APInt uminVal, umaxVal, sminVal, smaxVal;
104 };
105 
106 raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
107 
108 /// This lattice value represents the integer range of an SSA value.
109 class IntegerValueRange {
110 public:
111   /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
112   /// range that is used to mark the value as unable to be analyzed further,
113   /// where `t` is the type of `value`.
114   static IntegerValueRange getMaxRange(Value value);
115 
116   /// Create an integer value range lattice value.
IntegerValueRange(ConstantIntRanges value)117   IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
118 
119   /// Create an integer value range lattice value.
120   IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
value(std::move (value))121       : value(std::move(value)) {}
122 
123   /// Whether the range is uninitialized. This happens when the state hasn't
124   /// been set during the analysis.
isUninitialized()125   bool isUninitialized() const { return !value.has_value(); }
126 
127   /// Get the known integer value range.
getValue()128   const ConstantIntRanges &getValue() const {
129     assert(!isUninitialized());
130     return *value;
131   }
132 
133   /// Compare two ranges.
134   bool operator==(const IntegerValueRange &rhs) const {
135     return value == rhs.value;
136   }
137 
138   /// Compute the least upper bound of two ranges.
join(const IntegerValueRange & lhs,const IntegerValueRange & rhs)139   static IntegerValueRange join(const IntegerValueRange &lhs,
140                                 const IntegerValueRange &rhs) {
141     if (lhs.isUninitialized())
142       return rhs;
143     if (rhs.isUninitialized())
144       return lhs;
145     return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
146   }
147 
148   /// Print the integer value range.
print(raw_ostream & os)149   void print(raw_ostream &os) const { os << value; }
150 
151 private:
152   /// The known integer value range.
153   std::optional<ConstantIntRanges> value;
154 };
155 
156 raw_ostream &operator<<(raw_ostream &, const IntegerValueRange &);
157 
158 /// The type of the `setResultRanges` callback provided to ops implementing
159 /// InferIntRangeInterface. It should be called once for each integer result
160 /// value and be passed the ConstantIntRanges corresponding to that value.
161 using SetIntRangeFn =
162     llvm::function_ref<void(Value, const ConstantIntRanges &)>;
163 
164 /// Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
165 /// This is the `setResultRanges` callback for the IntegerValueRange based
166 /// interface method.
167 using SetIntLatticeFn =
168     llvm::function_ref<void(Value, const IntegerValueRange &)>;
169 
170 class InferIntRangeInterface;
171 
172 namespace intrange::detail {
173 /// Default implementation of `inferResultRanges` which dispatches to the
174 /// `inferResultRangesFromOptional`.
175 void defaultInferResultRanges(InferIntRangeInterface interface,
176                               ArrayRef<IntegerValueRange> argRanges,
177                               SetIntLatticeFn setResultRanges);
178 
179 /// Default implementation of `inferResultRangesFromOptional` which dispatches
180 /// to the `inferResultRanges`.
181 void defaultInferResultRangesFromOptional(InferIntRangeInterface interface,
182                                           ArrayRef<ConstantIntRanges> argRanges,
183                                           SetIntRangeFn setResultRanges);
184 } // end namespace intrange::detail
185 } // end namespace mlir
186 
187 #include "mlir/Interfaces/InferIntRangeInterface.h.inc"
188 
189 #endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
190