xref: /llvm-project/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (revision 29a925abb660104b413b15075b3a19793825f57e)
1 //===- ValueBoundsOpInterface.h - Value Bounds ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_
10 #define MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_
11 
12 #include "mlir/Analysis/FlatLinearValueConstraints.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/Value.h"
16 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Support/ExtensibleRTTI.h"
20 
21 #include <queue>
22 
23 namespace mlir {
24 class OffsetSizeAndStrideOpInterface;
25 
26 /// A hyperrectangular slice, represented as a list of offsets, sizes and
27 /// strides.
28 class HyperrectangularSlice {
29 public:
30   HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
31                         ArrayRef<OpFoldResult> sizes,
32                         ArrayRef<OpFoldResult> strides);
33 
34   /// Create a hyperrectangular slice with unit strides.
35   HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
36                         ArrayRef<OpFoldResult> sizes);
37 
38   /// Infer a hyperrectangular slice from `OffsetSizeAndStrideOpInterface`.
39   HyperrectangularSlice(OffsetSizeAndStrideOpInterface op);
40 
getMixedOffsets()41   ArrayRef<OpFoldResult> getMixedOffsets() const { return mixedOffsets; }
getMixedSizes()42   ArrayRef<OpFoldResult> getMixedSizes() const { return mixedSizes; }
getMixedStrides()43   ArrayRef<OpFoldResult> getMixedStrides() const { return mixedStrides; }
44 
45 private:
46   SmallVector<OpFoldResult> mixedOffsets;
47   SmallVector<OpFoldResult> mixedSizes;
48   SmallVector<OpFoldResult> mixedStrides;
49 };
50 
51 using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
52 
53 /// A helper class to be used with `ValueBoundsOpInterface`. This class stores a
54 /// constraint system and mapping of constrained variables to index-typed
55 /// values or dimension sizes of shaped values.
56 ///
57 /// Interface implementations of `ValueBoundsOpInterface` use `addBounds` to
58 /// insert constraints about their results and/or region block arguments into
59 /// the constraint set in the form of an AffineExpr. When a bound should be
60 /// expressed in terms of another value/dimension, `getExpr` can be used to
61 /// retrieve an AffineExpr that represents the specified value/dimension.
62 ///
63 /// When a value/dimension is retrieved for the first time through `getExpr`,
64 /// it is added to an internal worklist. See `computeBound` for more details.
65 ///
66 /// Note: Any modification of existing IR invalides the data stored in this
67 /// class. Adding new operations is allowed.
68 class ValueBoundsConstraintSet
69     : public llvm::RTTIExtends<ValueBoundsConstraintSet, llvm::RTTIRoot> {
70 protected:
71   /// Helper class that builds a bound for a shaped value dimension or
72   /// index-typed value.
73   class BoundBuilder {
74   public:
75     /// Specify a dimension, assuming that the underlying value is a shaped
76     /// value.
77     BoundBuilder &operator[](int64_t dim);
78 
79     // These overloaded operators add lower/upper/equality bounds.
80     void operator<(AffineExpr expr);
81     void operator<=(AffineExpr expr);
82     void operator>(AffineExpr expr);
83     void operator>=(AffineExpr expr);
84     void operator==(AffineExpr expr);
85     void operator<(OpFoldResult ofr);
86     void operator<=(OpFoldResult ofr);
87     void operator>(OpFoldResult ofr);
88     void operator>=(OpFoldResult ofr);
89     void operator==(OpFoldResult ofr);
90     void operator<(int64_t i);
91     void operator<=(int64_t i);
92     void operator>(int64_t i);
93     void operator>=(int64_t i);
94     void operator==(int64_t i);
95 
96   protected:
97     friend class ValueBoundsConstraintSet;
BoundBuilder(ValueBoundsConstraintSet & cstr,Value value)98     BoundBuilder(ValueBoundsConstraintSet &cstr, Value value)
99         : cstr(cstr), value(value) {}
100 
101   private:
102     BoundBuilder(const BoundBuilder &) = delete;
103     BoundBuilder &operator=(const BoundBuilder &) = delete;
104     bool operator==(const BoundBuilder &) = delete;
105     bool operator!=(const BoundBuilder &) = delete;
106 
107     ValueBoundsConstraintSet &cstr;
108     Value value;
109     std::optional<int64_t> dim;
110   };
111 
112 public:
113   static char ID;
114 
115   /// A variable that can be added to the constraint set as a "column". The
116   /// value bounds infrastructure can compute bounds for variables and compare
117   /// two variables.
118   ///
119   /// Internally, a variable is represented as an affine map and operands.
120   class Variable {
121   public:
122     /// Construct a variable for an index-typed attribute or SSA value.
123     Variable(OpFoldResult ofr);
124 
125     /// Construct a variable for an index-typed SSA value.
126     Variable(Value indexValue);
127 
128     /// Construct a variable for a dimension of a shaped value.
129     Variable(Value shapedValue, int64_t dim);
130 
131     /// Construct a variable for an index-typed attribute/SSA value or for a
132     /// dimension of a shaped value. A non-null dimension must be provided if
133     /// and only if `ofr` is a shaped value.
134     Variable(OpFoldResult ofr, std::optional<int64_t> dim);
135 
136     /// Construct a variable for a map and its operands.
137     Variable(AffineMap map, ArrayRef<Variable> mapOperands);
138     Variable(AffineMap map, ArrayRef<Value> mapOperands);
139 
getContext()140     MLIRContext *getContext() const { return map.getContext(); }
141 
142   private:
143     friend class ValueBoundsConstraintSet;
144     AffineMap map;
145     ValueDimList mapOperands;
146   };
147 
148   /// The stop condition when traversing the backward slice of a shaped value/
149   /// index-type value. The traversal continues until the stop condition
150   /// evaluates to "true" for a value.
151   ///
152   /// The first parameter of the function is the shaped value/index-typed
153   /// value. The second parameter is the dimension in case of a shaped value.
154   /// The third parameter is this constraint set.
155   using StopConditionFn = std::function<bool(
156       Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
157 
158   /// Compute a bound for the given variable. The computed bound is stored in
159   /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
160   /// operand is either an index-type SSA value or a shaped value and a
161   /// dimension.
162   ///
163   /// The bound is computed in terms of values/dimensions for which
164   /// `stopCondition` evaluates to "true". To that end, the backward slice
165   /// (reverse use-def chain) of the given value is visited in a worklist-driven
166   /// manner and the constraint set is populated according to
167   /// `ValueBoundsOpInterface` for each visited value.
168   ///
169   /// By default, lower/equal bounds are closed and upper bounds are open. If
170   /// `closedUB` is set to "true", upper bounds are also closed.
171   static LogicalResult
172   computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
173                presburger::BoundType type, const Variable &var,
174                StopConditionFn stopCondition, bool closedUB = false);
175 
176   /// Compute a bound in terms of the values/dimensions in `dependencies`. The
177   /// computed bound consists of only constant terms and dependent values (or
178   /// dimension sizes thereof).
179   static LogicalResult
180   computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
181                         presburger::BoundType type, const Variable &var,
182                         ValueDimList dependencies, bool closedUB = false);
183 
184   /// Compute a bound in that is independent of all values in `independencies`.
185   ///
186   /// Independencies are the opposite of dependencies. The computed bound does
187   /// not contain any SSA values that are part of `independencies`. E.g., this
188   /// function can be used to make ops hoistable from loops. To that end, ops
189   /// must be made independent of loop induction variables (in the case of "for"
190   /// loops). Loop induction variables are the independencies; they may not
191   /// appear in the computed bound.
192   static LogicalResult
193   computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
194                           presburger::BoundType type, const Variable &var,
195                           ValueRange independencies, bool closedUB = false);
196 
197   /// Compute a constant bound for the given variable.
198   ///
199   /// This function traverses the backward slice of the given operands in a
200   /// worklist-driven manner until `stopCondition` evaluates to "true". The
201   /// constraint set is populated according to `ValueBoundsOpInterface` for each
202   /// visited value. (No constraints are added for values for which the stop
203   /// condition evaluates to "true".)
204   ///
205   /// The stop condition is optional: If none is specified, the backward slice
206   /// is traversed in a breadth-first manner until a constant bound could be
207   /// computed.
208   ///
209   /// By default, lower/equal bounds are closed and upper bounds are open. If
210   /// `closedUB` is set to "true", upper bounds are also closed.
211   static FailureOr<int64_t>
212   computeConstantBound(presburger::BoundType type, const Variable &var,
213                        StopConditionFn stopCondition = nullptr,
214                        bool closedUB = false);
215 
216   /// Compute a constant delta between the given two values. Return "failure"
217   /// if a constant delta could not be determined.
218   ///
219   /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
220   /// index-typed.
221   static FailureOr<int64_t>
222   computeConstantDelta(Value value1, Value value2,
223                        std::optional<int64_t> dim1 = std::nullopt,
224                        std::optional<int64_t> dim2 = std::nullopt);
225 
226   /// Traverse the IR starting from the given value/dim and populate constraints
227   /// as long as the stop condition holds. Also process all values/dims that are
228   /// already on the worklist.
229   void populateConstraints(Value value, std::optional<int64_t> dim);
230 
231   /// Comparison operator for `ValueBoundsConstraintSet::compare`.
232   enum ComparisonOperator { LT, LE, EQ, GT, GE };
233 
234   /// Populate constraints for lhs/rhs (until the stop condition is met). Then,
235   /// try to prove that, based on the current state of this constraint set
236   /// (i.e., without analyzing additional IR or adding new constraints), the
237   /// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
238   ///
239   /// Return "true" if the specified relation between the two values/dims was
240   /// proven to hold. Return "false" if the specified relation could not be
241   /// proven. This could be because the specified relation does in fact not hold
242   /// or because there is not enough information in the constraint set. In other
243   /// words, if we do not know for sure, this function returns "false".
244   bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp,
245                           const Variable &rhs);
246 
247   /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
248   /// specified relation could not be proven. This could be because the
249   /// specified relation does in fact not hold or because there is not enough
250   /// information in the constraint set. In other words, if we do not know for
251   /// sure, this function returns "false".
252   ///
253   /// This function keeps traversing the backward slice of lhs/rhs until could
254   /// prove the relation or until it ran out of IR.
255   static bool compare(const Variable &lhs, ComparisonOperator cmp,
256                       const Variable &rhs);
257 
258   /// Compute whether the given variables are equal. Return "failure" if
259   /// equality could not be determined.
260   static FailureOr<bool> areEqual(const Variable &var1, const Variable &var2);
261 
262   /// Return "true" if the given slices are guaranteed to be overlapping.
263   /// Return "false" if the given slices are guaranteed to be non-overlapping.
264   /// Return "failure" if unknown.
265   ///
266   /// Slices are overlapping if for all dimensions:
267   /// *      offset1 + size1 * stride1 <= offset2
268   /// * and  offset2 + size2 * stride2 <= offset1
269   ///
270   /// Slice are non-overlapping if the above constraint is not satisfied for
271   /// at least one dimension.
272   static FailureOr<bool> areOverlappingSlices(MLIRContext *ctx,
273                                               HyperrectangularSlice slice1,
274                                               HyperrectangularSlice slice2);
275 
276   /// Return "true" if the given slices are guaranteed to be equivalent.
277   /// Return "false" if the given slices are guaranteed to be non-equivalent.
278   /// Return "failure" if unknown.
279   ///
280   /// Slices are equivalent if their offsets, sizes and strices are equal.
281   static FailureOr<bool> areEquivalentSlices(MLIRContext *ctx,
282                                              HyperrectangularSlice slice1,
283                                              HyperrectangularSlice slice2);
284 
285   /// Add a bound for the given index-typed value or shaped value. This function
286   /// returns a builder that adds the bound.
bound(Value value)287   BoundBuilder bound(Value value) { return BoundBuilder(*this, value); }
288 
289   /// Return an expression that represents the given index-typed value or shaped
290   /// value dimension. If this value/dimension was not used so far, it is added
291   /// to the worklist.
292   ///
293   /// `dim` must be `nullopt` if and only if the given value is of index type.
294   AffineExpr getExpr(Value value, std::optional<int64_t> dim = std::nullopt);
295 
296   /// Return an expression that represents a constant or index-typed SSA value.
297   /// In case of a value, if this value was not used so far, it is added to the
298   /// worklist.
299   AffineExpr getExpr(OpFoldResult ofr);
300 
301   /// Return an expression that represents a constant.
302   AffineExpr getExpr(int64_t constant);
303 
304   /// Debugging only: Dump the constraint set and the column-to-value/dim
305   /// mapping to llvm::errs.
306   void dump() const;
307 
308 protected:
309   /// Dimension identifier to indicate a value is index-typed. This is used for
310   /// internal data structures/API only.
311   static constexpr int64_t kIndexValue = -1;
312 
313   /// An index-typed value or the dimension of a shaped-type value.
314   using ValueDim = std::pair<Value, int64_t>;
315 
316   ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition,
317                            bool addConservativeSemiAffineBounds = false);
318 
319   /// Return "true" if, based on the current state of the constraint system,
320   /// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation
321   /// could not be proven. This could be because the specified relation does in
322   /// fact not hold or because there is not enough information in the constraint
323   /// set. In other words, if we do not know for sure, this function returns
324   /// "false".
325   ///
326   /// This function does not analyze any IR and does not populate any additional
327   /// constraints.
328   bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
329 
330   /// Given an affine map with a single result (and map operands), add a new
331   /// column to the constraint set that represents the result of the map.
332   /// Traverse additional IR starting from the map operands as needed (as long
333   /// as the stop condition is not satisfied). Also process all values/dims that
334   /// are already on the worklist. Return the position of the newly added
335   /// column.
336   int64_t populateConstraints(AffineMap map, ValueDimList mapOperands);
337 
338   /// Iteratively process all elements on the worklist until an index-typed
339   /// value or shaped value meets `stopCondition`. Such values are not processed
340   /// any further.
341   void processWorklist();
342 
343   /// Bound the given column in the underlying constraint set by the given
344   /// expression.
345   void addBound(presburger::BoundType type, int64_t pos, AffineExpr expr);
346 
347   /// Return the column position of the given value/dimension. Asserts that the
348   /// value/dimension exists in the constraint set.
349   int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;
350 
351   /// Return an affine expression that represents column `pos` in the constraint
352   /// set.
353   AffineExpr getPosExpr(int64_t pos);
354 
355   /// Return "true" if the given value/dim is mapped (i.e., has a corresponding
356   /// column in the constraint system).
357   bool isMapped(Value value, std::optional<int64_t> dim = std::nullopt) const;
358 
359   /// Insert a value/dimension into the constraint set. If `isSymbol` is set to
360   /// "false", a dimension is added. The value/dimension is added to the
361   /// worklist if `addToWorklist` is set.
362   ///
363   /// Note: There are certain affine restrictions wrt. dimensions. E.g., they
364   /// cannot be multiplied. Furthermore, bounds can only be queried for
365   /// dimensions but not for symbols.
366   int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true,
367                  bool addToWorklist = true);
368 
369   /// Insert an anonymous column into the constraint set. The column is not
370   /// bound to any value/dimension. If `isSymbol` is set to "false", a dimension
371   /// is added.
372   ///
373   /// Note: There are certain affine restrictions wrt. dimensions. E.g., they
374   /// cannot be multiplied. Furthermore, bounds can only be queried for
375   /// dimensions but not for symbols.
376   int64_t insert(bool isSymbol = true);
377 
378   /// Insert the given affine map and its bound operands as a new column in the
379   /// constraint system. Return the position of the new column. Any operands
380   /// that were not analyzed yet are put on the worklist.
381   int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
382   int64_t insert(const Variable &var, bool isSymbol = true);
383 
384   /// Project out the given column in the constraint set.
385   void projectOut(int64_t pos);
386 
387   /// Project out all columns for which the condition holds.
388   void projectOut(function_ref<bool(ValueDim)> condition);
389 
390   void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
391 
392   /// Mapping of columns to values/shape dimensions.
393   SmallVector<std::optional<ValueDim>> positionToValueDim;
394   /// Reverse mapping of values/shape dimensions to columns.
395   DenseMap<ValueDim, int64_t> valueDimToPosition;
396 
397   /// Worklist of values/shape dimensions that have not been processed yet.
398   std::queue<int64_t> worklist;
399 
400   /// Constraint system of equalities and inequalities.
401   FlatLinearConstraints cstr;
402 
403   /// Builder for constructing affine expressions.
404   Builder builder;
405 
406   /// The current stop condition function.
407   StopConditionFn stopCondition = nullptr;
408 
409   /// Should conservative bounds be added for semi-affine expressions.
410   bool addConservativeSemiAffineBounds = false;
411 };
412 
413 } // namespace mlir
414 
415 #include "mlir/Interfaces/ValueBoundsOpInterface.h.inc"
416 
417 #endif // MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_
418