12458cd27SLei Zhang //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// 22458cd27SLei Zhang // 32458cd27SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42458cd27SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 52458cd27SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62458cd27SLei Zhang // 72458cd27SLei Zhang //===----------------------------------------------------------------------===// 82458cd27SLei Zhang 92458cd27SLei Zhang #include "mlir/Interfaces/ViewLikeInterface.h" 102458cd27SLei Zhang 112458cd27SLei Zhang using namespace mlir; 122458cd27SLei Zhang 132458cd27SLei Zhang //===----------------------------------------------------------------------===// 142458cd27SLei Zhang // ViewLike Interfaces 152458cd27SLei Zhang //===----------------------------------------------------------------------===// 162458cd27SLei Zhang 172458cd27SLei Zhang /// Include the definitions of the loop-like interfaces. 182458cd27SLei Zhang #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" 19a8de412fSNicolas Vasilache 2065b72a78SAlexander Belyaev LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, 2165b72a78SAlexander Belyaev StringRef name, 2265b72a78SAlexander Belyaev unsigned numElements, 23a9733b8aSLorenzo Chelini ArrayRef<int64_t> staticVals, 2465b72a78SAlexander Belyaev ValueRange values) { 25a9733b8aSLorenzo Chelini // Check static and dynamic offsets/sizes/strides does not overflow type. 26a9733b8aSLorenzo Chelini if (staticVals.size() != numElements) 27ee308c99SJacques Pienaar return op->emitError("expected ") << numElements << " " << name 28ee308c99SJacques Pienaar << " values, got " << staticVals.size(); 29a8de412fSNicolas Vasilache unsigned expectedNumDynamicEntries = 30e0b19e95SMehdi Amini llvm::count_if(staticVals, [](int64_t staticVal) { 31a9733b8aSLorenzo Chelini return ShapedType::isDynamic(staticVal); 32a8de412fSNicolas Vasilache }); 33a8de412fSNicolas Vasilache if (values.size() != expectedNumDynamicEntries) 34118a7156SMaheshRavishankar return op->emitError("expected ") 35a8de412fSNicolas Vasilache << expectedNumDynamicEntries << " dynamic " << name << " values"; 36a8de412fSNicolas Vasilache return success(); 37a8de412fSNicolas Vasilache } 38a8de412fSNicolas Vasilache 3962851ea7SUday Bondhugula LogicalResult 4062851ea7SUday Bondhugula mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { 415133673dSNicolas Vasilache std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks(); 425133673dSNicolas Vasilache // Offsets can come in 2 flavors: 435133673dSNicolas Vasilache // 1. Either single entry (when maxRanks == 1). 445133673dSNicolas Vasilache // 2. Or as an array whose rank must match that of the mixed sizes. 455133673dSNicolas Vasilache // So that the result type is well-formed. 4684124ff8SMehdi Amini if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT 475133673dSNicolas Vasilache op.getMixedOffsets().size() != op.getMixedSizes().size()) 485133673dSNicolas Vasilache return op->emitError( 495133673dSNicolas Vasilache "expected mixed offsets rank to match mixed sizes rank (") 505133673dSNicolas Vasilache << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() 515133673dSNicolas Vasilache << ") so the rank of the result type is well-formed."; 525133673dSNicolas Vasilache // Ranks of mixed sizes and strides must always match so the result type is 535133673dSNicolas Vasilache // well-formed. 545133673dSNicolas Vasilache if (op.getMixedSizes().size() != op.getMixedStrides().size()) 555133673dSNicolas Vasilache return op->emitError( 565133673dSNicolas Vasilache "expected mixed sizes rank to match mixed strides rank (") 575133673dSNicolas Vasilache << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() 585133673dSNicolas Vasilache << ") so the rank of the result type is well-formed."; 595133673dSNicolas Vasilache 60a1f04b70SMatthias Springer if (failed(verifyListOfOperandsOrIntegers( 61a1f04b70SMatthias Springer op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets()))) 62a8de412fSNicolas Vasilache return failure(); 63a1f04b70SMatthias Springer if (failed(verifyListOfOperandsOrIntegers( 64a1f04b70SMatthias Springer op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes()))) 65a8de412fSNicolas Vasilache return failure(); 66a1f04b70SMatthias Springer if (failed(verifyListOfOperandsOrIntegers( 67a1f04b70SMatthias Springer op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides()))) 68a8de412fSNicolas Vasilache return failure(); 691949fe90SRik Huijzer 701949fe90SRik Huijzer for (int64_t offset : op.getStaticOffsets()) { 711949fe90SRik Huijzer if (offset < 0 && !ShapedType::isDynamic(offset)) 721949fe90SRik Huijzer return op->emitError("expected offsets to be non-negative, but got ") 731949fe90SRik Huijzer << offset; 741949fe90SRik Huijzer } 751949fe90SRik Huijzer for (int64_t size : op.getStaticSizes()) { 761949fe90SRik Huijzer if (size < 0 && !ShapedType::isDynamic(size)) 771949fe90SRik Huijzer return op->emitError("expected sizes to be non-negative, but got ") 781949fe90SRik Huijzer << size; 791949fe90SRik Huijzer } 80a8de412fSNicolas Vasilache return success(); 81a8de412fSNicolas Vasilache } 82b6c71c13SNicolas Vasilache 83310deca2SAlexander Belyaev static char getLeftDelimiter(AsmParser::Delimiter delimiter) { 84310deca2SAlexander Belyaev switch (delimiter) { 85310deca2SAlexander Belyaev case AsmParser::Delimiter::Paren: 86310deca2SAlexander Belyaev return '('; 87310deca2SAlexander Belyaev case AsmParser::Delimiter::LessGreater: 88310deca2SAlexander Belyaev return '<'; 89310deca2SAlexander Belyaev case AsmParser::Delimiter::Square: 90310deca2SAlexander Belyaev return '['; 91310deca2SAlexander Belyaev case AsmParser::Delimiter::Braces: 92310deca2SAlexander Belyaev return '{'; 93310deca2SAlexander Belyaev default: 94310deca2SAlexander Belyaev llvm_unreachable("unsupported delimiter"); 95310deca2SAlexander Belyaev } 96310deca2SAlexander Belyaev } 97310deca2SAlexander Belyaev 98310deca2SAlexander Belyaev static char getRightDelimiter(AsmParser::Delimiter delimiter) { 99310deca2SAlexander Belyaev switch (delimiter) { 100310deca2SAlexander Belyaev case AsmParser::Delimiter::Paren: 101310deca2SAlexander Belyaev return ')'; 102310deca2SAlexander Belyaev case AsmParser::Delimiter::LessGreater: 103310deca2SAlexander Belyaev return '>'; 104310deca2SAlexander Belyaev case AsmParser::Delimiter::Square: 105310deca2SAlexander Belyaev return ']'; 106310deca2SAlexander Belyaev case AsmParser::Delimiter::Braces: 107310deca2SAlexander Belyaev return '}'; 108310deca2SAlexander Belyaev default: 109310deca2SAlexander Belyaev llvm_unreachable("unsupported delimiter"); 110310deca2SAlexander Belyaev } 111310deca2SAlexander Belyaev } 112310deca2SAlexander Belyaev 113a2ad3ec7SJeff Niu void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, 114a9733b8aSLorenzo Chelini OperandRange values, 115310deca2SAlexander Belyaev ArrayRef<int64_t> integers, 116*ef4800c9SDiego Caballero ArrayRef<bool> scalableFlags, 117*ef4800c9SDiego Caballero TypeRange valueTypes, 1187a52f791SAndrzej Warzynski AsmParser::Delimiter delimiter) { 119310deca2SAlexander Belyaev char leftDelimiter = getLeftDelimiter(delimiter); 120310deca2SAlexander Belyaev char rightDelimiter = getRightDelimiter(delimiter); 121310deca2SAlexander Belyaev printer << leftDelimiter; 122a2ad3ec7SJeff Niu if (integers.empty()) { 123310deca2SAlexander Belyaev printer << rightDelimiter; 124342d4662SMaheshRavishankar return; 125342d4662SMaheshRavishankar } 126726835cdSAndrzej Warzynski 127ad7ef192SAndrzej Warzynski unsigned dynamicValIdx = 0; 128ad7ef192SAndrzej Warzynski unsigned scalableIndexIdx = 0; 129a9733b8aSLorenzo Chelini llvm::interleaveComma(integers, printer, [&](int64_t integer) { 130*ef4800c9SDiego Caballero if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx]) 131ad7ef192SAndrzej Warzynski printer << "["; 1322fe4d90cSAlex Zinenko if (ShapedType::isDynamic(integer)) { 133ad7ef192SAndrzej Warzynski printer << values[dynamicValIdx]; 1342fe4d90cSAlex Zinenko if (!valueTypes.empty()) 135ad7ef192SAndrzej Warzynski printer << " : " << valueTypes[dynamicValIdx]; 136ad7ef192SAndrzej Warzynski ++dynamicValIdx; 1372fe4d90cSAlex Zinenko } else { 138a9733b8aSLorenzo Chelini printer << integer; 1392fe4d90cSAlex Zinenko } 140*ef4800c9SDiego Caballero if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx]) 14114d073b5SAlexander Belyaev printer << "]"; 142ad7ef192SAndrzej Warzynski 143ad7ef192SAndrzej Warzynski scalableIndexIdx++; 144ad7ef192SAndrzej Warzynski }); 14514d073b5SAlexander Belyaev 146310deca2SAlexander Belyaev printer << rightDelimiter; 147c2470810SNicolas Vasilache } 148c2470810SNicolas Vasilache 149a2ad3ec7SJeff Niu ParseResult mlir::parseDynamicIndexList( 150e13d23bcSMarkus Böck OpAsmParser &parser, 151e13d23bcSMarkus Böck SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, 152*ef4800c9SDiego Caballero DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, 153a5b3677dSAndrzej Warzynski SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) { 154b6c71c13SNicolas Vasilache 155a9733b8aSLorenzo Chelini SmallVector<int64_t, 4> integerVals; 156ad7ef192SAndrzej Warzynski SmallVector<bool, 4> scalableVals; 157baca3b38SLorenzo Chelini auto parseIntegerOrValue = [&]() { 158e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand operand; 159b6c71c13SNicolas Vasilache auto res = parser.parseOptionalOperand(operand); 160a5b3677dSAndrzej Warzynski 161ad7ef192SAndrzej Warzynski // When encountering `[`, assume that this is a scalable index. 162ad7ef192SAndrzej Warzynski scalableVals.push_back(parser.parseOptionalLSquare().succeeded()); 163a5b3677dSAndrzej Warzynski 164c8e6ebd7SKazu Hirata if (res.has_value() && succeeded(res.value())) { 165342d4662SMaheshRavishankar values.push_back(operand); 166a9733b8aSLorenzo Chelini integerVals.push_back(ShapedType::kDynamic); 1672fe4d90cSAlex Zinenko if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) 1682fe4d90cSAlex Zinenko return failure(); 169b6c71c13SNicolas Vasilache } else { 170a9733b8aSLorenzo Chelini int64_t integer; 171a9733b8aSLorenzo Chelini if (failed(parser.parseInteger(integer))) 172baca3b38SLorenzo Chelini return failure(); 173a9733b8aSLorenzo Chelini integerVals.push_back(integer); 174b6c71c13SNicolas Vasilache } 175ad7ef192SAndrzej Warzynski 176ad7ef192SAndrzej Warzynski // If this is assumed to be a scalable index, verify that there's a closing 177ad7ef192SAndrzej Warzynski // `]`. 178ad7ef192SAndrzej Warzynski if (scalableVals.back() && parser.parseOptionalRSquare().failed()) 179a5b3677dSAndrzej Warzynski return failure(); 180baca3b38SLorenzo Chelini return success(); 181baca3b38SLorenzo Chelini }; 182310deca2SAlexander Belyaev if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue, 183baca3b38SLorenzo Chelini " in dynamic index list")) 184baca3b38SLorenzo Chelini return parser.emitError(parser.getNameLoc()) 185baca3b38SLorenzo Chelini << "expected SSA value or integer"; 186a9733b8aSLorenzo Chelini integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); 187*ef4800c9SDiego Caballero scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); 188b6c71c13SNicolas Vasilache return success(); 189b6c71c13SNicolas Vasilache } 190b6c71c13SNicolas Vasilache 191ce4f99e7SNicolas Vasilache bool mlir::detail::sameOffsetsSizesAndStrides( 192ce4f99e7SNicolas Vasilache OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, 193ce4f99e7SNicolas Vasilache llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) { 194a1f04b70SMatthias Springer if (a.getStaticOffsets().size() != b.getStaticOffsets().size()) 195ce4f99e7SNicolas Vasilache return false; 196a1f04b70SMatthias Springer if (a.getStaticSizes().size() != b.getStaticSizes().size()) 197ce4f99e7SNicolas Vasilache return false; 198a1f04b70SMatthias Springer if (a.getStaticStrides().size() != b.getStaticStrides().size()) 199ce4f99e7SNicolas Vasilache return false; 200ce4f99e7SNicolas Vasilache for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) 201ce4f99e7SNicolas Vasilache if (!cmp(std::get<0>(it), std::get<1>(it))) 202ce4f99e7SNicolas Vasilache return false; 203ce4f99e7SNicolas Vasilache for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) 204ce4f99e7SNicolas Vasilache if (!cmp(std::get<0>(it), std::get<1>(it))) 205ce4f99e7SNicolas Vasilache return false; 206ce4f99e7SNicolas Vasilache for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) 207ce4f99e7SNicolas Vasilache if (!cmp(std::get<0>(it), std::get<1>(it))) 208ce4f99e7SNicolas Vasilache return false; 209ce4f99e7SNicolas Vasilache return true; 210ce4f99e7SNicolas Vasilache } 211b2d1de2dSMatthias Springer 212b2d1de2dSMatthias Springer unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals, 213b2d1de2dSMatthias Springer unsigned idx) { 214b2d1de2dSMatthias Springer return std::count_if(staticVals.begin(), staticVals.begin() + idx, 215b2d1de2dSMatthias Springer [&](int64_t val) { return ShapedType::isDynamic(val); }); 216b2d1de2dSMatthias Springer } 217