1 //===- MatchInterfaces.h - Transform Dialect Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
11
12 #include <optional>
13 #include <type_traits>
14
15 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "llvm/ADT/STLExtras.h"
18
19 namespace mlir {
20 namespace transform {
21 class MatchOpInterface;
22
23 namespace detail {
24 /// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
25 /// first operand.
26 template <typename OpTy>
matchOptionalOperation(OpTy op,TransformResults & results,TransformState & state)27 DiagnosedSilenceableFailure matchOptionalOperation(OpTy op,
28 TransformResults &results,
29 TransformState &state) {
30 if constexpr (std::is_same_v<
31 typename llvm::function_traits<
32 decltype(&OpTy::matchOperation)>::template arg_t<0>,
33 Operation *>) {
34 return op.matchOperation(nullptr, results, state);
35 } else {
36 return op.matchOperation(std::nullopt, results, state);
37 }
38 }
39 } // namespace detail
40
41 template <typename OpTy>
42 class AtMostOneOpMatcherOpTrait
43 : public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
44 template <typename T>
45 using has_get_operand_handle =
46 decltype(std::declval<T &>().getOperandHandle());
47 template <typename T>
48 using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(
49 std::declval<Operation *>(), std::declval<TransformResults &>(),
50 std::declval<TransformState &>()));
51 template <typename T>
52 using has_match_operation_optional =
53 decltype(std::declval<T &>().matchOperation(
54 std::declval<std::optional<Operation *>>(),
55 std::declval<TransformResults &>(),
56 std::declval<TransformState &>()));
57
58 public:
verifyTrait(Operation * op)59 static LogicalResult verifyTrait(Operation *op) {
60 static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
61 "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
62 "operation type to have the getOperandHandle() method");
63 static_assert(
64 llvm::is_detected<has_match_operation_ptr, OpTy>::value ||
65 llvm::is_detected<has_match_operation_optional, OpTy>::value,
66 "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "
67 "type to have either the matchOperation(Operation *, TransformResults "
68 "&, TransformState &) or the matchOperation(std::optional<Operation*>, "
69 "TransformResults &, TransformState &) method");
70
71 // This must be a dynamic assert because interface registration is dynamic.
72 assert(
73 isa<MatchOpInterface>(op) &&
74 "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
75 "operations with MatchOpInterface");
76 Value operandHandle = cast<OpTy>(op).getOperandHandle();
77 if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
78 return op->emitError() << "AtMostOneOpMatcherOpTrait/"
79 "SingleOpMatchOpTrait requires the op handle "
80 "to be of TransformHandleTypeInterface";
81 }
82
83 return success();
84 }
85
apply(TransformRewriter & rewriter,TransformResults & results,TransformState & state)86 DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
87 TransformResults &results,
88 TransformState &state) {
89 Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
90 auto payload = state.getPayloadOps(operandHandle);
91 if (!llvm::hasNItemsOrLess(payload, 1)) {
92 return emitDefiniteFailure(this->getOperation()->getLoc())
93 << "AtMostOneOpMatcherOpTrait requires the operand handle to "
94 "point to at most one payload op";
95 }
96 if (payload.empty()) {
97 return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()),
98 results, state);
99 }
100 return cast<OpTy>(this->getOperation())
101 .matchOperation(*payload.begin(), results, state);
102 }
103
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)104 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
105 onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
106 producesHandle(this->getOperation()->getOpResults(), effects);
107 onlyReadsPayload(effects);
108 }
109 };
110
111 template <typename OpTy>
112 class SingleOpMatcherOpTrait : public AtMostOneOpMatcherOpTrait<OpTy> {
113
114 public:
apply(TransformRewriter & rewriter,TransformResults & results,TransformState & state)115 DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
116 TransformResults &results,
117 TransformState &state) {
118 Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
119 auto payload = state.getPayloadOps(operandHandle);
120 if (!llvm::hasSingleElement(payload)) {
121 return emitDefiniteFailure(this->getOperation()->getLoc())
122 << "SingleOpMatchOpTrait requires the operand handle to point to "
123 "a single payload op";
124 }
125 return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
126 rewriter, results, state);
127 }
128 };
129
130 template <typename OpTy>
131 class SingleValueMatcherOpTrait
132 : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
133 public:
verifyTrait(Operation * op)134 static LogicalResult verifyTrait(Operation *op) {
135 // This must be a dynamic assert because interface registration is
136 // dynamic.
137 assert(isa<MatchOpInterface>(op) &&
138 "SingleValueMatchOpTrait is only available on operations with "
139 "MatchOpInterface");
140
141 Value operandHandle = cast<OpTy>(op).getOperandHandle();
142 if (!isa<TransformValueHandleTypeInterface>(operandHandle.getType())) {
143 return op->emitError() << "SingleValueMatchOpTrait requires an operand "
144 "of TransformValueHandleTypeInterface";
145 }
146
147 return success();
148 }
149
apply(TransformRewriter & rewriter,TransformResults & results,TransformState & state)150 DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
151 TransformResults &results,
152 TransformState &state) {
153 Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
154 auto payload = state.getPayloadValues(operandHandle);
155 if (!llvm::hasSingleElement(payload)) {
156 return emitDefiniteFailure(this->getOperation()->getLoc())
157 << "SingleValueMatchOpTrait requires the value handle to point "
158 "to a single payload value";
159 }
160
161 return cast<OpTy>(this->getOperation())
162 .matchValue(*payload.begin(), results, state);
163 }
164
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)165 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
166 onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
167 producesHandle(this->getOperation()->getOpResults(), effects);
168 onlyReadsPayload(effects);
169 }
170 };
171
172 //===----------------------------------------------------------------------===//
173 // Printing/parsing for positional specification matchers
174 //===----------------------------------------------------------------------===//
175
176 /// Parses a positional index specification for transform match operations.
177 /// The following forms are accepted:
178 ///
179 /// - `all`: sets `isAll` and returns;
180 /// - comma-separated-integer-list: populates `rawDimList` with the values;
181 /// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
182 /// with the values and sets `isInverted`.
183 ParseResult parseTransformMatchDims(OpAsmParser &parser,
184 DenseI64ArrayAttr &rawDimList,
185 UnitAttr &isInverted, UnitAttr &isAll);
186
187 /// Prints a positional index specification for transform match operations.
188 void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
189 DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
190 UnitAttr isAll);
191
192 //===----------------------------------------------------------------------===//
193 // Utilities for positional specification matchers
194 //===----------------------------------------------------------------------===//
195
196 /// Checks if the positional specification defined is valid and reports errors
197 /// otherwise.
198 LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
199 bool inverted, bool all);
200
201 /// Populates `result` with the positional identifiers relative to `maxNumber`.
202 /// If `isAll` is set, the result will contain all numbers from `0` to
203 /// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
204 /// values from `rawList` are are interpreted as counting backwards from
205 /// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
206 /// numbers remain as is. If `isInverted` is set, populates `result` with those
207 /// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
208 /// `rawList`. If `rawList` contains values that are greater than or equal to
209 /// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
210 /// given location. `maxNumber` must be positive. If `rawList` contains
211 /// duplicate numbers or numbers that become duplicate after negative value
212 /// remapping, emits a silenceable error.
213 DiagnosedSilenceableFailure
214 expandTargetSpecification(Location loc, bool isAll, bool isInverted,
215 ArrayRef<int64_t> rawList, int64_t maxNumber,
216 SmallVectorImpl<int64_t> &result);
217
218 } // namespace transform
219 } // namespace mlir
220
221 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h.inc"
222
223 #endif // MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
224