xref: /llvm-project/mlir/lib/Interfaces/ViewLikeInterface.cpp (revision 399638f98cdc2405c6a2e85f3cbba175fabaf858)
1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
20 LogicalResult mlir::verifyListOfOperandsOrIntegers(
21     Operation *op, StringRef name, unsigned numElements, ArrayAttr attr,
22     ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
23   /// Check static and dynamic offsets/sizes/strides does not overflow type.
24   if (attr.size() != numElements)
25     return op->emitError("expected ")
26            << numElements << " " << name << " values";
27   unsigned expectedNumDynamicEntries =
28       llvm::count_if(attr.getValue(), [&](Attribute attr) {
29         return isDynamic(attr.cast<IntegerAttr>().getInt());
30       });
31   if (values.size() != expectedNumDynamicEntries)
32     return op->emitError("expected ")
33            << expectedNumDynamicEntries << " dynamic " << name << " values";
34   return success();
35 }
36 
37 LogicalResult
38 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
39   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
40   // Offsets can come in 2 flavors:
41   //   1. Either single entry (when maxRanks == 1).
42   //   2. Or as an array whose rank must match that of the mixed sizes.
43   // So that the result type is well-formed.
44   if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
45       op.getMixedOffsets().size() != op.getMixedSizes().size())
46     return op->emitError(
47                "expected mixed offsets rank to match mixed sizes rank (")
48            << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
49            << ") so the rank of the result type is well-formed.";
50   // Ranks of mixed sizes and strides must always match so the result type is
51   // well-formed.
52   if (op.getMixedSizes().size() != op.getMixedStrides().size())
53     return op->emitError(
54                "expected mixed sizes rank to match mixed strides rank (")
55            << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
56            << ") so the rank of the result type is well-formed.";
57 
58   if (failed(verifyListOfOperandsOrIntegers(op, "offset", maxRanks[0],
59                                             op.static_offsets(), op.offsets(),
60                                             ShapedType::isDynamic)))
61     return failure();
62   if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
63                                             op.static_sizes(), op.sizes(),
64                                             ShapedType::isDynamic)))
65     return failure();
66   if (failed(verifyListOfOperandsOrIntegers(op, "stride", maxRanks[2],
67                                             op.static_strides(), op.strides(),
68                                             ShapedType::isDynamic)))
69     return failure();
70   return success();
71 }
72 
73 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
74                                  OperandRange values, ArrayAttr integers,
75                                  int64_t dynVal) {
76   printer << '[';
77   if (integers.empty()) {
78     printer << "]";
79     return;
80   }
81   unsigned idx = 0;
82   llvm::interleaveComma(integers, printer, [&](Attribute a) {
83     int64_t val = a.cast<IntegerAttr>().getInt();
84     if (val == dynVal)
85       printer << values[idx++];
86     else
87       printer << val;
88   });
89   printer << ']';
90 }
91 
92 ParseResult mlir::parseDynamicIndexList(
93     OpAsmParser &parser,
94     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
95     ArrayAttr &integers, int64_t dynVal) {
96   if (failed(parser.parseLSquare()))
97     return failure();
98   // 0-D.
99   if (succeeded(parser.parseOptionalRSquare())) {
100     integers = parser.getBuilder().getArrayAttr({});
101     return success();
102   }
103 
104   SmallVector<int64_t, 4> attrVals;
105   while (true) {
106     OpAsmParser::UnresolvedOperand operand;
107     auto res = parser.parseOptionalOperand(operand);
108     if (res.has_value() && succeeded(res.value())) {
109       values.push_back(operand);
110       attrVals.push_back(dynVal);
111     } else {
112       IntegerAttr attr;
113       if (failed(parser.parseAttribute<IntegerAttr>(attr)))
114         return parser.emitError(parser.getNameLoc())
115                << "expected SSA value or integer";
116       attrVals.push_back(attr.getInt());
117     }
118 
119     if (succeeded(parser.parseOptionalComma()))
120       continue;
121     if (failed(parser.parseRSquare()))
122       return failure();
123     break;
124   }
125   integers = parser.getBuilder().getI64ArrayAttr(attrVals);
126   return success();
127 }
128 
129 bool mlir::detail::sameOffsetsSizesAndStrides(
130     OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
131     llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
132   if (a.static_offsets().size() != b.static_offsets().size())
133     return false;
134   if (a.static_sizes().size() != b.static_sizes().size())
135     return false;
136   if (a.static_strides().size() != b.static_strides().size())
137     return false;
138   for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
139     if (!cmp(std::get<0>(it), std::get<1>(it)))
140       return false;
141   for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
142     if (!cmp(std::get<0>(it), std::get<1>(it)))
143       return false;
144   for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
145     if (!cmp(std::get<0>(it), std::get<1>(it)))
146       return false;
147   return true;
148 }
149 
150 SmallVector<OpFoldResult, 4>
151 mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues,
152                      const int64_t dynamicValueIndicator) {
153   SmallVector<OpFoldResult, 4> res;
154   res.reserve(staticValues.size());
155   unsigned numDynamic = 0;
156   unsigned count = static_cast<unsigned>(staticValues.size());
157   for (unsigned idx = 0; idx < count; ++idx) {
158     APInt value = staticValues[idx].cast<IntegerAttr>().getValue();
159     res.push_back(value.getSExtValue() == dynamicValueIndicator
160                       ? OpFoldResult{dynamicValues[numDynamic++]}
161                       : OpFoldResult{staticValues[idx]});
162   }
163   return res;
164 }
165 
166 SmallVector<OpFoldResult, 4>
167 mlir::getMixedStridesOrOffsets(ArrayAttr staticValues,
168                                ValueRange dynamicValues) {
169   return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamic);
170 }
171 
172 SmallVector<OpFoldResult, 4> mlir::getMixedSizes(ArrayAttr staticValues,
173                                                  ValueRange dynamicValues) {
174   return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamic);
175 }
176 
177 std::pair<ArrayAttr, SmallVector<Value>>
178 mlir::decomposeMixedValues(Builder &b,
179                            const SmallVectorImpl<OpFoldResult> &mixedValues,
180                            const int64_t dynamicValueIndicator) {
181   SmallVector<int64_t> staticValues;
182   SmallVector<Value> dynamicValues;
183   for (const auto &it : mixedValues) {
184     if (it.is<Attribute>()) {
185       staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
186     } else {
187       staticValues.push_back(dynamicValueIndicator);
188       dynamicValues.push_back(it.get<Value>());
189     }
190   }
191   return {b.getI64ArrayAttr(staticValues), dynamicValues};
192 }
193 
194 std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets(
195     OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
196   return decomposeMixedValues(b, mixedValues, ShapedType::kDynamic);
197 }
198 
199 std::pair<ArrayAttr, SmallVector<Value>>
200 mlir::decomposeMixedSizes(OpBuilder &b,
201                           const SmallVectorImpl<OpFoldResult> &mixedValues) {
202   return decomposeMixedValues(b, mixedValues, ShapedType::kDynamic);
203 }
204