xref: /llvm-project/mlir/lib/Interfaces/ViewLikeInterface.cpp (revision baca3b382b09aa2488d3e619478b9c96f67b40b8)
1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
20 LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
21                                                    StringRef name,
22                                                    unsigned numElements,
23                                                    ArrayRef<int64_t> staticVals,
24                                                    ValueRange values) {
25   // Check static and dynamic offsets/sizes/strides does not overflow type.
26   if (staticVals.size() != numElements)
27     return op->emitError("expected ")
28            << numElements << " " << name << " values";
29   unsigned expectedNumDynamicEntries =
30       llvm::count_if(staticVals, [&](int64_t staticVal) {
31         return ShapedType::isDynamic(staticVal);
32       });
33   if (values.size() != expectedNumDynamicEntries)
34     return op->emitError("expected ")
35            << expectedNumDynamicEntries << " dynamic " << name << " values";
36   return success();
37 }
38 
39 LogicalResult
40 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
41   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
42   // Offsets can come in 2 flavors:
43   //   1. Either single entry (when maxRanks == 1).
44   //   2. Or as an array whose rank must match that of the mixed sizes.
45   // So that the result type is well-formed.
46   if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
47       op.getMixedOffsets().size() != op.getMixedSizes().size())
48     return op->emitError(
49                "expected mixed offsets rank to match mixed sizes rank (")
50            << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
51            << ") so the rank of the result type is well-formed.";
52   // Ranks of mixed sizes and strides must always match so the result type is
53   // well-formed.
54   if (op.getMixedSizes().size() != op.getMixedStrides().size())
55     return op->emitError(
56                "expected mixed sizes rank to match mixed strides rank (")
57            << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
58            << ") so the rank of the result type is well-formed.";
59 
60   if (failed(verifyListOfOperandsOrIntegers(op, "offset", maxRanks[0],
61                                             op.static_offsets(), op.offsets())))
62     return failure();
63   if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
64                                             op.static_sizes(), op.sizes())))
65     return failure();
66   if (failed(verifyListOfOperandsOrIntegers(op, "stride", maxRanks[2],
67                                             op.static_strides(), op.strides())))
68     return failure();
69   return success();
70 }
71 
72 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
73                                  OperandRange values,
74                                  ArrayRef<int64_t> integers) {
75   printer << '[';
76   if (integers.empty()) {
77     printer << "]";
78     return;
79   }
80   unsigned idx = 0;
81   llvm::interleaveComma(integers, printer, [&](int64_t integer) {
82     if (ShapedType::isDynamic(integer))
83       printer << values[idx++];
84     else
85       printer << integer;
86   });
87   printer << ']';
88 }
89 
90 ParseResult mlir::parseDynamicIndexList(
91     OpAsmParser &parser,
92     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
93     DenseI64ArrayAttr &integers) {
94 
95   SmallVector<int64_t, 4> integerVals;
96   auto parseIntegerOrValue = [&]() {
97     OpAsmParser::UnresolvedOperand operand;
98     auto res = parser.parseOptionalOperand(operand);
99     if (res.has_value() && succeeded(res.value())) {
100       values.push_back(operand);
101       integerVals.push_back(ShapedType::kDynamic);
102     } else {
103       int64_t integer;
104       if (failed(parser.parseInteger(integer)))
105         return failure();
106       integerVals.push_back(integer);
107     }
108     return success();
109   };
110   if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square,
111                                      parseIntegerOrValue,
112                                      " in dynamic index list"))
113     return parser.emitError(parser.getNameLoc())
114            << "expected SSA value or integer";
115   integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
116   return success();
117 }
118 
119 bool mlir::detail::sameOffsetsSizesAndStrides(
120     OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
121     llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
122   if (a.static_offsets().size() != b.static_offsets().size())
123     return false;
124   if (a.static_sizes().size() != b.static_sizes().size())
125     return false;
126   if (a.static_strides().size() != b.static_strides().size())
127     return false;
128   for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
129     if (!cmp(std::get<0>(it), std::get<1>(it)))
130       return false;
131   for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
132     if (!cmp(std::get<0>(it), std::get<1>(it)))
133       return false;
134   for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
135     if (!cmp(std::get<0>(it), std::get<1>(it)))
136       return false;
137   return true;
138 }
139