xref: /llvm-project/mlir/include/mlir/IR/Matchers.h (revision 4f4e2abb1a5ff1225d32410fd02b732d077aa056)
1 //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
11 // include/llvm/IR/PatternMatch.h.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_IR_MATCHERS_H
16 #define MLIR_IR_MATCHERS_H
17 
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/OpDefinition.h"
21 #include "mlir/Interfaces/InferIntRangeInterface.h"
22 
23 namespace mlir {
24 
25 namespace detail {
26 
27 /// The matcher that matches a certain kind of Attribute and binds the value
28 /// inside the Attribute.
29 template <
30     typename AttrClass,
31     // Require AttrClass to be a derived class from Attribute and get its
32     // value type
33     typename ValueType = typename std::enable_if_t<
34         std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType,
35     // Require the ValueType is not void
36     typename = std::enable_if_t<!std::is_void<ValueType>::value>>
37 struct attr_value_binder {
38   ValueType *bind_value;
39 
40   /// Creates a matcher instance that binds the value to bv if match succeeds.
41   attr_value_binder(ValueType *bv) : bind_value(bv) {}
42 
43   bool match(Attribute attr) {
44     if (auto intAttr = llvm::dyn_cast<AttrClass>(attr)) {
45       *bind_value = intAttr.getValue();
46       return true;
47     }
48     return false;
49   }
50 };
51 
52 /// The matcher that matches operations that have the `ConstantLike` trait.
53 struct constant_op_matcher {
54   bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
55 };
56 
57 /// The matcher that matches operations that have the specified op name.
58 struct NameOpMatcher {
59   NameOpMatcher(StringRef name) : name(name) {}
60   bool match(Operation *op) { return op->getName().getStringRef() == name; }
61 
62   StringRef name;
63 };
64 
65 /// The matcher that matches operations that have the specified attribute name.
66 struct AttrOpMatcher {
67   AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
68   bool match(Operation *op) { return op->hasAttr(attrName); }
69 
70   StringRef attrName;
71 };
72 
73 /// The matcher that matches operations that have the `ConstantLike` trait, and
74 /// binds the folded attribute value.
75 template <typename AttrT>
76 struct constant_op_binder {
77   AttrT *bind_value;
78 
79   /// Creates a matcher instance that binds the constant attribute value to
80   /// bind_value if match succeeds.
81   constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
82   /// Creates a matcher instance that doesn't bind if match succeeds.
83   constant_op_binder() : bind_value(nullptr) {}
84 
85   bool match(Operation *op) {
86     if (!op->hasTrait<OpTrait::ConstantLike>())
87       return false;
88 
89     // Fold the constant to an attribute.
90     SmallVector<OpFoldResult, 1> foldedOp;
91     LogicalResult result = op->fold(/*operands=*/std::nullopt, foldedOp);
92     (void)result;
93     assert(succeeded(result) && "expected ConstantLike op to be foldable");
94 
95     if (auto attr = llvm::dyn_cast<AttrT>(cast<Attribute>(foldedOp.front()))) {
96       if (bind_value)
97         *bind_value = attr;
98       return true;
99     }
100     return false;
101   }
102 };
103 
104 /// A matcher that matches operations that implement the
105 /// `InferIntRangeInterface` interface, and binds the inferred range.
106 struct infer_int_range_op_binder {
107   IntegerValueRange *bind_value;
108 
109   explicit infer_int_range_op_binder(IntegerValueRange *bind_value)
110       : bind_value(bind_value) {}
111 
112   bool match(Operation *op) {
113     auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op);
114     if (!inferIntRangeOp)
115       return false;
116 
117     // Set the range of all integer operands to the maximal range.
118     SmallVector<IntegerValueRange> argRanges =
119         llvm::map_to_vector(op->getOperands(), IntegerValueRange::getMaxRange);
120 
121     // Infer the result result range if possible.
122     bool matched = false;
123     auto setResultRanges = [&](Value value,
124                                const IntegerValueRange &argRanges) {
125       if (argRanges.isUninitialized())
126         return;
127       if (value != op->getResult(0))
128         return;
129       *bind_value = argRanges;
130       matched = true;
131     };
132     inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
133     return matched;
134   }
135 };
136 
137 /// The matcher that matches operations that have the specified attribute
138 /// name, and binds the attribute value.
139 template <typename AttrT>
140 struct AttrOpBinder {
141   /// Creates a matcher instance that binds the attribute value to
142   /// bind_value if match succeeds.
143   AttrOpBinder(StringRef attrName, AttrT *bindValue)
144       : attrName(attrName), bindValue(bindValue) {}
145   /// Creates a matcher instance that doesn't bind if match succeeds.
146   AttrOpBinder(StringRef attrName) : attrName(attrName), bindValue(nullptr) {}
147 
148   bool match(Operation *op) {
149     if (auto attr = op->getAttrOfType<AttrT>(attrName)) {
150       if (bindValue)
151         *bindValue = attr;
152       return true;
153     }
154     return false;
155   }
156   StringRef attrName;
157   AttrT *bindValue;
158 };
159 
160 /// The matcher that matches a constant scalar / vector splat / tensor splat
161 /// float Attribute or Operation and binds the constant float value.
162 struct constant_float_value_binder {
163   FloatAttr::ValueType *bind_value;
164 
165   /// Creates a matcher instance that binds the value to bv if match succeeds.
166   constant_float_value_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
167 
168   bool match(Attribute attr) {
169     attr_value_binder<FloatAttr> matcher(bind_value);
170     if (matcher.match(attr))
171       return true;
172 
173     if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
174       return matcher.match(splatAttr.getSplatValue<Attribute>());
175 
176     return false;
177   }
178 
179   bool match(Operation *op) {
180     Attribute attr;
181     if (!constant_op_binder<Attribute>(&attr).match(op))
182       return false;
183 
184     Type type = op->getResult(0).getType();
185     if (isa<FloatType, VectorType, RankedTensorType>(type))
186       return match(attr);
187 
188     return false;
189   }
190 };
191 
192 /// The matcher that matches a given target constant scalar / vector splat /
193 /// tensor splat float value that fulfills a predicate.
194 struct constant_float_predicate_matcher {
195   bool (*predicate)(const APFloat &);
196 
197   bool match(Attribute attr) {
198     APFloat value(APFloat::Bogus());
199     return constant_float_value_binder(&value).match(attr) && predicate(value);
200   }
201 
202   bool match(Operation *op) {
203     APFloat value(APFloat::Bogus());
204     return constant_float_value_binder(&value).match(op) && predicate(value);
205   }
206 };
207 
208 /// The matcher that matches a constant scalar / vector splat / tensor splat
209 /// integer Attribute or Operation and binds the constant integer value.
210 struct constant_int_value_binder {
211   IntegerAttr::ValueType *bind_value;
212 
213   /// Creates a matcher instance that binds the value to bv if match succeeds.
214   constant_int_value_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
215 
216   bool match(Attribute attr) {
217     attr_value_binder<IntegerAttr> matcher(bind_value);
218     if (matcher.match(attr))
219       return true;
220 
221     if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
222       return matcher.match(splatAttr.getSplatValue<Attribute>());
223 
224     return false;
225   }
226 
227   bool match(Operation *op) {
228     Attribute attr;
229     if (!constant_op_binder<Attribute>(&attr).match(op))
230       return false;
231 
232     Type type = op->getResult(0).getType();
233     if (isa<IntegerType, IndexType, VectorType, RankedTensorType>(type))
234       return match(attr);
235 
236     return false;
237   }
238 };
239 
240 /// The matcher that matches a given target constant scalar / vector splat /
241 /// tensor splat integer value that fulfills a predicate.
242 struct constant_int_predicate_matcher {
243   bool (*predicate)(const APInt &);
244 
245   bool match(Attribute attr) {
246     APInt value;
247     return constant_int_value_binder(&value).match(attr) && predicate(value);
248   }
249 
250   bool match(Operation *op) {
251     APInt value;
252     return constant_int_value_binder(&value).match(op) && predicate(value);
253   }
254 };
255 
256 /// A matcher that matches a given a constant scalar / vector splat / tensor
257 /// splat integer value or a constant integer range that fulfills a predicate.
258 struct constant_int_range_predicate_matcher {
259   bool (*predicate)(const ConstantIntRanges &);
260 
261   bool match(Attribute attr) {
262     APInt value;
263     return constant_int_value_binder(&value).match(attr) &&
264            predicate(ConstantIntRanges::constant(value));
265   }
266 
267   bool match(Operation *op) {
268     // Try to match a constant integer value first.
269     APInt value;
270     if (constant_int_value_binder(&value).match(op))
271       return predicate(ConstantIntRanges::constant(value));
272 
273     // Otherwise, try to match an operation that implements the
274     // `InferIntRangeInterface` interface.
275     IntegerValueRange range;
276     return infer_int_range_op_binder(&range).match(op) &&
277            predicate(range.getValue());
278   }
279 };
280 
281 /// The matcher that matches a certain kind of op.
282 template <typename OpClass>
283 struct op_matcher {
284   bool match(Operation *op) { return isa<OpClass>(op); }
285 };
286 
287 /// Trait to check whether T provides a 'match' method with type
288 /// `MatchTarget` (Value, Operation, or Attribute).
289 template <typename T, typename MatchTarget>
290 using has_compatible_matcher_t =
291     decltype(std::declval<T>().match(std::declval<MatchTarget>()));
292 
293 /// Statically switch to a Value matcher.
294 template <typename MatcherClass>
295 std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t,
296                                    MatcherClass, Value>::value,
297                  bool>
298 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
299   return matcher.match(op->getOperand(idx));
300 }
301 
302 /// Statically switch to an Operation matcher.
303 template <typename MatcherClass>
304 std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t,
305                                    MatcherClass, Operation *>::value,
306                  bool>
307 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
308   if (auto *defOp = op->getOperand(idx).getDefiningOp())
309     return matcher.match(defOp);
310   return false;
311 }
312 
313 /// Terminal matcher, always returns true.
314 struct AnyValueMatcher {
315   bool match(Value op) const { return true; }
316 };
317 
318 /// Terminal matcher, always returns true.
319 struct AnyCapturedValueMatcher {
320   Value *what;
321   AnyCapturedValueMatcher(Value *what) : what(what) {}
322   bool match(Value op) const {
323     *what = op;
324     return true;
325   }
326 };
327 
328 /// Binds to a specific value and matches it.
329 struct PatternMatcherValue {
330   PatternMatcherValue(Value val) : value(val) {}
331   bool match(Value val) const { return val == value; }
332   Value value;
333 };
334 
335 template <typename TupleT, class CallbackT, std::size_t... Is>
336 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
337                              std::index_sequence<Is...>) {
338 
339   (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
340    ...);
341 }
342 
343 template <typename... Tys, typename CallbackT>
344 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
345   detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
346                         std::make_index_sequence<sizeof...(Tys)>{});
347 }
348 
349 /// RecursivePatternMatcher that composes.
350 template <typename OpType, typename... OperandMatchers>
351 struct RecursivePatternMatcher {
352   RecursivePatternMatcher(OperandMatchers... matchers)
353       : operandMatchers(matchers...) {}
354   bool match(Operation *op) {
355     if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
356       return false;
357     bool res = true;
358     enumerate(operandMatchers, [&](size_t index, auto &matcher) {
359       res &= matchOperandOrValueAtIndex(op, index, matcher);
360     });
361     return res;
362   }
363   std::tuple<OperandMatchers...> operandMatchers;
364 };
365 
366 } // namespace detail
367 
368 /// Matches a constant foldable operation.
369 inline detail::constant_op_matcher m_Constant() {
370   return detail::constant_op_matcher();
371 }
372 
373 /// Matches a named attribute operation.
374 inline detail::AttrOpMatcher m_Attr(StringRef attrName) {
375   return detail::AttrOpMatcher(attrName);
376 }
377 
378 /// Matches a named operation.
379 inline detail::NameOpMatcher m_Op(StringRef opName) {
380   return detail::NameOpMatcher(opName);
381 }
382 
383 /// Matches a value from a constant foldable operation and writes the value to
384 /// bind_value.
385 template <typename AttrT>
386 inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
387   return detail::constant_op_binder<AttrT>(bind_value);
388 }
389 
390 /// Matches a named attribute operation and writes the value to bind_value.
391 template <typename AttrT>
392 inline detail::AttrOpBinder<AttrT> m_Attr(StringRef attrName,
393                                           AttrT *bindValue) {
394   return detail::AttrOpBinder<AttrT>(attrName, bindValue);
395 }
396 
397 /// Matches a constant scalar / vector splat / tensor splat float (both positive
398 /// and negative) zero.
399 inline detail::constant_float_predicate_matcher m_AnyZeroFloat() {
400   return {[](const APFloat &value) { return value.isZero(); }};
401 }
402 
403 /// Matches a constant scalar / vector splat / tensor splat float positive zero.
404 inline detail::constant_float_predicate_matcher m_PosZeroFloat() {
405   return {[](const APFloat &value) { return value.isPosZero(); }};
406 }
407 
408 /// Matches a constant scalar / vector splat / tensor splat float negative zero.
409 inline detail::constant_float_predicate_matcher m_NegZeroFloat() {
410   return {[](const APFloat &value) { return value.isNegZero(); }};
411 }
412 
413 /// Matches a constant scalar / vector splat / tensor splat float ones.
414 inline detail::constant_float_predicate_matcher m_OneFloat() {
415   return {[](const APFloat &value) {
416     return APFloat(value.getSemantics(), 1) == value;
417   }};
418 }
419 
420 /// Matches a constant scalar / vector splat / tensor splat float ones.
421 inline detail::constant_float_predicate_matcher m_NaNFloat() {
422   return {[](const APFloat &value) { return value.isNaN(); }};
423 }
424 
425 /// Matches a constant scalar / vector splat / tensor splat float positive
426 /// infinity.
427 inline detail::constant_float_predicate_matcher m_PosInfFloat() {
428   return {[](const APFloat &value) {
429     return !value.isNegative() && value.isInfinity();
430   }};
431 }
432 
433 /// Matches a constant scalar / vector splat / tensor splat float negative
434 /// infinity.
435 inline detail::constant_float_predicate_matcher m_NegInfFloat() {
436   return {[](const APFloat &value) {
437     return value.isNegative() && value.isInfinity();
438   }};
439 }
440 
441 /// Matches a constant scalar / vector splat / tensor splat integer zero.
442 inline detail::constant_int_predicate_matcher m_Zero() {
443   return {[](const APInt &value) { return 0 == value; }};
444 }
445 
446 /// Matches a constant scalar / vector splat / tensor splat integer that is any
447 /// non-zero value.
448 inline detail::constant_int_predicate_matcher m_NonZero() {
449   return {[](const APInt &value) { return 0 != value; }};
450 }
451 
452 /// Matches a constant scalar / vector splat / tensor splat integer or a
453 /// unsigned integer range that does not contain zero. Note that this matcher
454 /// interprets the target value as an unsigned integer.
455 inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU() {
456   return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
457 }
458 
459 /// Matches a constant scalar / vector splat / tensor splat integer or a
460 /// signed integer range that does not contain zero. Note that this matcher
461 /// interprets the target value as a signed integer.
462 inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS() {
463   return {[](const ConstantIntRanges &range) {
464     return range.smin().sgt(0) || range.smax().slt(0);
465   }};
466 }
467 
468 /// Matches a constant scalar / vector splat / tensor splat integer or a
469 /// signed integer range that does not contain minus one. Note
470 /// that this matcher interprets the target value as a signed integer.
471 inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS() {
472   return {[](const ConstantIntRanges &range) {
473     return range.smin().sgt(-1) || range.smax().slt(-1);
474   }};
475 }
476 
477 /// Matches a constant scalar / vector splat / tensor splat integer one.
478 inline detail::constant_int_predicate_matcher m_One() {
479   return {[](const APInt &value) { return 1 == value; }};
480 }
481 
482 /// Matches the given OpClass.
483 template <typename OpClass>
484 inline detail::op_matcher<OpClass> m_Op() {
485   return detail::op_matcher<OpClass>();
486 }
487 
488 /// Entry point for matching a pattern over a Value.
489 template <typename Pattern>
490 inline bool matchPattern(Value value, const Pattern &pattern) {
491   assert(value);
492   // TODO: handle other cases
493   if (auto *op = value.getDefiningOp())
494     return const_cast<Pattern &>(pattern).match(op);
495   return false;
496 }
497 
498 /// Entry point for matching a pattern over an Operation.
499 template <typename Pattern>
500 inline bool matchPattern(Operation *op, const Pattern &pattern) {
501   assert(op);
502   return const_cast<Pattern &>(pattern).match(op);
503 }
504 
505 /// Entry point for matching a pattern over an Attribute. Returns `false`
506 /// when `attr` is null.
507 template <typename Pattern>
508 inline bool matchPattern(Attribute attr, const Pattern &pattern) {
509   static_assert(llvm::is_detected<detail::has_compatible_matcher_t, Pattern,
510                                   Attribute>::value,
511                 "Pattern does not support matching Attributes");
512   if (!attr)
513     return false;
514   return const_cast<Pattern &>(pattern).match(attr);
515 }
516 
517 /// Matches a constant holding a scalar/vector/tensor float (splat) and
518 /// writes the float value to bind_value.
519 inline detail::constant_float_value_binder
520 m_ConstantFloat(FloatAttr::ValueType *bind_value) {
521   return detail::constant_float_value_binder(bind_value);
522 }
523 
524 /// Matches a constant holding a scalar/vector/tensor integer (splat) and
525 /// writes the integer value to bind_value.
526 inline detail::constant_int_value_binder
527 m_ConstantInt(IntegerAttr::ValueType *bind_value) {
528   return detail::constant_int_value_binder(bind_value);
529 }
530 
531 template <typename OpType, typename... Matchers>
532 auto m_Op(Matchers... matchers) {
533   return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
534 }
535 
536 namespace matchers {
537 inline auto m_Any() { return detail::AnyValueMatcher(); }
538 inline auto m_Any(Value *val) { return detail::AnyCapturedValueMatcher(val); }
539 inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
540 } // namespace matchers
541 
542 } // namespace mlir
543 
544 #endif // MLIR_IR_MATCHERS_H
545