xref: /llvm-project/mlir/lib/Interfaces/ViewLikeInterface.cpp (revision ef4800c9168ee45ced8295d13ac68f58b4358759)
12458cd27SLei Zhang //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
22458cd27SLei Zhang //
32458cd27SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42458cd27SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
52458cd27SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62458cd27SLei Zhang //
72458cd27SLei Zhang //===----------------------------------------------------------------------===//
82458cd27SLei Zhang 
92458cd27SLei Zhang #include "mlir/Interfaces/ViewLikeInterface.h"
102458cd27SLei Zhang 
112458cd27SLei Zhang using namespace mlir;
122458cd27SLei Zhang 
132458cd27SLei Zhang //===----------------------------------------------------------------------===//
142458cd27SLei Zhang // ViewLike Interfaces
152458cd27SLei Zhang //===----------------------------------------------------------------------===//
162458cd27SLei Zhang 
172458cd27SLei Zhang /// Include the definitions of the loop-like interfaces.
182458cd27SLei Zhang #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19a8de412fSNicolas Vasilache 
2065b72a78SAlexander Belyaev LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
2165b72a78SAlexander Belyaev                                                    StringRef name,
2265b72a78SAlexander Belyaev                                                    unsigned numElements,
23a9733b8aSLorenzo Chelini                                                    ArrayRef<int64_t> staticVals,
2465b72a78SAlexander Belyaev                                                    ValueRange values) {
25a9733b8aSLorenzo Chelini   // Check static and dynamic offsets/sizes/strides does not overflow type.
26a9733b8aSLorenzo Chelini   if (staticVals.size() != numElements)
27ee308c99SJacques Pienaar     return op->emitError("expected ") << numElements << " " << name
28ee308c99SJacques Pienaar                                       << " values, got " << staticVals.size();
29a8de412fSNicolas Vasilache   unsigned expectedNumDynamicEntries =
30e0b19e95SMehdi Amini       llvm::count_if(staticVals, [](int64_t staticVal) {
31a9733b8aSLorenzo Chelini         return ShapedType::isDynamic(staticVal);
32a8de412fSNicolas Vasilache       });
33a8de412fSNicolas Vasilache   if (values.size() != expectedNumDynamicEntries)
34118a7156SMaheshRavishankar     return op->emitError("expected ")
35a8de412fSNicolas Vasilache            << expectedNumDynamicEntries << " dynamic " << name << " values";
36a8de412fSNicolas Vasilache   return success();
37a8de412fSNicolas Vasilache }
38a8de412fSNicolas Vasilache 
3962851ea7SUday Bondhugula LogicalResult
4062851ea7SUday Bondhugula mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
415133673dSNicolas Vasilache   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
425133673dSNicolas Vasilache   // Offsets can come in 2 flavors:
435133673dSNicolas Vasilache   //   1. Either single entry (when maxRanks == 1).
445133673dSNicolas Vasilache   //   2. Or as an array whose rank must match that of the mixed sizes.
455133673dSNicolas Vasilache   // So that the result type is well-formed.
4684124ff8SMehdi Amini   if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
475133673dSNicolas Vasilache       op.getMixedOffsets().size() != op.getMixedSizes().size())
485133673dSNicolas Vasilache     return op->emitError(
495133673dSNicolas Vasilache                "expected mixed offsets rank to match mixed sizes rank (")
505133673dSNicolas Vasilache            << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
515133673dSNicolas Vasilache            << ") so the rank of the result type is well-formed.";
525133673dSNicolas Vasilache   // Ranks of mixed sizes and strides must always match so the result type is
535133673dSNicolas Vasilache   // well-formed.
545133673dSNicolas Vasilache   if (op.getMixedSizes().size() != op.getMixedStrides().size())
555133673dSNicolas Vasilache     return op->emitError(
565133673dSNicolas Vasilache                "expected mixed sizes rank to match mixed strides rank (")
575133673dSNicolas Vasilache            << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
585133673dSNicolas Vasilache            << ") so the rank of the result type is well-formed.";
595133673dSNicolas Vasilache 
60a1f04b70SMatthias Springer   if (failed(verifyListOfOperandsOrIntegers(
61a1f04b70SMatthias Springer           op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets())))
62a8de412fSNicolas Vasilache     return failure();
63a1f04b70SMatthias Springer   if (failed(verifyListOfOperandsOrIntegers(
64a1f04b70SMatthias Springer           op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes())))
65a8de412fSNicolas Vasilache     return failure();
66a1f04b70SMatthias Springer   if (failed(verifyListOfOperandsOrIntegers(
67a1f04b70SMatthias Springer           op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
68a8de412fSNicolas Vasilache     return failure();
691949fe90SRik Huijzer 
701949fe90SRik Huijzer   for (int64_t offset : op.getStaticOffsets()) {
711949fe90SRik Huijzer     if (offset < 0 && !ShapedType::isDynamic(offset))
721949fe90SRik Huijzer       return op->emitError("expected offsets to be non-negative, but got ")
731949fe90SRik Huijzer              << offset;
741949fe90SRik Huijzer   }
751949fe90SRik Huijzer   for (int64_t size : op.getStaticSizes()) {
761949fe90SRik Huijzer     if (size < 0 && !ShapedType::isDynamic(size))
771949fe90SRik Huijzer       return op->emitError("expected sizes to be non-negative, but got ")
781949fe90SRik Huijzer              << size;
791949fe90SRik Huijzer   }
80a8de412fSNicolas Vasilache   return success();
81a8de412fSNicolas Vasilache }
82b6c71c13SNicolas Vasilache 
83310deca2SAlexander Belyaev static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
84310deca2SAlexander Belyaev   switch (delimiter) {
85310deca2SAlexander Belyaev   case AsmParser::Delimiter::Paren:
86310deca2SAlexander Belyaev     return '(';
87310deca2SAlexander Belyaev   case AsmParser::Delimiter::LessGreater:
88310deca2SAlexander Belyaev     return '<';
89310deca2SAlexander Belyaev   case AsmParser::Delimiter::Square:
90310deca2SAlexander Belyaev     return '[';
91310deca2SAlexander Belyaev   case AsmParser::Delimiter::Braces:
92310deca2SAlexander Belyaev     return '{';
93310deca2SAlexander Belyaev   default:
94310deca2SAlexander Belyaev     llvm_unreachable("unsupported delimiter");
95310deca2SAlexander Belyaev   }
96310deca2SAlexander Belyaev }
97310deca2SAlexander Belyaev 
98310deca2SAlexander Belyaev static char getRightDelimiter(AsmParser::Delimiter delimiter) {
99310deca2SAlexander Belyaev   switch (delimiter) {
100310deca2SAlexander Belyaev   case AsmParser::Delimiter::Paren:
101310deca2SAlexander Belyaev     return ')';
102310deca2SAlexander Belyaev   case AsmParser::Delimiter::LessGreater:
103310deca2SAlexander Belyaev     return '>';
104310deca2SAlexander Belyaev   case AsmParser::Delimiter::Square:
105310deca2SAlexander Belyaev     return ']';
106310deca2SAlexander Belyaev   case AsmParser::Delimiter::Braces:
107310deca2SAlexander Belyaev     return '}';
108310deca2SAlexander Belyaev   default:
109310deca2SAlexander Belyaev     llvm_unreachable("unsupported delimiter");
110310deca2SAlexander Belyaev   }
111310deca2SAlexander Belyaev }
112310deca2SAlexander Belyaev 
113a2ad3ec7SJeff Niu void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
114a9733b8aSLorenzo Chelini                                  OperandRange values,
115310deca2SAlexander Belyaev                                  ArrayRef<int64_t> integers,
116*ef4800c9SDiego Caballero                                  ArrayRef<bool> scalableFlags,
117*ef4800c9SDiego Caballero                                  TypeRange valueTypes,
1187a52f791SAndrzej Warzynski                                  AsmParser::Delimiter delimiter) {
119310deca2SAlexander Belyaev   char leftDelimiter = getLeftDelimiter(delimiter);
120310deca2SAlexander Belyaev   char rightDelimiter = getRightDelimiter(delimiter);
121310deca2SAlexander Belyaev   printer << leftDelimiter;
122a2ad3ec7SJeff Niu   if (integers.empty()) {
123310deca2SAlexander Belyaev     printer << rightDelimiter;
124342d4662SMaheshRavishankar     return;
125342d4662SMaheshRavishankar   }
126726835cdSAndrzej Warzynski 
127ad7ef192SAndrzej Warzynski   unsigned dynamicValIdx = 0;
128ad7ef192SAndrzej Warzynski   unsigned scalableIndexIdx = 0;
129a9733b8aSLorenzo Chelini   llvm::interleaveComma(integers, printer, [&](int64_t integer) {
130*ef4800c9SDiego Caballero     if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
131ad7ef192SAndrzej Warzynski       printer << "[";
1322fe4d90cSAlex Zinenko     if (ShapedType::isDynamic(integer)) {
133ad7ef192SAndrzej Warzynski       printer << values[dynamicValIdx];
1342fe4d90cSAlex Zinenko       if (!valueTypes.empty())
135ad7ef192SAndrzej Warzynski         printer << " : " << valueTypes[dynamicValIdx];
136ad7ef192SAndrzej Warzynski       ++dynamicValIdx;
1372fe4d90cSAlex Zinenko     } else {
138a9733b8aSLorenzo Chelini       printer << integer;
1392fe4d90cSAlex Zinenko     }
140*ef4800c9SDiego Caballero     if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
14114d073b5SAlexander Belyaev       printer << "]";
142ad7ef192SAndrzej Warzynski 
143ad7ef192SAndrzej Warzynski     scalableIndexIdx++;
144ad7ef192SAndrzej Warzynski   });
14514d073b5SAlexander Belyaev 
146310deca2SAlexander Belyaev   printer << rightDelimiter;
147c2470810SNicolas Vasilache }
148c2470810SNicolas Vasilache 
149a2ad3ec7SJeff Niu ParseResult mlir::parseDynamicIndexList(
150e13d23bcSMarkus Böck     OpAsmParser &parser,
151e13d23bcSMarkus Böck     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
152*ef4800c9SDiego Caballero     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
153a5b3677dSAndrzej Warzynski     SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
154b6c71c13SNicolas Vasilache 
155a9733b8aSLorenzo Chelini   SmallVector<int64_t, 4> integerVals;
156ad7ef192SAndrzej Warzynski   SmallVector<bool, 4> scalableVals;
157baca3b38SLorenzo Chelini   auto parseIntegerOrValue = [&]() {
158e13d23bcSMarkus Böck     OpAsmParser::UnresolvedOperand operand;
159b6c71c13SNicolas Vasilache     auto res = parser.parseOptionalOperand(operand);
160a5b3677dSAndrzej Warzynski 
161ad7ef192SAndrzej Warzynski     // When encountering `[`, assume that this is a scalable index.
162ad7ef192SAndrzej Warzynski     scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
163a5b3677dSAndrzej Warzynski 
164c8e6ebd7SKazu Hirata     if (res.has_value() && succeeded(res.value())) {
165342d4662SMaheshRavishankar       values.push_back(operand);
166a9733b8aSLorenzo Chelini       integerVals.push_back(ShapedType::kDynamic);
1672fe4d90cSAlex Zinenko       if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
1682fe4d90cSAlex Zinenko         return failure();
169b6c71c13SNicolas Vasilache     } else {
170a9733b8aSLorenzo Chelini       int64_t integer;
171a9733b8aSLorenzo Chelini       if (failed(parser.parseInteger(integer)))
172baca3b38SLorenzo Chelini         return failure();
173a9733b8aSLorenzo Chelini       integerVals.push_back(integer);
174b6c71c13SNicolas Vasilache     }
175ad7ef192SAndrzej Warzynski 
176ad7ef192SAndrzej Warzynski     // If this is assumed to be a scalable index, verify that there's a closing
177ad7ef192SAndrzej Warzynski     // `]`.
178ad7ef192SAndrzej Warzynski     if (scalableVals.back() && parser.parseOptionalRSquare().failed())
179a5b3677dSAndrzej Warzynski       return failure();
180baca3b38SLorenzo Chelini     return success();
181baca3b38SLorenzo Chelini   };
182310deca2SAlexander Belyaev   if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
183baca3b38SLorenzo Chelini                                      " in dynamic index list"))
184baca3b38SLorenzo Chelini     return parser.emitError(parser.getNameLoc())
185baca3b38SLorenzo Chelini            << "expected SSA value or integer";
186a9733b8aSLorenzo Chelini   integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
187*ef4800c9SDiego Caballero   scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
188b6c71c13SNicolas Vasilache   return success();
189b6c71c13SNicolas Vasilache }
190b6c71c13SNicolas Vasilache 
191ce4f99e7SNicolas Vasilache bool mlir::detail::sameOffsetsSizesAndStrides(
192ce4f99e7SNicolas Vasilache     OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
193ce4f99e7SNicolas Vasilache     llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
194a1f04b70SMatthias Springer   if (a.getStaticOffsets().size() != b.getStaticOffsets().size())
195ce4f99e7SNicolas Vasilache     return false;
196a1f04b70SMatthias Springer   if (a.getStaticSizes().size() != b.getStaticSizes().size())
197ce4f99e7SNicolas Vasilache     return false;
198a1f04b70SMatthias Springer   if (a.getStaticStrides().size() != b.getStaticStrides().size())
199ce4f99e7SNicolas Vasilache     return false;
200ce4f99e7SNicolas Vasilache   for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
201ce4f99e7SNicolas Vasilache     if (!cmp(std::get<0>(it), std::get<1>(it)))
202ce4f99e7SNicolas Vasilache       return false;
203ce4f99e7SNicolas Vasilache   for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
204ce4f99e7SNicolas Vasilache     if (!cmp(std::get<0>(it), std::get<1>(it)))
205ce4f99e7SNicolas Vasilache       return false;
206ce4f99e7SNicolas Vasilache   for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
207ce4f99e7SNicolas Vasilache     if (!cmp(std::get<0>(it), std::get<1>(it)))
208ce4f99e7SNicolas Vasilache       return false;
209ce4f99e7SNicolas Vasilache   return true;
210ce4f99e7SNicolas Vasilache }
211b2d1de2dSMatthias Springer 
212b2d1de2dSMatthias Springer unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
213b2d1de2dSMatthias Springer                                                    unsigned idx) {
214b2d1de2dSMatthias Springer   return std::count_if(staticVals.begin(), staticVals.begin() + idx,
215b2d1de2dSMatthias Springer                        [&](int64_t val) { return ShapedType::isDynamic(val); });
216b2d1de2dSMatthias Springer }
217