1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ViewLikeInterface.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // ViewLike Interfaces 15 //===----------------------------------------------------------------------===// 16 17 /// Include the definitions of the loop-like interfaces. 18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" 19 20 LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, 21 StringRef name, 22 unsigned numElements, 23 ArrayRef<int64_t> staticVals, 24 ValueRange values) { 25 // Check static and dynamic offsets/sizes/strides does not overflow type. 26 if (staticVals.size() != numElements) 27 return op->emitError("expected ") << numElements << " " << name 28 << " values, got " << staticVals.size(); 29 unsigned expectedNumDynamicEntries = 30 llvm::count_if(staticVals, [](int64_t staticVal) { 31 return ShapedType::isDynamic(staticVal); 32 }); 33 if (values.size() != expectedNumDynamicEntries) 34 return op->emitError("expected ") 35 << expectedNumDynamicEntries << " dynamic " << name << " values"; 36 return success(); 37 } 38 39 LogicalResult 40 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { 41 std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks(); 42 // Offsets can come in 2 flavors: 43 // 1. Either single entry (when maxRanks == 1). 44 // 2. Or as an array whose rank must match that of the mixed sizes. 45 // So that the result type is well-formed. 46 if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT 47 op.getMixedOffsets().size() != op.getMixedSizes().size()) 48 return op->emitError( 49 "expected mixed offsets rank to match mixed sizes rank (") 50 << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() 51 << ") so the rank of the result type is well-formed."; 52 // Ranks of mixed sizes and strides must always match so the result type is 53 // well-formed. 54 if (op.getMixedSizes().size() != op.getMixedStrides().size()) 55 return op->emitError( 56 "expected mixed sizes rank to match mixed strides rank (") 57 << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() 58 << ") so the rank of the result type is well-formed."; 59 60 if (failed(verifyListOfOperandsOrIntegers( 61 op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets()))) 62 return failure(); 63 if (failed(verifyListOfOperandsOrIntegers( 64 op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes()))) 65 return failure(); 66 if (failed(verifyListOfOperandsOrIntegers( 67 op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides()))) 68 return failure(); 69 70 for (int64_t offset : op.getStaticOffsets()) { 71 if (offset < 0 && !ShapedType::isDynamic(offset)) 72 return op->emitError("expected offsets to be non-negative, but got ") 73 << offset; 74 } 75 for (int64_t size : op.getStaticSizes()) { 76 if (size < 0 && !ShapedType::isDynamic(size)) 77 return op->emitError("expected sizes to be non-negative, but got ") 78 << size; 79 } 80 return success(); 81 } 82 83 static char getLeftDelimiter(AsmParser::Delimiter delimiter) { 84 switch (delimiter) { 85 case AsmParser::Delimiter::Paren: 86 return '('; 87 case AsmParser::Delimiter::LessGreater: 88 return '<'; 89 case AsmParser::Delimiter::Square: 90 return '['; 91 case AsmParser::Delimiter::Braces: 92 return '{'; 93 default: 94 llvm_unreachable("unsupported delimiter"); 95 } 96 } 97 98 static char getRightDelimiter(AsmParser::Delimiter delimiter) { 99 switch (delimiter) { 100 case AsmParser::Delimiter::Paren: 101 return ')'; 102 case AsmParser::Delimiter::LessGreater: 103 return '>'; 104 case AsmParser::Delimiter::Square: 105 return ']'; 106 case AsmParser::Delimiter::Braces: 107 return '}'; 108 default: 109 llvm_unreachable("unsupported delimiter"); 110 } 111 } 112 113 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, 114 OperandRange values, 115 ArrayRef<int64_t> integers, 116 ArrayRef<bool> scalableFlags, 117 TypeRange valueTypes, 118 AsmParser::Delimiter delimiter) { 119 char leftDelimiter = getLeftDelimiter(delimiter); 120 char rightDelimiter = getRightDelimiter(delimiter); 121 printer << leftDelimiter; 122 if (integers.empty()) { 123 printer << rightDelimiter; 124 return; 125 } 126 127 unsigned dynamicValIdx = 0; 128 unsigned scalableIndexIdx = 0; 129 llvm::interleaveComma(integers, printer, [&](int64_t integer) { 130 if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx]) 131 printer << "["; 132 if (ShapedType::isDynamic(integer)) { 133 printer << values[dynamicValIdx]; 134 if (!valueTypes.empty()) 135 printer << " : " << valueTypes[dynamicValIdx]; 136 ++dynamicValIdx; 137 } else { 138 printer << integer; 139 } 140 if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx]) 141 printer << "]"; 142 143 scalableIndexIdx++; 144 }); 145 146 printer << rightDelimiter; 147 } 148 149 ParseResult mlir::parseDynamicIndexList( 150 OpAsmParser &parser, 151 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, 152 DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, 153 SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) { 154 155 SmallVector<int64_t, 4> integerVals; 156 SmallVector<bool, 4> scalableVals; 157 auto parseIntegerOrValue = [&]() { 158 OpAsmParser::UnresolvedOperand operand; 159 auto res = parser.parseOptionalOperand(operand); 160 161 // When encountering `[`, assume that this is a scalable index. 162 scalableVals.push_back(parser.parseOptionalLSquare().succeeded()); 163 164 if (res.has_value() && succeeded(res.value())) { 165 values.push_back(operand); 166 integerVals.push_back(ShapedType::kDynamic); 167 if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) 168 return failure(); 169 } else { 170 int64_t integer; 171 if (failed(parser.parseInteger(integer))) 172 return failure(); 173 integerVals.push_back(integer); 174 } 175 176 // If this is assumed to be a scalable index, verify that there's a closing 177 // `]`. 178 if (scalableVals.back() && parser.parseOptionalRSquare().failed()) 179 return failure(); 180 return success(); 181 }; 182 if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue, 183 " in dynamic index list")) 184 return parser.emitError(parser.getNameLoc()) 185 << "expected SSA value or integer"; 186 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); 187 scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); 188 return success(); 189 } 190 191 bool mlir::detail::sameOffsetsSizesAndStrides( 192 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, 193 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) { 194 if (a.getStaticOffsets().size() != b.getStaticOffsets().size()) 195 return false; 196 if (a.getStaticSizes().size() != b.getStaticSizes().size()) 197 return false; 198 if (a.getStaticStrides().size() != b.getStaticStrides().size()) 199 return false; 200 for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) 201 if (!cmp(std::get<0>(it), std::get<1>(it))) 202 return false; 203 for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) 204 if (!cmp(std::get<0>(it), std::get<1>(it))) 205 return false; 206 for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) 207 if (!cmp(std::get<0>(it), std::get<1>(it))) 208 return false; 209 return true; 210 } 211 212 unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals, 213 unsigned idx) { 214 return std::count_if(staticVals.begin(), staticVals.begin() + idx, 215 [&](int64_t val) { return ShapedType::isDynamic(val); }); 216 } 217