1 //===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines utilities for dealing with static values, e.g., 10 // converting back and forth between Value and OpFoldResult. Such functionality 11 // is used in multiple dialects. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H 16 #define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H 17 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinAttributes.h" 20 #include "mlir/IR/OpDefinition.h" 21 #include "mlir/Support/LLVM.h" 22 #include "llvm/ADT/SmallVector.h" 23 #include "llvm/ADT/SmallVectorExtras.h" 24 25 namespace mlir { 26 27 /// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp 28 /// with attribute with value `0`. 29 bool isZeroIndex(OpFoldResult v); 30 31 /// Represents a range (offset, size, and stride) where each element of the 32 /// triple may be dynamic or static. 33 struct Range { 34 OpFoldResult offset; 35 OpFoldResult size; 36 OpFoldResult stride; 37 }; 38 39 /// Given an array of Range values, return a tuple of (offset vector, sizes 40 /// vector, and strides vector) formed by separating out the individual 41 /// elements of each range. 42 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, 43 SmallVector<OpFoldResult>> 44 getOffsetsSizesAndStrides(ArrayRef<Range> ranges); 45 46 /// Helper function to dispatch an OpFoldResult into `staticVec` if: 47 /// a) it is an IntegerAttr 48 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. 49 /// In such dynamic cases, ShapedType::kDynamic is also pushed to 50 /// `staticVec`. This is useful to extract mixed static and dynamic entries 51 /// that come from an AttrSizedOperandSegments trait. 52 void dispatchIndexOpFoldResult(OpFoldResult ofr, 53 SmallVectorImpl<Value> &dynamicVec, 54 SmallVectorImpl<int64_t> &staticVec); 55 56 /// Helper function to dispatch multiple OpFoldResults according to the 57 /// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single 58 /// OpFoldResult. 59 void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 60 SmallVectorImpl<Value> &dynamicVec, 61 SmallVectorImpl<int64_t> &staticVec); 62 63 /// Given OpFoldResult representing dim size value (*), generates a pair of 64 /// sizes: 65 /// * 1st result, static value, contains an int64_t dim size that can be used 66 /// to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims), 67 /// * 2nd result, dynamic value, contains OpFoldResult encapsulating the 68 /// actual dim size (either original or updated input value). 69 /// For input sizes for which it is possible to extract a constant Attribute, 70 /// replaces the original size value with an integer attribute (unless it's 71 /// already a constant Attribute). The 1st return value also becomes the actual 72 /// integer size (as opposed ShapedType::kDynamic). 73 /// 74 /// (*) This hook is usually used when, given input sizes as OpFoldResult, 75 /// it's required to generate two vectors: 76 /// * sizes as int64_t to generate a shape, 77 /// * sizes as OpFoldResult for sizes-like attribute. 78 /// Please update this comment if you identify other use cases. 79 std::pair<int64_t, OpFoldResult> 80 getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b); 81 82 /// Extract integer values from the assumed ArrayAttr of IntegerAttr. 83 template <typename IntTy> 84 SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) { 85 return llvm::to_vector( 86 llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> IntTy { 87 return cast<IntegerAttr>(a).getInt(); 88 })); 89 } 90 91 /// Given a value, try to extract a constant Attribute. If this fails, return 92 /// the original value. 93 OpFoldResult getAsOpFoldResult(Value val); 94 /// Given an array of values, try to extract a constant Attribute from each 95 /// value. If this fails, return the original value. 96 SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values); 97 /// Convert `arrayAttr` to a vector of OpFoldResult. 98 SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr); 99 100 /// Convert int64_t to integer attributes of index type and return them as 101 /// OpFoldResult. 102 OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val); 103 SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx, 104 ArrayRef<int64_t> values); 105 106 /// If ofr is a constant integer or an IntegerAttr, return the integer. 107 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr); 108 /// If all ofrs are constant integers or IntegerAttrs, return the integers. 109 std::optional<SmallVector<int64_t>> 110 getConstantIntValues(ArrayRef<OpFoldResult> ofrs); 111 112 /// Return true if `ofr` is constant integer equal to `value`. 113 bool isConstantIntValue(OpFoldResult ofr, int64_t value); 114 /// Return true if all of `ofrs` are constant integers equal to `value`. 115 bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value); 116 /// Return true if all of `ofrs` are constant integers equal to the 117 /// corresponding value in `values`. 118 bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs, 119 ArrayRef<int64_t> values); 120 121 /// Return true if ofr1 and ofr2 are the same integer constant attribute 122 /// values or the same SSA value. Ignore integer bitwitdh and type mismatch 123 /// that come from the fact there is no IndexAttr and that IndexType have no 124 /// bitwidth. 125 bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); 126 bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1, 127 ArrayRef<OpFoldResult> ofrs2); 128 129 // To convert an OpFoldResult to a Value of index type, see: 130 // mlir/include/mlir/Dialect/Arith/Utils/Utils.h 131 // TODO: find a better common landing place. 132 // 133 // Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, 134 // OpFoldResult ofr); 135 136 // To convert an OpFoldResult to a Value of index type, see: 137 // mlir/include/mlir/Dialect/Arith/Utils/Utils.h 138 // TODO: find a better common landing place. 139 // 140 // SmallVector<Value> 141 // getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, 142 // ArrayRef<OpFoldResult> valueOrAttrVec); 143 144 /// Return a vector of OpFoldResults with the same size a staticValues, but 145 /// all elements for which ShapedType::isDynamic is true, will be replaced by 146 /// dynamicValues. 147 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues, 148 ValueRange dynamicValues, 149 MLIRContext *context); 150 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues, 151 ValueRange dynamicValues, Builder &b); 152 153 /// Decompose a vector of mixed static or dynamic values into the 154 /// corresponding pair of arrays. This is the inverse function of 155 /// `getMixedValues`. 156 std::pair<SmallVector<int64_t>, SmallVector<Value>> 157 decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues); 158 159 /// Helper to sort `values` according to matching `keys`. 160 SmallVector<Value> 161 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values, 162 llvm::function_ref<bool(Attribute, Attribute)> compare); 163 SmallVector<OpFoldResult> 164 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values, 165 llvm::function_ref<bool(Attribute, Attribute)> compare); 166 SmallVector<int64_t> 167 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values, 168 llvm::function_ref<bool(Attribute, Attribute)> compare); 169 170 /// Helper function to check whether the passed in `sizes` or `offsets` are 171 /// valid. This can be used to re-check whether dimensions are still valid 172 /// after constant folding the dynamic dimensions. 173 bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets); 174 175 /// Helper function to check whether the passed in `strides` are valid. This 176 /// can be used to re-check whether dimensions are still valid after constant 177 /// folding the dynamic dimensions. 178 bool hasValidStrides(SmallVector<int64_t> strides); 179 180 /// Returns "success" when any of the elements in `ofrs` is a constant value. In 181 /// that case the value is replaced by an attribute. Returns "failure" when no 182 /// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only 183 /// non-negative and non-zero constant values are folded respectively. 184 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, 185 bool onlyNonNegative = false, 186 bool onlyNonZero = false); 187 188 /// Returns "success" when any of the elements in `offsetsOrSizes` is a 189 /// constant value. In that case the value is replaced by an attribute. Returns 190 /// "failure" when no folding happened. Invalid values are not folded to avoid 191 /// canonicalization crashes. 192 LogicalResult 193 foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes); 194 195 /// Returns "success" when any of the elements in `strides` is a constant 196 /// value. In that case the value is replaced by an attribute. Returns 197 /// "failure" when no folding happened. Invalid values are not folded to avoid 198 /// canonicalization crashes. 199 LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides); 200 201 /// Return the number of iterations for a loop with a lower bound `lb`, upper 202 /// bound `ub` and step `step`. 203 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, 204 OpFoldResult step); 205 206 /// Idiomatic saturated operations on values like offsets, sizes, and strides. 207 struct SaturatedInteger { 208 static SaturatedInteger wrap(int64_t v) { 209 return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0} 210 : SaturatedInteger{false, v}; 211 } 212 int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; } 213 FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) { 214 if (saturated && !other.saturated) 215 return other; 216 if (!saturated && !other.saturated && v != other.v) 217 return failure(); 218 return *this; 219 } 220 bool operator==(SaturatedInteger other) { 221 return (saturated && other.saturated) || 222 (!saturated && !other.saturated && v == other.v); 223 } 224 bool operator!=(SaturatedInteger other) { return !(*this == other); } 225 SaturatedInteger operator+(SaturatedInteger other) { 226 if (saturated || other.saturated) 227 return SaturatedInteger{true, 0}; 228 return SaturatedInteger{false, other.v + v}; 229 } 230 SaturatedInteger operator*(SaturatedInteger other) { 231 // Multiplication with 0 is always 0. 232 if (!other.saturated && other.v == 0) 233 return SaturatedInteger{false, 0}; 234 if (!saturated && v == 0) 235 return SaturatedInteger{false, 0}; 236 // Otherwise, if this or the other integer is dynamic, so is the result. 237 if (saturated || other.saturated) 238 return SaturatedInteger{true, 0}; 239 return SaturatedInteger{false, other.v * v}; 240 } 241 bool saturated = true; 242 int64_t v = 0; 243 }; 244 245 } // namespace mlir 246 247 #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H 248