1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ViewLikeInterface.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // ViewLike Interfaces 15 //===----------------------------------------------------------------------===// 16 17 /// Include the definitions of the loop-like interfaces. 18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" 19 20 LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, 21 StringRef name, 22 unsigned numElements, 23 ArrayRef<int64_t> staticVals, 24 ValueRange values) { 25 // Check static and dynamic offsets/sizes/strides does not overflow type. 26 if (staticVals.size() != numElements) 27 return op->emitError("expected ") 28 << numElements << " " << name << " values"; 29 unsigned expectedNumDynamicEntries = 30 llvm::count_if(staticVals, [&](int64_t staticVal) { 31 return ShapedType::isDynamic(staticVal); 32 }); 33 if (values.size() != expectedNumDynamicEntries) 34 return op->emitError("expected ") 35 << expectedNumDynamicEntries << " dynamic " << name << " values"; 36 return success(); 37 } 38 39 LogicalResult 40 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { 41 std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks(); 42 // Offsets can come in 2 flavors: 43 // 1. Either single entry (when maxRanks == 1). 44 // 2. Or as an array whose rank must match that of the mixed sizes. 45 // So that the result type is well-formed. 46 if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT 47 op.getMixedOffsets().size() != op.getMixedSizes().size()) 48 return op->emitError( 49 "expected mixed offsets rank to match mixed sizes rank (") 50 << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() 51 << ") so the rank of the result type is well-formed."; 52 // Ranks of mixed sizes and strides must always match so the result type is 53 // well-formed. 54 if (op.getMixedSizes().size() != op.getMixedStrides().size()) 55 return op->emitError( 56 "expected mixed sizes rank to match mixed strides rank (") 57 << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() 58 << ") so the rank of the result type is well-formed."; 59 60 if (failed(verifyListOfOperandsOrIntegers(op, "offset", maxRanks[0], 61 op.static_offsets(), op.offsets()))) 62 return failure(); 63 if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1], 64 op.static_sizes(), op.sizes()))) 65 return failure(); 66 if (failed(verifyListOfOperandsOrIntegers(op, "stride", maxRanks[2], 67 op.static_strides(), op.strides()))) 68 return failure(); 69 return success(); 70 } 71 72 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, 73 OperandRange values, 74 ArrayRef<int64_t> integers) { 75 printer << '['; 76 if (integers.empty()) { 77 printer << "]"; 78 return; 79 } 80 unsigned idx = 0; 81 llvm::interleaveComma(integers, printer, [&](int64_t integer) { 82 if (ShapedType::isDynamic(integer)) 83 printer << values[idx++]; 84 else 85 printer << integer; 86 }); 87 printer << ']'; 88 } 89 90 ParseResult mlir::parseDynamicIndexList( 91 OpAsmParser &parser, 92 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, 93 DenseI64ArrayAttr &integers) { 94 95 SmallVector<int64_t, 4> integerVals; 96 auto parseIntegerOrValue = [&]() { 97 OpAsmParser::UnresolvedOperand operand; 98 auto res = parser.parseOptionalOperand(operand); 99 if (res.has_value() && succeeded(res.value())) { 100 values.push_back(operand); 101 integerVals.push_back(ShapedType::kDynamic); 102 } else { 103 int64_t integer; 104 if (failed(parser.parseInteger(integer))) 105 return failure(); 106 integerVals.push_back(integer); 107 } 108 return success(); 109 }; 110 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square, 111 parseIntegerOrValue, 112 " in dynamic index list")) 113 return parser.emitError(parser.getNameLoc()) 114 << "expected SSA value or integer"; 115 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); 116 return success(); 117 } 118 119 bool mlir::detail::sameOffsetsSizesAndStrides( 120 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, 121 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) { 122 if (a.static_offsets().size() != b.static_offsets().size()) 123 return false; 124 if (a.static_sizes().size() != b.static_sizes().size()) 125 return false; 126 if (a.static_strides().size() != b.static_strides().size()) 127 return false; 128 for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) 129 if (!cmp(std::get<0>(it), std::get<1>(it))) 130 return false; 131 for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) 132 if (!cmp(std::get<0>(it), std::get<1>(it))) 133 return false; 134 for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) 135 if (!cmp(std::get<0>(it), std::get<1>(it))) 136 return false; 137 return true; 138 } 139