1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ViewLikeInterface.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // ViewLike Interfaces 15 //===----------------------------------------------------------------------===// 16 17 /// Include the definitions of the loop-like interfaces. 18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" 19 20 LogicalResult mlir::verifyListOfOperandsOrIntegers( 21 Operation *op, StringRef name, unsigned numElements, ArrayAttr attr, 22 ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) { 23 /// Check static and dynamic offsets/sizes/strides does not overflow type. 24 if (attr.size() != numElements) 25 return op->emitError("expected ") 26 << numElements << " " << name << " values"; 27 unsigned expectedNumDynamicEntries = 28 llvm::count_if(attr.getValue(), [&](Attribute attr) { 29 return isDynamic(attr.cast<IntegerAttr>().getInt()); 30 }); 31 if (values.size() != expectedNumDynamicEntries) 32 return op->emitError("expected ") 33 << expectedNumDynamicEntries << " dynamic " << name << " values"; 34 return success(); 35 } 36 37 LogicalResult 38 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { 39 std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks(); 40 // Offsets can come in 2 flavors: 41 // 1. Either single entry (when maxRanks == 1). 42 // 2. Or as an array whose rank must match that of the mixed sizes. 43 // So that the result type is well-formed. 44 if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT 45 op.getMixedOffsets().size() != op.getMixedSizes().size()) 46 return op->emitError( 47 "expected mixed offsets rank to match mixed sizes rank (") 48 << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() 49 << ") so the rank of the result type is well-formed."; 50 // Ranks of mixed sizes and strides must always match so the result type is 51 // well-formed. 52 if (op.getMixedSizes().size() != op.getMixedStrides().size()) 53 return op->emitError( 54 "expected mixed sizes rank to match mixed strides rank (") 55 << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() 56 << ") so the rank of the result type is well-formed."; 57 58 if (failed(verifyListOfOperandsOrIntegers(op, "offset", maxRanks[0], 59 op.static_offsets(), op.offsets(), 60 ShapedType::isDynamic))) 61 return failure(); 62 if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1], 63 op.static_sizes(), op.sizes(), 64 ShapedType::isDynamic))) 65 return failure(); 66 if (failed(verifyListOfOperandsOrIntegers(op, "stride", maxRanks[2], 67 op.static_strides(), op.strides(), 68 ShapedType::isDynamic))) 69 return failure(); 70 return success(); 71 } 72 73 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, 74 OperandRange values, ArrayAttr integers, 75 int64_t dynVal) { 76 printer << '['; 77 if (integers.empty()) { 78 printer << "]"; 79 return; 80 } 81 unsigned idx = 0; 82 llvm::interleaveComma(integers, printer, [&](Attribute a) { 83 int64_t val = a.cast<IntegerAttr>().getInt(); 84 if (val == dynVal) 85 printer << values[idx++]; 86 else 87 printer << val; 88 }); 89 printer << ']'; 90 } 91 92 ParseResult mlir::parseDynamicIndexList( 93 OpAsmParser &parser, 94 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, 95 ArrayAttr &integers, int64_t dynVal) { 96 if (failed(parser.parseLSquare())) 97 return failure(); 98 // 0-D. 99 if (succeeded(parser.parseOptionalRSquare())) { 100 integers = parser.getBuilder().getArrayAttr({}); 101 return success(); 102 } 103 104 SmallVector<int64_t, 4> attrVals; 105 while (true) { 106 OpAsmParser::UnresolvedOperand operand; 107 auto res = parser.parseOptionalOperand(operand); 108 if (res.has_value() && succeeded(res.value())) { 109 values.push_back(operand); 110 attrVals.push_back(dynVal); 111 } else { 112 IntegerAttr attr; 113 if (failed(parser.parseAttribute<IntegerAttr>(attr))) 114 return parser.emitError(parser.getNameLoc()) 115 << "expected SSA value or integer"; 116 attrVals.push_back(attr.getInt()); 117 } 118 119 if (succeeded(parser.parseOptionalComma())) 120 continue; 121 if (failed(parser.parseRSquare())) 122 return failure(); 123 break; 124 } 125 integers = parser.getBuilder().getI64ArrayAttr(attrVals); 126 return success(); 127 } 128 129 bool mlir::detail::sameOffsetsSizesAndStrides( 130 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, 131 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) { 132 if (a.static_offsets().size() != b.static_offsets().size()) 133 return false; 134 if (a.static_sizes().size() != b.static_sizes().size()) 135 return false; 136 if (a.static_strides().size() != b.static_strides().size()) 137 return false; 138 for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) 139 if (!cmp(std::get<0>(it), std::get<1>(it))) 140 return false; 141 for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) 142 if (!cmp(std::get<0>(it), std::get<1>(it))) 143 return false; 144 for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) 145 if (!cmp(std::get<0>(it), std::get<1>(it))) 146 return false; 147 return true; 148 } 149 150 SmallVector<OpFoldResult, 4> 151 mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues, 152 const int64_t dynamicValueIndicator) { 153 SmallVector<OpFoldResult, 4> res; 154 res.reserve(staticValues.size()); 155 unsigned numDynamic = 0; 156 unsigned count = static_cast<unsigned>(staticValues.size()); 157 for (unsigned idx = 0; idx < count; ++idx) { 158 APInt value = staticValues[idx].cast<IntegerAttr>().getValue(); 159 res.push_back(value.getSExtValue() == dynamicValueIndicator 160 ? OpFoldResult{dynamicValues[numDynamic++]} 161 : OpFoldResult{staticValues[idx]}); 162 } 163 return res; 164 } 165 166 SmallVector<OpFoldResult, 4> 167 mlir::getMixedStridesOrOffsets(ArrayAttr staticValues, 168 ValueRange dynamicValues) { 169 return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamic); 170 } 171 172 SmallVector<OpFoldResult, 4> mlir::getMixedSizes(ArrayAttr staticValues, 173 ValueRange dynamicValues) { 174 return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamic); 175 } 176 177 std::pair<ArrayAttr, SmallVector<Value>> 178 mlir::decomposeMixedValues(Builder &b, 179 const SmallVectorImpl<OpFoldResult> &mixedValues, 180 const int64_t dynamicValueIndicator) { 181 SmallVector<int64_t> staticValues; 182 SmallVector<Value> dynamicValues; 183 for (const auto &it : mixedValues) { 184 if (it.is<Attribute>()) { 185 staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt()); 186 } else { 187 staticValues.push_back(dynamicValueIndicator); 188 dynamicValues.push_back(it.get<Value>()); 189 } 190 } 191 return {b.getI64ArrayAttr(staticValues), dynamicValues}; 192 } 193 194 std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets( 195 OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) { 196 return decomposeMixedValues(b, mixedValues, ShapedType::kDynamic); 197 } 198 199 std::pair<ArrayAttr, SmallVector<Value>> 200 mlir::decomposeMixedSizes(OpBuilder &b, 201 const SmallVectorImpl<OpFoldResult> &mixedValues) { 202 return decomposeMixedValues(b, mixedValues, ShapedType::kDynamic); 203 } 204