xref: /llvm-project/mlir/lib/Dialect/Transform/Interfaces/MatchInterfaces.cpp (revision b7b337fb91f9b0538fcc4467ffca7c6c71192bc9)
1 //===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // Printing and parsing for match ops.
15 //===----------------------------------------------------------------------===//
16 
17 /// Keyword syntax for positional specification inversion.
18 constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
19 
20 /// Keyword syntax for full inclusion in positional specification.
21 constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
22 
parseTransformMatchDims(OpAsmParser & parser,DenseI64ArrayAttr & rawDimList,UnitAttr & isInverted,UnitAttr & isAll)23 ParseResult transform::parseTransformMatchDims(OpAsmParser &parser,
24                                                DenseI64ArrayAttr &rawDimList,
25                                                UnitAttr &isInverted,
26                                                UnitAttr &isAll) {
27   Builder &builder = parser.getBuilder();
28   if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
29     rawDimList = builder.getDenseI64ArrayAttr({});
30     isInverted = nullptr;
31     isAll = builder.getUnitAttr();
32     return success();
33   }
34 
35   isAll = nullptr;
36   isInverted = nullptr;
37   if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
38     isInverted = builder.getUnitAttr();
39   }
40 
41   if (isInverted) {
42     if (parser.parseLParen().failed())
43       return failure();
44   }
45 
46   SmallVector<int64_t> values;
47   ParseResult listResult = parser.parseCommaSeparatedList(
48       [&]() { return parser.parseInteger(values.emplace_back()); });
49   if (listResult.failed())
50     return failure();
51 
52   rawDimList = builder.getDenseI64ArrayAttr(values);
53 
54   if (isInverted) {
55     if (parser.parseRParen().failed())
56       return failure();
57   }
58   return success();
59 }
60 
printTransformMatchDims(OpAsmPrinter & printer,Operation * op,DenseI64ArrayAttr rawDimList,UnitAttr isInverted,UnitAttr isAll)61 void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
62                                         DenseI64ArrayAttr rawDimList,
63                                         UnitAttr isInverted, UnitAttr isAll) {
64   if (isAll) {
65     printer << kDimAllKeyword;
66     return;
67   }
68   if (isInverted) {
69     printer << kDimExceptKeyword << "(";
70   }
71   llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
72                         [&](int64_t value) { printer << value; });
73   if (isInverted) {
74     printer << ")";
75   }
76 }
77 
verifyTransformMatchDimsOp(Operation * op,ArrayRef<int64_t> raw,bool inverted,bool all)78 LogicalResult transform::verifyTransformMatchDimsOp(Operation *op,
79                                                     ArrayRef<int64_t> raw,
80                                                     bool inverted, bool all) {
81   if (all) {
82     if (inverted) {
83       return op->emitOpError()
84              << "cannot request both 'all' and 'inverted' values in the list";
85     }
86     if (!raw.empty()) {
87       return op->emitOpError()
88              << "cannot both request 'all' and specific values in the list";
89     }
90   }
91   if (!all && raw.empty()) {
92     return op->emitOpError() << "must request specific values in the list if "
93                                 "'all' is not specified";
94   }
95   SmallVector<int64_t> rawVector = llvm::to_vector(raw);
96   auto *it = llvm::unique(rawVector);
97   if (it != rawVector.end())
98     return op->emitOpError() << "expected the listed values to be unique";
99 
100   return success();
101 }
102 
expandTargetSpecification(Location loc,bool isAll,bool isInverted,ArrayRef<int64_t> rawList,int64_t maxNumber,SmallVectorImpl<int64_t> & result)103 DiagnosedSilenceableFailure transform::expandTargetSpecification(
104     Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
105     int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
106   assert(maxNumber > 0 && "expected size to be positive");
107   assert(!(isAll && isInverted) && "cannot invert all");
108   if (isAll) {
109     result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
110     return DiagnosedSilenceableFailure::success();
111   }
112 
113   SmallVector<int64_t> expanded;
114   llvm::SmallDenseSet<int64_t> visited;
115   expanded.reserve(rawList.size());
116   SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
117   for (int64_t raw : rawList) {
118     int64_t updated = raw < 0 ? maxNumber + raw : raw;
119     if (updated >= maxNumber) {
120       return emitSilenceableFailure(loc)
121              << "position overflow " << updated << " (updated from " << raw
122              << ") for maximum " << maxNumber;
123     }
124     if (updated < 0) {
125       return emitSilenceableFailure(loc) << "position underflow " << updated
126                                          << " (updated from " << raw << ")";
127     }
128     if (!visited.insert(updated).second) {
129       return emitSilenceableFailure(loc) << "repeated position " << updated
130                                          << " (updated from " << raw << ")";
131     }
132     target.push_back(updated);
133   }
134 
135   if (!isInverted)
136     return DiagnosedSilenceableFailure::success();
137 
138   result.reserve(result.size() + (maxNumber - expanded.size()));
139   for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
140     if (llvm::is_contained(expanded, candidate))
141       continue;
142     result.push_back(candidate);
143   }
144 
145   return DiagnosedSilenceableFailure::success();
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // Generated interface implementation.
150 //===----------------------------------------------------------------------===//
151 
152 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.cpp.inc"
153