xref: /llvm-project/mlir/lib/Interfaces/ViewLikeInterface.cpp (revision ef4800c9168ee45ced8295d13ac68f58b4358759)
1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
20 LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
21                                                    StringRef name,
22                                                    unsigned numElements,
23                                                    ArrayRef<int64_t> staticVals,
24                                                    ValueRange values) {
25   // Check static and dynamic offsets/sizes/strides does not overflow type.
26   if (staticVals.size() != numElements)
27     return op->emitError("expected ") << numElements << " " << name
28                                       << " values, got " << staticVals.size();
29   unsigned expectedNumDynamicEntries =
30       llvm::count_if(staticVals, [](int64_t staticVal) {
31         return ShapedType::isDynamic(staticVal);
32       });
33   if (values.size() != expectedNumDynamicEntries)
34     return op->emitError("expected ")
35            << expectedNumDynamicEntries << " dynamic " << name << " values";
36   return success();
37 }
38 
39 LogicalResult
40 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
41   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
42   // Offsets can come in 2 flavors:
43   //   1. Either single entry (when maxRanks == 1).
44   //   2. Or as an array whose rank must match that of the mixed sizes.
45   // So that the result type is well-formed.
46   if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
47       op.getMixedOffsets().size() != op.getMixedSizes().size())
48     return op->emitError(
49                "expected mixed offsets rank to match mixed sizes rank (")
50            << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
51            << ") so the rank of the result type is well-formed.";
52   // Ranks of mixed sizes and strides must always match so the result type is
53   // well-formed.
54   if (op.getMixedSizes().size() != op.getMixedStrides().size())
55     return op->emitError(
56                "expected mixed sizes rank to match mixed strides rank (")
57            << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
58            << ") so the rank of the result type is well-formed.";
59 
60   if (failed(verifyListOfOperandsOrIntegers(
61           op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets())))
62     return failure();
63   if (failed(verifyListOfOperandsOrIntegers(
64           op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes())))
65     return failure();
66   if (failed(verifyListOfOperandsOrIntegers(
67           op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
68     return failure();
69 
70   for (int64_t offset : op.getStaticOffsets()) {
71     if (offset < 0 && !ShapedType::isDynamic(offset))
72       return op->emitError("expected offsets to be non-negative, but got ")
73              << offset;
74   }
75   for (int64_t size : op.getStaticSizes()) {
76     if (size < 0 && !ShapedType::isDynamic(size))
77       return op->emitError("expected sizes to be non-negative, but got ")
78              << size;
79   }
80   return success();
81 }
82 
83 static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
84   switch (delimiter) {
85   case AsmParser::Delimiter::Paren:
86     return '(';
87   case AsmParser::Delimiter::LessGreater:
88     return '<';
89   case AsmParser::Delimiter::Square:
90     return '[';
91   case AsmParser::Delimiter::Braces:
92     return '{';
93   default:
94     llvm_unreachable("unsupported delimiter");
95   }
96 }
97 
98 static char getRightDelimiter(AsmParser::Delimiter delimiter) {
99   switch (delimiter) {
100   case AsmParser::Delimiter::Paren:
101     return ')';
102   case AsmParser::Delimiter::LessGreater:
103     return '>';
104   case AsmParser::Delimiter::Square:
105     return ']';
106   case AsmParser::Delimiter::Braces:
107     return '}';
108   default:
109     llvm_unreachable("unsupported delimiter");
110   }
111 }
112 
113 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
114                                  OperandRange values,
115                                  ArrayRef<int64_t> integers,
116                                  ArrayRef<bool> scalableFlags,
117                                  TypeRange valueTypes,
118                                  AsmParser::Delimiter delimiter) {
119   char leftDelimiter = getLeftDelimiter(delimiter);
120   char rightDelimiter = getRightDelimiter(delimiter);
121   printer << leftDelimiter;
122   if (integers.empty()) {
123     printer << rightDelimiter;
124     return;
125   }
126 
127   unsigned dynamicValIdx = 0;
128   unsigned scalableIndexIdx = 0;
129   llvm::interleaveComma(integers, printer, [&](int64_t integer) {
130     if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
131       printer << "[";
132     if (ShapedType::isDynamic(integer)) {
133       printer << values[dynamicValIdx];
134       if (!valueTypes.empty())
135         printer << " : " << valueTypes[dynamicValIdx];
136       ++dynamicValIdx;
137     } else {
138       printer << integer;
139     }
140     if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
141       printer << "]";
142 
143     scalableIndexIdx++;
144   });
145 
146   printer << rightDelimiter;
147 }
148 
149 ParseResult mlir::parseDynamicIndexList(
150     OpAsmParser &parser,
151     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
152     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
153     SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
154 
155   SmallVector<int64_t, 4> integerVals;
156   SmallVector<bool, 4> scalableVals;
157   auto parseIntegerOrValue = [&]() {
158     OpAsmParser::UnresolvedOperand operand;
159     auto res = parser.parseOptionalOperand(operand);
160 
161     // When encountering `[`, assume that this is a scalable index.
162     scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
163 
164     if (res.has_value() && succeeded(res.value())) {
165       values.push_back(operand);
166       integerVals.push_back(ShapedType::kDynamic);
167       if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
168         return failure();
169     } else {
170       int64_t integer;
171       if (failed(parser.parseInteger(integer)))
172         return failure();
173       integerVals.push_back(integer);
174     }
175 
176     // If this is assumed to be a scalable index, verify that there's a closing
177     // `]`.
178     if (scalableVals.back() && parser.parseOptionalRSquare().failed())
179       return failure();
180     return success();
181   };
182   if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
183                                      " in dynamic index list"))
184     return parser.emitError(parser.getNameLoc())
185            << "expected SSA value or integer";
186   integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
187   scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
188   return success();
189 }
190 
191 bool mlir::detail::sameOffsetsSizesAndStrides(
192     OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
193     llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
194   if (a.getStaticOffsets().size() != b.getStaticOffsets().size())
195     return false;
196   if (a.getStaticSizes().size() != b.getStaticSizes().size())
197     return false;
198   if (a.getStaticStrides().size() != b.getStaticStrides().size())
199     return false;
200   for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
201     if (!cmp(std::get<0>(it), std::get<1>(it)))
202       return false;
203   for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
204     if (!cmp(std::get<0>(it), std::get<1>(it)))
205       return false;
206   for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
207     if (!cmp(std::get<0>(it), std::get<1>(it)))
208       return false;
209   return true;
210 }
211 
212 unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
213                                                    unsigned idx) {
214   return std::count_if(staticVals.begin(), staticVals.begin() + idx,
215                        [&](int64_t val) { return ShapedType::isDynamic(val); });
216 }
217