xref: /llvm-project/mlir/lib/Interfaces/ViewLikeInterface.cpp (revision c73053c26e6e17d23e10857f86075417e7fa166e)
1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
20 LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
21                                                    StringRef name,
22                                                    unsigned numElements,
23                                                    ArrayRef<int64_t> staticVals,
24                                                    ValueRange values) {
25   // Check static and dynamic offsets/sizes/strides does not overflow type.
26   if (staticVals.size() != numElements)
27     return op->emitError("expected ") << numElements << " " << name
28                                       << " values, got " << staticVals.size();
29   unsigned expectedNumDynamicEntries =
30       llvm::count_if(staticVals, [&](int64_t staticVal) {
31         return ShapedType::isDynamic(staticVal);
32       });
33   if (values.size() != expectedNumDynamicEntries)
34     return op->emitError("expected ")
35            << expectedNumDynamicEntries << " dynamic " << name << " values";
36   return success();
37 }
38 
39 LogicalResult
40 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
41   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
42   // Offsets can come in 2 flavors:
43   //   1. Either single entry (when maxRanks == 1).
44   //   2. Or as an array whose rank must match that of the mixed sizes.
45   // So that the result type is well-formed.
46   if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
47       op.getMixedOffsets().size() != op.getMixedSizes().size())
48     return op->emitError(
49                "expected mixed offsets rank to match mixed sizes rank (")
50            << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
51            << ") so the rank of the result type is well-formed.";
52   // Ranks of mixed sizes and strides must always match so the result type is
53   // well-formed.
54   if (op.getMixedSizes().size() != op.getMixedStrides().size())
55     return op->emitError(
56                "expected mixed sizes rank to match mixed strides rank (")
57            << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
58            << ") so the rank of the result type is well-formed.";
59 
60   if (failed(verifyListOfOperandsOrIntegers(op, "offset", maxRanks[0],
61                                             op.static_offsets(), op.offsets())))
62     return failure();
63   if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
64                                             op.static_sizes(), op.sizes())))
65     return failure();
66   if (failed(verifyListOfOperandsOrIntegers(op, "stride", maxRanks[2],
67                                             op.static_strides(), op.strides())))
68     return failure();
69   return success();
70 }
71 
72 static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
73   switch (delimiter) {
74   case AsmParser::Delimiter::Paren:
75     return '(';
76   case AsmParser::Delimiter::LessGreater:
77     return '<';
78   case AsmParser::Delimiter::Square:
79     return '[';
80   case AsmParser::Delimiter::Braces:
81     return '{';
82   default:
83     llvm_unreachable("unsupported delimiter");
84   }
85 }
86 
87 static char getRightDelimiter(AsmParser::Delimiter delimiter) {
88   switch (delimiter) {
89   case AsmParser::Delimiter::Paren:
90     return ')';
91   case AsmParser::Delimiter::LessGreater:
92     return '>';
93   case AsmParser::Delimiter::Square:
94     return ']';
95   case AsmParser::Delimiter::Braces:
96     return '}';
97   default:
98     llvm_unreachable("unsupported delimiter");
99   }
100 }
101 
102 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
103                                  OperandRange values,
104                                  ArrayRef<int64_t> integers,
105                                  TypeRange valueTypes, ArrayRef<bool> scalables,
106                                  AsmParser::Delimiter delimiter) {
107   char leftDelimiter = getLeftDelimiter(delimiter);
108   char rightDelimiter = getRightDelimiter(delimiter);
109   printer << leftDelimiter;
110   if (integers.empty()) {
111     printer << rightDelimiter;
112     return;
113   }
114 
115   unsigned dynamicValIdx = 0;
116   unsigned scalableIndexIdx = 0;
117   llvm::interleaveComma(integers, printer, [&](int64_t integer) {
118     if (!scalables.empty() && scalables[scalableIndexIdx])
119       printer << "[";
120     if (ShapedType::isDynamic(integer)) {
121       printer << values[dynamicValIdx];
122       if (!valueTypes.empty())
123         printer << " : " << valueTypes[dynamicValIdx];
124       ++dynamicValIdx;
125     } else {
126       printer << integer;
127     }
128     if (!scalables.empty() && scalables[scalableIndexIdx])
129       printer << "]";
130 
131     scalableIndexIdx++;
132   });
133 
134   printer << rightDelimiter;
135 }
136 
137 ParseResult mlir::parseDynamicIndexList(
138     OpAsmParser &parser,
139     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
140     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
141     SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
142 
143   SmallVector<int64_t, 4> integerVals;
144   SmallVector<bool, 4> scalableVals;
145   auto parseIntegerOrValue = [&]() {
146     OpAsmParser::UnresolvedOperand operand;
147     auto res = parser.parseOptionalOperand(operand);
148 
149     // When encountering `[`, assume that this is a scalable index.
150     scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
151 
152     if (res.has_value() && succeeded(res.value())) {
153       values.push_back(operand);
154       integerVals.push_back(ShapedType::kDynamic);
155       if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
156         return failure();
157     } else {
158       int64_t integer;
159       if (failed(parser.parseInteger(integer)))
160         return failure();
161       integerVals.push_back(integer);
162     }
163 
164     // If this is assumed to be a scalable index, verify that there's a closing
165     // `]`.
166     if (scalableVals.back() && parser.parseOptionalRSquare().failed())
167       return failure();
168     return success();
169   };
170   if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
171                                      " in dynamic index list"))
172     return parser.emitError(parser.getNameLoc())
173            << "expected SSA value or integer";
174   integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
175   scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
176   return success();
177 }
178 
179 bool mlir::detail::sameOffsetsSizesAndStrides(
180     OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
181     llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
182   if (a.static_offsets().size() != b.static_offsets().size())
183     return false;
184   if (a.static_sizes().size() != b.static_sizes().size())
185     return false;
186   if (a.static_strides().size() != b.static_strides().size())
187     return false;
188   for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
189     if (!cmp(std::get<0>(it), std::get<1>(it)))
190       return false;
191   for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
192     if (!cmp(std::get<0>(it), std::get<1>(it)))
193       return false;
194   for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
195     if (!cmp(std::get<0>(it), std::get<1>(it)))
196       return false;
197   return true;
198 }
199