xref: /llvm-project/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (revision 092372da15e5165be14cdbb7cac3cf4976fd82d0)
1 //===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for dealing with static values, e.g.,
10 // converting back and forth between Value and OpFoldResult. Such functionality
11 // is used in multiple dialects.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
16 #define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
17 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/OpDefinition.h"
21 #include "mlir/Support/LLVM.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/SmallVectorExtras.h"
24 
25 namespace mlir {
26 
27 /// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp
28 /// with attribute with value `0`.
29 bool isZeroIndex(OpFoldResult v);
30 
31 /// Represents a range (offset, size, and stride) where each element of the
32 /// triple may be dynamic or static.
33 struct Range {
34   OpFoldResult offset;
35   OpFoldResult size;
36   OpFoldResult stride;
37 };
38 
39 /// Given an array of Range values, return a tuple of (offset vector, sizes
40 /// vector, and strides vector) formed by separating out the individual
41 /// elements of each range.
42 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
43            SmallVector<OpFoldResult>>
44 getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
45 
46 /// Helper function to dispatch an OpFoldResult into `staticVec` if:
47 ///   a) it is an IntegerAttr
48 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
49 /// In such dynamic cases, ShapedType::kDynamic is also pushed to
50 /// `staticVec`. This is useful to extract mixed static and dynamic entries
51 /// that come from an AttrSizedOperandSegments trait.
52 void dispatchIndexOpFoldResult(OpFoldResult ofr,
53                                SmallVectorImpl<Value> &dynamicVec,
54                                SmallVectorImpl<int64_t> &staticVec);
55 
56 /// Helper function to dispatch multiple OpFoldResults according to the
57 /// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single
58 /// OpFoldResult.
59 void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
60                                 SmallVectorImpl<Value> &dynamicVec,
61                                 SmallVectorImpl<int64_t> &staticVec);
62 
63 /// Given OpFoldResult representing dim size value (*), generates a pair of
64 /// sizes:
65 ///   * 1st result, static value, contains an int64_t dim size that can be used
66 ///   to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims),
67 ///   * 2nd result, dynamic value, contains OpFoldResult encapsulating the
68 ///   actual dim size (either original or updated input value).
69 /// For input sizes for which it is possible to extract a constant Attribute,
70 /// replaces the original size value with an integer attribute (unless it's
71 /// already a constant Attribute). The 1st return value also becomes the actual
72 /// integer size (as opposed ShapedType::kDynamic).
73 ///
74 /// (*) This hook is usually used when, given input sizes as OpFoldResult,
75 /// it's required to generate two vectors:
76 ///   * sizes as int64_t to generate a shape,
77 ///   * sizes as OpFoldResult for sizes-like attribute.
78 /// Please update this comment if you identify other use cases.
79 std::pair<int64_t, OpFoldResult>
80 getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b);
81 
82 /// Extract integer values from the assumed ArrayAttr of IntegerAttr.
83 template <typename IntTy>
84 SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) {
85   return llvm::to_vector(
86       llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> IntTy {
87         return cast<IntegerAttr>(a).getInt();
88       }));
89 }
90 
91 /// Given a value, try to extract a constant Attribute. If this fails, return
92 /// the original value.
93 OpFoldResult getAsOpFoldResult(Value val);
94 /// Given an array of values, try to extract a constant Attribute from each
95 /// value. If this fails, return the original value.
96 SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
97 /// Convert `arrayAttr` to a vector of OpFoldResult.
98 SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
99 
100 /// Convert int64_t to integer attributes of index type and return them as
101 /// OpFoldResult.
102 OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val);
103 SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
104                                                  ArrayRef<int64_t> values);
105 
106 /// If ofr is a constant integer or an IntegerAttr, return the integer.
107 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
108 /// If all ofrs are constant integers or IntegerAttrs, return the integers.
109 std::optional<SmallVector<int64_t>>
110 getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
111 
112 /// Return true if `ofr` is constant integer equal to `value`.
113 bool isConstantIntValue(OpFoldResult ofr, int64_t value);
114 /// Return true if all of `ofrs` are constant integers equal to `value`.
115 bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
116 /// Return true if all of `ofrs` are constant integers equal to the
117 /// corresponding value in `values`.
118 bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
119                           ArrayRef<int64_t> values);
120 
121 /// Return true if ofr1 and ofr2 are the same integer constant attribute
122 /// values or the same SSA value. Ignore integer bitwitdh and type mismatch
123 /// that come from the fact there is no IndexAttr and that IndexType have no
124 /// bitwidth.
125 bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
126 bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
127                                     ArrayRef<OpFoldResult> ofrs2);
128 
129 // To convert an OpFoldResult to a Value of index type, see:
130 //   mlir/include/mlir/Dialect/Arith/Utils/Utils.h
131 // TODO: find a better common landing place.
132 //
133 // Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
134 //                                       OpFoldResult ofr);
135 
136 // To convert an OpFoldResult to a Value of index type, see:
137 //   mlir/include/mlir/Dialect/Arith/Utils/Utils.h
138 // TODO: find a better common landing place.
139 //
140 // SmallVector<Value>
141 // getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
142 //                                 ArrayRef<OpFoldResult> valueOrAttrVec);
143 
144 /// Return a vector of OpFoldResults with the same size a staticValues, but
145 /// all elements for which ShapedType::isDynamic is true, will be replaced by
146 /// dynamicValues.
147 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
148                                          ValueRange dynamicValues,
149                                          MLIRContext *context);
150 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
151                                          ValueRange dynamicValues, Builder &b);
152 
153 /// Decompose a vector of mixed static or dynamic values into the
154 /// corresponding pair of arrays. This is the inverse function of
155 /// `getMixedValues`.
156 std::pair<SmallVector<int64_t>, SmallVector<Value>>
157 decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
158 
159 /// Helper to sort `values` according to matching `keys`.
160 SmallVector<Value>
161 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
162                      llvm::function_ref<bool(Attribute, Attribute)> compare);
163 SmallVector<OpFoldResult>
164 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
165                      llvm::function_ref<bool(Attribute, Attribute)> compare);
166 SmallVector<int64_t>
167 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
168                      llvm::function_ref<bool(Attribute, Attribute)> compare);
169 
170 /// Helper function to check whether the passed in `sizes` or `offsets` are
171 /// valid. This can be used to re-check whether dimensions are still valid
172 /// after constant folding the dynamic dimensions.
173 bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
174 
175 /// Helper function to check whether the passed in `strides` are valid. This
176 /// can be used to re-check whether dimensions are still valid after constant
177 /// folding the dynamic dimensions.
178 bool hasValidStrides(SmallVector<int64_t> strides);
179 
180 /// Returns "success" when any of the elements in `ofrs` is a constant value. In
181 /// that case the value is replaced by an attribute. Returns "failure" when no
182 /// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
183 /// non-negative and non-zero constant values are folded respectively.
184 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
185                                    bool onlyNonNegative = false,
186                                    bool onlyNonZero = false);
187 
188 /// Returns "success" when any of the elements in `offsetsOrSizes` is a
189 /// constant value. In that case the value is replaced by an attribute. Returns
190 /// "failure" when no folding happened. Invalid values are not folded to avoid
191 /// canonicalization crashes.
192 LogicalResult
193 foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
194 
195 /// Returns "success" when any of the elements in `strides` is a constant
196 /// value. In that case the value is replaced by an attribute. Returns
197 /// "failure" when no folding happened. Invalid values are not folded to avoid
198 /// canonicalization crashes.
199 LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
200 
201 /// Return the number of iterations for a loop with a lower bound `lb`, upper
202 /// bound `ub` and step `step`.
203 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
204                                          OpFoldResult step);
205 
206 /// Idiomatic saturated operations on values like offsets, sizes, and strides.
207 struct SaturatedInteger {
208   static SaturatedInteger wrap(int64_t v) {
209     return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
210                                       : SaturatedInteger{false, v};
211   }
212   int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
213   FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
214     if (saturated && !other.saturated)
215       return other;
216     if (!saturated && !other.saturated && v != other.v)
217       return failure();
218     return *this;
219   }
220   bool operator==(SaturatedInteger other) {
221     return (saturated && other.saturated) ||
222            (!saturated && !other.saturated && v == other.v);
223   }
224   bool operator!=(SaturatedInteger other) { return !(*this == other); }
225   SaturatedInteger operator+(SaturatedInteger other) {
226     if (saturated || other.saturated)
227       return SaturatedInteger{true, 0};
228     return SaturatedInteger{false, other.v + v};
229   }
230   SaturatedInteger operator*(SaturatedInteger other) {
231     // Multiplication with 0 is always 0.
232     if (!other.saturated && other.v == 0)
233       return SaturatedInteger{false, 0};
234     if (!saturated && v == 0)
235       return SaturatedInteger{false, 0};
236     // Otherwise, if this or the other integer is dynamic, so is the result.
237     if (saturated || other.saturated)
238       return SaturatedInteger{true, 0};
239     return SaturatedInteger{false, other.v * v};
240   }
241   bool saturated = true;
242   int64_t v = 0;
243 };
244 
245 } // namespace mlir
246 
247 #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
248