xref: /llvm-project/mlir/include/mlir/Dialect/Transform/Interfaces/MatchInterfaces.h (revision 2c1ae801e1b66a09a15028ae4ba614e0911eec00)
1 //===- MatchInterfaces.h - Transform Dialect Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
11 
12 #include <optional>
13 #include <type_traits>
14 
15 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "llvm/ADT/STLExtras.h"
18 
19 namespace mlir {
20 namespace transform {
21 class MatchOpInterface;
22 
23 namespace detail {
24 /// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
25 /// first operand.
26 template <typename OpTy>
matchOptionalOperation(OpTy op,TransformResults & results,TransformState & state)27 DiagnosedSilenceableFailure matchOptionalOperation(OpTy op,
28                                                    TransformResults &results,
29                                                    TransformState &state) {
30   if constexpr (std::is_same_v<
31                     typename llvm::function_traits<
32                         decltype(&OpTy::matchOperation)>::template arg_t<0>,
33                     Operation *>) {
34     return op.matchOperation(nullptr, results, state);
35   } else {
36     return op.matchOperation(std::nullopt, results, state);
37   }
38 }
39 } // namespace detail
40 
41 template <typename OpTy>
42 class AtMostOneOpMatcherOpTrait
43     : public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
44   template <typename T>
45   using has_get_operand_handle =
46       decltype(std::declval<T &>().getOperandHandle());
47   template <typename T>
48   using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(
49       std::declval<Operation *>(), std::declval<TransformResults &>(),
50       std::declval<TransformState &>()));
51   template <typename T>
52   using has_match_operation_optional =
53       decltype(std::declval<T &>().matchOperation(
54           std::declval<std::optional<Operation *>>(),
55           std::declval<TransformResults &>(),
56           std::declval<TransformState &>()));
57 
58 public:
verifyTrait(Operation * op)59   static LogicalResult verifyTrait(Operation *op) {
60     static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
61                   "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
62                   "operation type to have the getOperandHandle() method");
63     static_assert(
64         llvm::is_detected<has_match_operation_ptr, OpTy>::value ||
65             llvm::is_detected<has_match_operation_optional, OpTy>::value,
66         "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "
67         "type to have either the matchOperation(Operation *, TransformResults "
68         "&, TransformState &) or the matchOperation(std::optional<Operation*>, "
69         "TransformResults &, TransformState &) method");
70 
71     // This must be a dynamic assert because interface registration is dynamic.
72     assert(
73         isa<MatchOpInterface>(op) &&
74         "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
75         "operations with MatchOpInterface");
76     Value operandHandle = cast<OpTy>(op).getOperandHandle();
77     if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
78       return op->emitError() << "AtMostOneOpMatcherOpTrait/"
79                                 "SingleOpMatchOpTrait requires the op handle "
80                                 "to be of TransformHandleTypeInterface";
81     }
82 
83     return success();
84   }
85 
apply(TransformRewriter & rewriter,TransformResults & results,TransformState & state)86   DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
87                                     TransformResults &results,
88                                     TransformState &state) {
89     Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
90     auto payload = state.getPayloadOps(operandHandle);
91     if (!llvm::hasNItemsOrLess(payload, 1)) {
92       return emitDefiniteFailure(this->getOperation()->getLoc())
93              << "AtMostOneOpMatcherOpTrait requires the operand handle to "
94                 "point to at most one payload op";
95     }
96     if (payload.empty()) {
97       return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()),
98                                             results, state);
99     }
100     return cast<OpTy>(this->getOperation())
101         .matchOperation(*payload.begin(), results, state);
102   }
103 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)104   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
105     onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
106     producesHandle(this->getOperation()->getOpResults(), effects);
107     onlyReadsPayload(effects);
108   }
109 };
110 
111 template <typename OpTy>
112 class SingleOpMatcherOpTrait : public AtMostOneOpMatcherOpTrait<OpTy> {
113 
114 public:
apply(TransformRewriter & rewriter,TransformResults & results,TransformState & state)115   DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
116                                     TransformResults &results,
117                                     TransformState &state) {
118     Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
119     auto payload = state.getPayloadOps(operandHandle);
120     if (!llvm::hasSingleElement(payload)) {
121       return emitDefiniteFailure(this->getOperation()->getLoc())
122              << "SingleOpMatchOpTrait requires the operand handle to point to "
123                 "a single payload op";
124     }
125     return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
126         rewriter, results, state);
127   }
128 };
129 
130 template <typename OpTy>
131 class SingleValueMatcherOpTrait
132     : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
133 public:
verifyTrait(Operation * op)134   static LogicalResult verifyTrait(Operation *op) {
135     // This must be a dynamic assert because interface registration is
136     // dynamic.
137     assert(isa<MatchOpInterface>(op) &&
138            "SingleValueMatchOpTrait is only available on operations with "
139            "MatchOpInterface");
140 
141     Value operandHandle = cast<OpTy>(op).getOperandHandle();
142     if (!isa<TransformValueHandleTypeInterface>(operandHandle.getType())) {
143       return op->emitError() << "SingleValueMatchOpTrait requires an operand "
144                                 "of TransformValueHandleTypeInterface";
145     }
146 
147     return success();
148   }
149 
apply(TransformRewriter & rewriter,TransformResults & results,TransformState & state)150   DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
151                                     TransformResults &results,
152                                     TransformState &state) {
153     Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
154     auto payload = state.getPayloadValues(operandHandle);
155     if (!llvm::hasSingleElement(payload)) {
156       return emitDefiniteFailure(this->getOperation()->getLoc())
157              << "SingleValueMatchOpTrait requires the value handle to point "
158                 "to a single payload value";
159     }
160 
161     return cast<OpTy>(this->getOperation())
162         .matchValue(*payload.begin(), results, state);
163   }
164 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)165   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
166     onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
167     producesHandle(this->getOperation()->getOpResults(), effects);
168     onlyReadsPayload(effects);
169   }
170 };
171 
172 //===----------------------------------------------------------------------===//
173 // Printing/parsing for positional specification matchers
174 //===----------------------------------------------------------------------===//
175 
176 /// Parses a positional index specification for transform match operations.
177 /// The following forms are accepted:
178 ///
179 ///  - `all`: sets `isAll` and returns;
180 ///  - comma-separated-integer-list: populates `rawDimList` with the values;
181 ///  - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
182 ///  with the values and sets `isInverted`.
183 ParseResult parseTransformMatchDims(OpAsmParser &parser,
184                                     DenseI64ArrayAttr &rawDimList,
185                                     UnitAttr &isInverted, UnitAttr &isAll);
186 
187 /// Prints a positional index specification for transform match operations.
188 void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
189                              DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
190                              UnitAttr isAll);
191 
192 //===----------------------------------------------------------------------===//
193 // Utilities for positional specification matchers
194 //===----------------------------------------------------------------------===//
195 
196 /// Checks if the positional specification defined is valid and reports errors
197 /// otherwise.
198 LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
199                                          bool inverted, bool all);
200 
201 /// Populates `result` with the positional identifiers relative to `maxNumber`.
202 /// If `isAll` is set, the result will contain all numbers from `0` to
203 /// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
204 /// values from `rawList` are  are interpreted as counting backwards from
205 /// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
206 /// numbers remain as is. If `isInverted` is set, populates `result` with those
207 /// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
208 /// `rawList`. If `rawList` contains values that are greater than or equal to
209 /// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
210 /// given location. `maxNumber` must be positive. If `rawList` contains
211 /// duplicate numbers or numbers that become duplicate after negative value
212 /// remapping, emits a silenceable error.
213 DiagnosedSilenceableFailure
214 expandTargetSpecification(Location loc, bool isAll, bool isInverted,
215                           ArrayRef<int64_t> rawList, int64_t maxNumber,
216                           SmallVectorImpl<int64_t> &result);
217 
218 } // namespace transform
219 } // namespace mlir
220 
221 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h.inc"
222 
223 #endif // MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
224