1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ViewLikeInterface.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // ViewLike Interfaces 15 //===----------------------------------------------------------------------===// 16 17 /// Include the definitions of the loop-like interfaces. 18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" 19 20 LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, 21 StringRef name, 22 unsigned numElements, 23 ArrayRef<int64_t> staticVals, 24 ValueRange values) { 25 // Check static and dynamic offsets/sizes/strides does not overflow type. 26 if (staticVals.size() != numElements) 27 return op->emitError("expected ") << numElements << " " << name 28 << " values, got " << staticVals.size(); 29 unsigned expectedNumDynamicEntries = 30 llvm::count_if(staticVals, [&](int64_t staticVal) { 31 return ShapedType::isDynamic(staticVal); 32 }); 33 if (values.size() != expectedNumDynamicEntries) 34 return op->emitError("expected ") 35 << expectedNumDynamicEntries << " dynamic " << name << " values"; 36 return success(); 37 } 38 39 LogicalResult 40 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { 41 std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks(); 42 // Offsets can come in 2 flavors: 43 // 1. Either single entry (when maxRanks == 1). 44 // 2. Or as an array whose rank must match that of the mixed sizes. 45 // So that the result type is well-formed. 46 if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT 47 op.getMixedOffsets().size() != op.getMixedSizes().size()) 48 return op->emitError( 49 "expected mixed offsets rank to match mixed sizes rank (") 50 << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() 51 << ") so the rank of the result type is well-formed."; 52 // Ranks of mixed sizes and strides must always match so the result type is 53 // well-formed. 54 if (op.getMixedSizes().size() != op.getMixedStrides().size()) 55 return op->emitError( 56 "expected mixed sizes rank to match mixed strides rank (") 57 << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() 58 << ") so the rank of the result type is well-formed."; 59 60 if (failed(verifyListOfOperandsOrIntegers(op, "offset", maxRanks[0], 61 op.static_offsets(), op.offsets()))) 62 return failure(); 63 if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1], 64 op.static_sizes(), op.sizes()))) 65 return failure(); 66 if (failed(verifyListOfOperandsOrIntegers(op, "stride", maxRanks[2], 67 op.static_strides(), op.strides()))) 68 return failure(); 69 return success(); 70 } 71 72 static char getLeftDelimiter(AsmParser::Delimiter delimiter) { 73 switch (delimiter) { 74 case AsmParser::Delimiter::Paren: 75 return '('; 76 case AsmParser::Delimiter::LessGreater: 77 return '<'; 78 case AsmParser::Delimiter::Square: 79 return '['; 80 case AsmParser::Delimiter::Braces: 81 return '{'; 82 default: 83 llvm_unreachable("unsupported delimiter"); 84 } 85 } 86 87 static char getRightDelimiter(AsmParser::Delimiter delimiter) { 88 switch (delimiter) { 89 case AsmParser::Delimiter::Paren: 90 return ')'; 91 case AsmParser::Delimiter::LessGreater: 92 return '>'; 93 case AsmParser::Delimiter::Square: 94 return ']'; 95 case AsmParser::Delimiter::Braces: 96 return '}'; 97 default: 98 llvm_unreachable("unsupported delimiter"); 99 } 100 } 101 102 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, 103 OperandRange values, 104 ArrayRef<int64_t> integers, 105 TypeRange valueTypes, ArrayRef<bool> scalables, 106 AsmParser::Delimiter delimiter) { 107 char leftDelimiter = getLeftDelimiter(delimiter); 108 char rightDelimiter = getRightDelimiter(delimiter); 109 printer << leftDelimiter; 110 if (integers.empty()) { 111 printer << rightDelimiter; 112 return; 113 } 114 115 unsigned dynamicValIdx = 0; 116 unsigned scalableIndexIdx = 0; 117 llvm::interleaveComma(integers, printer, [&](int64_t integer) { 118 if (!scalables.empty() && scalables[scalableIndexIdx]) 119 printer << "["; 120 if (ShapedType::isDynamic(integer)) { 121 printer << values[dynamicValIdx]; 122 if (!valueTypes.empty()) 123 printer << " : " << valueTypes[dynamicValIdx]; 124 ++dynamicValIdx; 125 } else { 126 printer << integer; 127 } 128 if (!scalables.empty() && scalables[scalableIndexIdx]) 129 printer << "]"; 130 131 scalableIndexIdx++; 132 }); 133 134 printer << rightDelimiter; 135 } 136 137 ParseResult mlir::parseDynamicIndexList( 138 OpAsmParser &parser, 139 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, 140 DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables, 141 SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) { 142 143 SmallVector<int64_t, 4> integerVals; 144 SmallVector<bool, 4> scalableVals; 145 auto parseIntegerOrValue = [&]() { 146 OpAsmParser::UnresolvedOperand operand; 147 auto res = parser.parseOptionalOperand(operand); 148 149 // When encountering `[`, assume that this is a scalable index. 150 scalableVals.push_back(parser.parseOptionalLSquare().succeeded()); 151 152 if (res.has_value() && succeeded(res.value())) { 153 values.push_back(operand); 154 integerVals.push_back(ShapedType::kDynamic); 155 if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) 156 return failure(); 157 } else { 158 int64_t integer; 159 if (failed(parser.parseInteger(integer))) 160 return failure(); 161 integerVals.push_back(integer); 162 } 163 164 // If this is assumed to be a scalable index, verify that there's a closing 165 // `]`. 166 if (scalableVals.back() && parser.parseOptionalRSquare().failed()) 167 return failure(); 168 return success(); 169 }; 170 if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue, 171 " in dynamic index list")) 172 return parser.emitError(parser.getNameLoc()) 173 << "expected SSA value or integer"; 174 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); 175 scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); 176 return success(); 177 } 178 179 bool mlir::detail::sameOffsetsSizesAndStrides( 180 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, 181 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) { 182 if (a.static_offsets().size() != b.static_offsets().size()) 183 return false; 184 if (a.static_sizes().size() != b.static_sizes().size()) 185 return false; 186 if (a.static_strides().size() != b.static_strides().size()) 187 return false; 188 for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) 189 if (!cmp(std::get<0>(it), std::get<1>(it))) 190 return false; 191 for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) 192 if (!cmp(std::get<0>(it), std::get<1>(it))) 193 return false; 194 for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) 195 if (!cmp(std::get<0>(it), std::get<1>(it))) 196 return false; 197 return true; 198 } 199