1 //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/StaticValueUtils.h" 10 #include "mlir/IR/Matchers.h" 11 #include "mlir/Support/LLVM.h" 12 #include "llvm/ADT/APSInt.h" 13 #include "llvm/ADT/STLExtras.h" 14 #include "llvm/Support/MathExtras.h" 15 16 namespace mlir { 17 18 bool isZeroIndex(OpFoldResult v) { 19 if (!v) 20 return false; 21 std::optional<int64_t> constint = getConstantIntValue(v); 22 if (!constint) 23 return false; 24 return *constint == 0; 25 } 26 27 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, 28 SmallVector<OpFoldResult>> 29 getOffsetsSizesAndStrides(ArrayRef<Range> ranges) { 30 SmallVector<OpFoldResult> offsets, sizes, strides; 31 offsets.reserve(ranges.size()); 32 sizes.reserve(ranges.size()); 33 strides.reserve(ranges.size()); 34 for (const auto &[offset, size, stride] : ranges) { 35 offsets.push_back(offset); 36 sizes.push_back(size); 37 strides.push_back(stride); 38 } 39 return std::make_tuple(offsets, sizes, strides); 40 } 41 42 /// Helper function to dispatch an OpFoldResult into `staticVec` if: 43 /// a) it is an IntegerAttr 44 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. 45 /// In such dynamic cases, a copy of the `sentinel` value is also pushed to 46 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 47 /// come from an AttrSizedOperandSegments trait. 48 void dispatchIndexOpFoldResult(OpFoldResult ofr, 49 SmallVectorImpl<Value> &dynamicVec, 50 SmallVectorImpl<int64_t> &staticVec) { 51 auto v = llvm::dyn_cast_if_present<Value>(ofr); 52 if (!v) { 53 APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue(); 54 staticVec.push_back(apInt.getSExtValue()); 55 return; 56 } 57 dynamicVec.push_back(v); 58 staticVec.push_back(ShapedType::kDynamic); 59 } 60 61 std::pair<int64_t, OpFoldResult> 62 getSimplifiedOfrAndStaticSizePair(OpFoldResult tileSizeOfr, Builder &b) { 63 int64_t tileSizeForShape = 64 getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic); 65 66 OpFoldResult tileSizeOfrSimplified = 67 (tileSizeForShape != ShapedType::kDynamic) 68 ? b.getIndexAttr(tileSizeForShape) 69 : tileSizeOfr; 70 71 return std::pair<int64_t, OpFoldResult>(tileSizeForShape, 72 tileSizeOfrSimplified); 73 } 74 75 void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 76 SmallVectorImpl<Value> &dynamicVec, 77 SmallVectorImpl<int64_t> &staticVec) { 78 for (OpFoldResult ofr : ofrs) 79 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec); 80 } 81 82 /// Given a value, try to extract a constant Attribute. If this fails, return 83 /// the original value. 84 OpFoldResult getAsOpFoldResult(Value val) { 85 if (!val) 86 return OpFoldResult(); 87 Attribute attr; 88 if (matchPattern(val, m_Constant(&attr))) 89 return attr; 90 return val; 91 } 92 93 /// Given an array of values, try to extract a constant Attribute from each 94 /// value. If this fails, return the original value. 95 SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) { 96 return llvm::to_vector( 97 llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); })); 98 } 99 100 /// Convert `arrayAttr` to a vector of OpFoldResult. 101 SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) { 102 SmallVector<OpFoldResult> res; 103 res.reserve(arrayAttr.size()); 104 for (Attribute a : arrayAttr) 105 res.push_back(a); 106 return res; 107 } 108 109 OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) { 110 return IntegerAttr::get(IndexType::get(ctx), val); 111 } 112 113 SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx, 114 ArrayRef<int64_t> values) { 115 return llvm::to_vector(llvm::map_range( 116 values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); })); 117 } 118 119 /// If ofr is a constant integer or an IntegerAttr, return the integer. 120 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) { 121 // Case 1: Check for Constant integer. 122 if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) { 123 APSInt intVal; 124 if (matchPattern(val, m_ConstantInt(&intVal))) 125 return intVal.getSExtValue(); 126 return std::nullopt; 127 } 128 // Case 2: Check for IntegerAttr. 129 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr); 130 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr)) 131 return intAttr.getValue().getSExtValue(); 132 return std::nullopt; 133 } 134 135 std::optional<SmallVector<int64_t>> 136 getConstantIntValues(ArrayRef<OpFoldResult> ofrs) { 137 bool failed = false; 138 SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) { 139 auto cv = getConstantIntValue(ofr); 140 if (!cv.has_value()) 141 failed = true; 142 return cv.value_or(0); 143 }); 144 if (failed) 145 return std::nullopt; 146 return res; 147 } 148 149 bool isConstantIntValue(OpFoldResult ofr, int64_t value) { 150 auto val = getConstantIntValue(ofr); 151 return val && *val == value; 152 } 153 154 bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) { 155 return llvm::all_of( 156 ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); }); 157 } 158 159 bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs, 160 ArrayRef<int64_t> values) { 161 if (ofrs.size() != values.size()) 162 return false; 163 std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs); 164 return constOfrs && llvm::equal(constOfrs.value(), values); 165 } 166 167 /// Return true if ofr1 and ofr2 are the same integer constant attribute values 168 /// or the same SSA value. 169 /// Ignore integer bitwidth and type mismatch that come from the fact there is 170 /// no IndexAttr and that IndexType has no bitwidth. 171 bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { 172 auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2); 173 if (cst1 && cst2 && *cst1 == *cst2) 174 return true; 175 auto v1 = llvm::dyn_cast_if_present<Value>(ofr1), 176 v2 = llvm::dyn_cast_if_present<Value>(ofr2); 177 return v1 && v1 == v2; 178 } 179 180 bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1, 181 ArrayRef<OpFoldResult> ofrs2) { 182 if (ofrs1.size() != ofrs2.size()) 183 return false; 184 for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2)) 185 if (!isEqualConstantIntOrValue(ofr1, ofr2)) 186 return false; 187 return true; 188 } 189 190 /// Return a vector of OpFoldResults with the same size a staticValues, but all 191 /// elements for which ShapedType::isDynamic is true, will be replaced by 192 /// dynamicValues. 193 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues, 194 ValueRange dynamicValues, 195 MLIRContext *context) { 196 SmallVector<OpFoldResult> res; 197 res.reserve(staticValues.size()); 198 unsigned numDynamic = 0; 199 unsigned count = static_cast<unsigned>(staticValues.size()); 200 for (unsigned idx = 0; idx < count; ++idx) { 201 int64_t value = staticValues[idx]; 202 res.push_back(ShapedType::isDynamic(value) 203 ? OpFoldResult{dynamicValues[numDynamic++]} 204 : OpFoldResult{IntegerAttr::get( 205 IntegerType::get(context, 64), staticValues[idx])}); 206 } 207 return res; 208 } 209 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues, 210 ValueRange dynamicValues, Builder &b) { 211 return getMixedValues(staticValues, dynamicValues, b.getContext()); 212 } 213 214 /// Decompose a vector of mixed static or dynamic values into the corresponding 215 /// pair of arrays. This is the inverse function of `getMixedValues`. 216 std::pair<SmallVector<int64_t>, SmallVector<Value>> 217 decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) { 218 SmallVector<int64_t> staticValues; 219 SmallVector<Value> dynamicValues; 220 for (const auto &it : mixedValues) { 221 if (auto attr = dyn_cast<Attribute>(it)) { 222 staticValues.push_back(cast<IntegerAttr>(attr).getInt()); 223 } else { 224 staticValues.push_back(ShapedType::kDynamic); 225 dynamicValues.push_back(cast<Value>(it)); 226 } 227 } 228 return {staticValues, dynamicValues}; 229 } 230 231 /// Helper to sort `values` according to matching `keys`. 232 template <typename K, typename V> 233 static SmallVector<V> 234 getValuesSortedByKeyImpl(ArrayRef<K> keys, ArrayRef<V> values, 235 llvm::function_ref<bool(K, K)> compare) { 236 if (keys.empty()) 237 return SmallVector<V>{values}; 238 assert(keys.size() == values.size() && "unexpected mismatching sizes"); 239 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size())); 240 std::sort(indices.begin(), indices.end(), 241 [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); }); 242 SmallVector<V> res; 243 res.reserve(values.size()); 244 for (int64_t i = 0, e = indices.size(); i < e; ++i) 245 res.push_back(values[indices[i]]); 246 return res; 247 } 248 249 SmallVector<Value> 250 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values, 251 llvm::function_ref<bool(Attribute, Attribute)> compare) { 252 return getValuesSortedByKeyImpl(keys, values, compare); 253 } 254 255 SmallVector<OpFoldResult> 256 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values, 257 llvm::function_ref<bool(Attribute, Attribute)> compare) { 258 return getValuesSortedByKeyImpl(keys, values, compare); 259 } 260 261 SmallVector<int64_t> 262 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values, 263 llvm::function_ref<bool(Attribute, Attribute)> compare) { 264 return getValuesSortedByKeyImpl(keys, values, compare); 265 } 266 267 /// Return the number of iterations for a loop with a lower bound `lb`, upper 268 /// bound `ub` and step `step`. 269 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, 270 OpFoldResult step) { 271 if (lb == ub) 272 return 0; 273 274 std::optional<int64_t> lbConstant = getConstantIntValue(lb); 275 if (!lbConstant) 276 return std::nullopt; 277 std::optional<int64_t> ubConstant = getConstantIntValue(ub); 278 if (!ubConstant) 279 return std::nullopt; 280 std::optional<int64_t> stepConstant = getConstantIntValue(step); 281 if (!stepConstant) 282 return std::nullopt; 283 284 return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant); 285 } 286 287 bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) { 288 return llvm::none_of(sizesOrOffsets, [](int64_t value) { 289 return !ShapedType::isDynamic(value) && value < 0; 290 }); 291 } 292 293 bool hasValidStrides(SmallVector<int64_t> strides) { 294 return llvm::none_of(strides, [](int64_t value) { 295 return !ShapedType::isDynamic(value) && value == 0; 296 }); 297 } 298 299 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, 300 bool onlyNonNegative, bool onlyNonZero) { 301 bool valuesChanged = false; 302 for (OpFoldResult &ofr : ofrs) { 303 if (isa<Attribute>(ofr)) 304 continue; 305 Attribute attr; 306 if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) { 307 // Note: All ofrs have index type. 308 if (onlyNonNegative && *getConstantIntValue(attr) < 0) 309 continue; 310 if (onlyNonZero && *getConstantIntValue(attr) == 0) 311 continue; 312 ofr = attr; 313 valuesChanged = true; 314 } 315 } 316 return success(valuesChanged); 317 } 318 319 LogicalResult 320 foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) { 321 return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true, 322 /*onlyNonZero=*/false); 323 } 324 325 LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) { 326 return foldDynamicIndexList(strides, /*onlyNonNegative=*/false, 327 /*onlyNonZero=*/true); 328 } 329 330 } // namespace mlir 331