xref: /llvm-project/mlir/lib/Dialect/Utils/StaticValueUtils.cpp (revision 092372da15e5165be14cdbb7cac3cf4976fd82d0)
1 //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/StaticValueUtils.h"
10 #include "mlir/IR/Matchers.h"
11 #include "mlir/Support/LLVM.h"
12 #include "llvm/ADT/APSInt.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Support/MathExtras.h"
15 
16 namespace mlir {
17 
18 bool isZeroIndex(OpFoldResult v) {
19   if (!v)
20     return false;
21   std::optional<int64_t> constint = getConstantIntValue(v);
22   if (!constint)
23     return false;
24   return *constint == 0;
25 }
26 
27 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
28            SmallVector<OpFoldResult>>
29 getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
30   SmallVector<OpFoldResult> offsets, sizes, strides;
31   offsets.reserve(ranges.size());
32   sizes.reserve(ranges.size());
33   strides.reserve(ranges.size());
34   for (const auto &[offset, size, stride] : ranges) {
35     offsets.push_back(offset);
36     sizes.push_back(size);
37     strides.push_back(stride);
38   }
39   return std::make_tuple(offsets, sizes, strides);
40 }
41 
42 /// Helper function to dispatch an OpFoldResult into `staticVec` if:
43 ///   a) it is an IntegerAttr
44 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
45 /// In such dynamic cases, a copy of the `sentinel` value is also pushed to
46 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
47 /// come from an AttrSizedOperandSegments trait.
48 void dispatchIndexOpFoldResult(OpFoldResult ofr,
49                                SmallVectorImpl<Value> &dynamicVec,
50                                SmallVectorImpl<int64_t> &staticVec) {
51   auto v = llvm::dyn_cast_if_present<Value>(ofr);
52   if (!v) {
53     APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
54     staticVec.push_back(apInt.getSExtValue());
55     return;
56   }
57   dynamicVec.push_back(v);
58   staticVec.push_back(ShapedType::kDynamic);
59 }
60 
61 std::pair<int64_t, OpFoldResult>
62 getSimplifiedOfrAndStaticSizePair(OpFoldResult tileSizeOfr, Builder &b) {
63   int64_t tileSizeForShape =
64       getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
65 
66   OpFoldResult tileSizeOfrSimplified =
67       (tileSizeForShape != ShapedType::kDynamic)
68           ? b.getIndexAttr(tileSizeForShape)
69           : tileSizeOfr;
70 
71   return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
72                                           tileSizeOfrSimplified);
73 }
74 
75 void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
76                                 SmallVectorImpl<Value> &dynamicVec,
77                                 SmallVectorImpl<int64_t> &staticVec) {
78   for (OpFoldResult ofr : ofrs)
79     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
80 }
81 
82 /// Given a value, try to extract a constant Attribute. If this fails, return
83 /// the original value.
84 OpFoldResult getAsOpFoldResult(Value val) {
85   if (!val)
86     return OpFoldResult();
87   Attribute attr;
88   if (matchPattern(val, m_Constant(&attr)))
89     return attr;
90   return val;
91 }
92 
93 /// Given an array of values, try to extract a constant Attribute from each
94 /// value. If this fails, return the original value.
95 SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) {
96   return llvm::to_vector(
97       llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
98 }
99 
100 /// Convert `arrayAttr` to a vector of OpFoldResult.
101 SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) {
102   SmallVector<OpFoldResult> res;
103   res.reserve(arrayAttr.size());
104   for (Attribute a : arrayAttr)
105     res.push_back(a);
106   return res;
107 }
108 
109 OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) {
110   return IntegerAttr::get(IndexType::get(ctx), val);
111 }
112 
113 SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
114                                                  ArrayRef<int64_t> values) {
115   return llvm::to_vector(llvm::map_range(
116       values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
117 }
118 
119 /// If ofr is a constant integer or an IntegerAttr, return the integer.
120 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
121   // Case 1: Check for Constant integer.
122   if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
123     APSInt intVal;
124     if (matchPattern(val, m_ConstantInt(&intVal)))
125       return intVal.getSExtValue();
126     return std::nullopt;
127   }
128   // Case 2: Check for IntegerAttr.
129   Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
130   if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
131     return intAttr.getValue().getSExtValue();
132   return std::nullopt;
133 }
134 
135 std::optional<SmallVector<int64_t>>
136 getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
137   bool failed = false;
138   SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
139     auto cv = getConstantIntValue(ofr);
140     if (!cv.has_value())
141       failed = true;
142     return cv.value_or(0);
143   });
144   if (failed)
145     return std::nullopt;
146   return res;
147 }
148 
149 bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
150   auto val = getConstantIntValue(ofr);
151   return val && *val == value;
152 }
153 
154 bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
155   return llvm::all_of(
156       ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
157 }
158 
159 bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
160                           ArrayRef<int64_t> values) {
161   if (ofrs.size() != values.size())
162     return false;
163   std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
164   return constOfrs && llvm::equal(constOfrs.value(), values);
165 }
166 
167 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
168 /// or the same SSA value.
169 /// Ignore integer bitwidth and type mismatch that come from the fact there is
170 /// no IndexAttr and that IndexType has no bitwidth.
171 bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
172   auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
173   if (cst1 && cst2 && *cst1 == *cst2)
174     return true;
175   auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
176        v2 = llvm::dyn_cast_if_present<Value>(ofr2);
177   return v1 && v1 == v2;
178 }
179 
180 bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
181                                     ArrayRef<OpFoldResult> ofrs2) {
182   if (ofrs1.size() != ofrs2.size())
183     return false;
184   for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
185     if (!isEqualConstantIntOrValue(ofr1, ofr2))
186       return false;
187   return true;
188 }
189 
190 /// Return a vector of OpFoldResults with the same size a staticValues, but all
191 /// elements for which ShapedType::isDynamic is true, will be replaced by
192 /// dynamicValues.
193 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
194                                          ValueRange dynamicValues,
195                                          MLIRContext *context) {
196   SmallVector<OpFoldResult> res;
197   res.reserve(staticValues.size());
198   unsigned numDynamic = 0;
199   unsigned count = static_cast<unsigned>(staticValues.size());
200   for (unsigned idx = 0; idx < count; ++idx) {
201     int64_t value = staticValues[idx];
202     res.push_back(ShapedType::isDynamic(value)
203                       ? OpFoldResult{dynamicValues[numDynamic++]}
204                       : OpFoldResult{IntegerAttr::get(
205                             IntegerType::get(context, 64), staticValues[idx])});
206   }
207   return res;
208 }
209 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
210                                          ValueRange dynamicValues, Builder &b) {
211   return getMixedValues(staticValues, dynamicValues, b.getContext());
212 }
213 
214 /// Decompose a vector of mixed static or dynamic values into the corresponding
215 /// pair of arrays. This is the inverse function of `getMixedValues`.
216 std::pair<SmallVector<int64_t>, SmallVector<Value>>
217 decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
218   SmallVector<int64_t> staticValues;
219   SmallVector<Value> dynamicValues;
220   for (const auto &it : mixedValues) {
221     if (auto attr = dyn_cast<Attribute>(it)) {
222       staticValues.push_back(cast<IntegerAttr>(attr).getInt());
223     } else {
224       staticValues.push_back(ShapedType::kDynamic);
225       dynamicValues.push_back(cast<Value>(it));
226     }
227   }
228   return {staticValues, dynamicValues};
229 }
230 
231 /// Helper to sort `values` according to matching `keys`.
232 template <typename K, typename V>
233 static SmallVector<V>
234 getValuesSortedByKeyImpl(ArrayRef<K> keys, ArrayRef<V> values,
235                          llvm::function_ref<bool(K, K)> compare) {
236   if (keys.empty())
237     return SmallVector<V>{values};
238   assert(keys.size() == values.size() && "unexpected mismatching sizes");
239   auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
240   std::sort(indices.begin(), indices.end(),
241             [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
242   SmallVector<V> res;
243   res.reserve(values.size());
244   for (int64_t i = 0, e = indices.size(); i < e; ++i)
245     res.push_back(values[indices[i]]);
246   return res;
247 }
248 
249 SmallVector<Value>
250 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
251                      llvm::function_ref<bool(Attribute, Attribute)> compare) {
252   return getValuesSortedByKeyImpl(keys, values, compare);
253 }
254 
255 SmallVector<OpFoldResult>
256 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
257                      llvm::function_ref<bool(Attribute, Attribute)> compare) {
258   return getValuesSortedByKeyImpl(keys, values, compare);
259 }
260 
261 SmallVector<int64_t>
262 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
263                      llvm::function_ref<bool(Attribute, Attribute)> compare) {
264   return getValuesSortedByKeyImpl(keys, values, compare);
265 }
266 
267 /// Return the number of iterations for a loop with a lower bound `lb`, upper
268 /// bound `ub` and step `step`.
269 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
270                                          OpFoldResult step) {
271   if (lb == ub)
272     return 0;
273 
274   std::optional<int64_t> lbConstant = getConstantIntValue(lb);
275   if (!lbConstant)
276     return std::nullopt;
277   std::optional<int64_t> ubConstant = getConstantIntValue(ub);
278   if (!ubConstant)
279     return std::nullopt;
280   std::optional<int64_t> stepConstant = getConstantIntValue(step);
281   if (!stepConstant)
282     return std::nullopt;
283 
284   return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
285 }
286 
287 bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
288   return llvm::none_of(sizesOrOffsets, [](int64_t value) {
289     return !ShapedType::isDynamic(value) && value < 0;
290   });
291 }
292 
293 bool hasValidStrides(SmallVector<int64_t> strides) {
294   return llvm::none_of(strides, [](int64_t value) {
295     return !ShapedType::isDynamic(value) && value == 0;
296   });
297 }
298 
299 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
300                                    bool onlyNonNegative, bool onlyNonZero) {
301   bool valuesChanged = false;
302   for (OpFoldResult &ofr : ofrs) {
303     if (isa<Attribute>(ofr))
304       continue;
305     Attribute attr;
306     if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) {
307       // Note: All ofrs have index type.
308       if (onlyNonNegative && *getConstantIntValue(attr) < 0)
309         continue;
310       if (onlyNonZero && *getConstantIntValue(attr) == 0)
311         continue;
312       ofr = attr;
313       valuesChanged = true;
314     }
315   }
316   return success(valuesChanged);
317 }
318 
319 LogicalResult
320 foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
321   return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
322                               /*onlyNonZero=*/false);
323 }
324 
325 LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
326   return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
327                               /*onlyNonZero=*/true);
328 }
329 
330 } // namespace mlir
331