1 //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides a simple and efficient mechanism for performing general 10 // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's 11 // include/llvm/IR/PatternMatch.h. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_IR_MATCHERS_H 16 #define MLIR_IR_MATCHERS_H 17 18 #include "mlir/IR/BuiltinAttributes.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/OpDefinition.h" 21 #include "mlir/Interfaces/InferIntRangeInterface.h" 22 23 namespace mlir { 24 25 namespace detail { 26 27 /// The matcher that matches a certain kind of Attribute and binds the value 28 /// inside the Attribute. 29 template < 30 typename AttrClass, 31 // Require AttrClass to be a derived class from Attribute and get its 32 // value type 33 typename ValueType = typename std::enable_if_t< 34 std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType, 35 // Require the ValueType is not void 36 typename = std::enable_if_t<!std::is_void<ValueType>::value>> 37 struct attr_value_binder { 38 ValueType *bind_value; 39 40 /// Creates a matcher instance that binds the value to bv if match succeeds. 41 attr_value_binder(ValueType *bv) : bind_value(bv) {} 42 43 bool match(Attribute attr) { 44 if (auto intAttr = llvm::dyn_cast<AttrClass>(attr)) { 45 *bind_value = intAttr.getValue(); 46 return true; 47 } 48 return false; 49 } 50 }; 51 52 /// The matcher that matches operations that have the `ConstantLike` trait. 53 struct constant_op_matcher { 54 bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); } 55 }; 56 57 /// The matcher that matches operations that have the specified op name. 58 struct NameOpMatcher { 59 NameOpMatcher(StringRef name) : name(name) {} 60 bool match(Operation *op) { return op->getName().getStringRef() == name; } 61 62 StringRef name; 63 }; 64 65 /// The matcher that matches operations that have the specified attribute name. 66 struct AttrOpMatcher { 67 AttrOpMatcher(StringRef attrName) : attrName(attrName) {} 68 bool match(Operation *op) { return op->hasAttr(attrName); } 69 70 StringRef attrName; 71 }; 72 73 /// The matcher that matches operations that have the `ConstantLike` trait, and 74 /// binds the folded attribute value. 75 template <typename AttrT> 76 struct constant_op_binder { 77 AttrT *bind_value; 78 79 /// Creates a matcher instance that binds the constant attribute value to 80 /// bind_value if match succeeds. 81 constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {} 82 /// Creates a matcher instance that doesn't bind if match succeeds. 83 constant_op_binder() : bind_value(nullptr) {} 84 85 bool match(Operation *op) { 86 if (!op->hasTrait<OpTrait::ConstantLike>()) 87 return false; 88 89 // Fold the constant to an attribute. 90 SmallVector<OpFoldResult, 1> foldedOp; 91 LogicalResult result = op->fold(/*operands=*/std::nullopt, foldedOp); 92 (void)result; 93 assert(succeeded(result) && "expected ConstantLike op to be foldable"); 94 95 if (auto attr = llvm::dyn_cast<AttrT>(cast<Attribute>(foldedOp.front()))) { 96 if (bind_value) 97 *bind_value = attr; 98 return true; 99 } 100 return false; 101 } 102 }; 103 104 /// A matcher that matches operations that implement the 105 /// `InferIntRangeInterface` interface, and binds the inferred range. 106 struct infer_int_range_op_binder { 107 IntegerValueRange *bind_value; 108 109 explicit infer_int_range_op_binder(IntegerValueRange *bind_value) 110 : bind_value(bind_value) {} 111 112 bool match(Operation *op) { 113 auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op); 114 if (!inferIntRangeOp) 115 return false; 116 117 // Set the range of all integer operands to the maximal range. 118 SmallVector<IntegerValueRange> argRanges = 119 llvm::map_to_vector(op->getOperands(), IntegerValueRange::getMaxRange); 120 121 // Infer the result result range if possible. 122 bool matched = false; 123 auto setResultRanges = [&](Value value, 124 const IntegerValueRange &argRanges) { 125 if (argRanges.isUninitialized()) 126 return; 127 if (value != op->getResult(0)) 128 return; 129 *bind_value = argRanges; 130 matched = true; 131 }; 132 inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges); 133 return matched; 134 } 135 }; 136 137 /// The matcher that matches operations that have the specified attribute 138 /// name, and binds the attribute value. 139 template <typename AttrT> 140 struct AttrOpBinder { 141 /// Creates a matcher instance that binds the attribute value to 142 /// bind_value if match succeeds. 143 AttrOpBinder(StringRef attrName, AttrT *bindValue) 144 : attrName(attrName), bindValue(bindValue) {} 145 /// Creates a matcher instance that doesn't bind if match succeeds. 146 AttrOpBinder(StringRef attrName) : attrName(attrName), bindValue(nullptr) {} 147 148 bool match(Operation *op) { 149 if (auto attr = op->getAttrOfType<AttrT>(attrName)) { 150 if (bindValue) 151 *bindValue = attr; 152 return true; 153 } 154 return false; 155 } 156 StringRef attrName; 157 AttrT *bindValue; 158 }; 159 160 /// The matcher that matches a constant scalar / vector splat / tensor splat 161 /// float Attribute or Operation and binds the constant float value. 162 struct constant_float_value_binder { 163 FloatAttr::ValueType *bind_value; 164 165 /// Creates a matcher instance that binds the value to bv if match succeeds. 166 constant_float_value_binder(FloatAttr::ValueType *bv) : bind_value(bv) {} 167 168 bool match(Attribute attr) { 169 attr_value_binder<FloatAttr> matcher(bind_value); 170 if (matcher.match(attr)) 171 return true; 172 173 if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) 174 return matcher.match(splatAttr.getSplatValue<Attribute>()); 175 176 return false; 177 } 178 179 bool match(Operation *op) { 180 Attribute attr; 181 if (!constant_op_binder<Attribute>(&attr).match(op)) 182 return false; 183 184 Type type = op->getResult(0).getType(); 185 if (isa<FloatType, VectorType, RankedTensorType>(type)) 186 return match(attr); 187 188 return false; 189 } 190 }; 191 192 /// The matcher that matches a given target constant scalar / vector splat / 193 /// tensor splat float value that fulfills a predicate. 194 struct constant_float_predicate_matcher { 195 bool (*predicate)(const APFloat &); 196 197 bool match(Attribute attr) { 198 APFloat value(APFloat::Bogus()); 199 return constant_float_value_binder(&value).match(attr) && predicate(value); 200 } 201 202 bool match(Operation *op) { 203 APFloat value(APFloat::Bogus()); 204 return constant_float_value_binder(&value).match(op) && predicate(value); 205 } 206 }; 207 208 /// The matcher that matches a constant scalar / vector splat / tensor splat 209 /// integer Attribute or Operation and binds the constant integer value. 210 struct constant_int_value_binder { 211 IntegerAttr::ValueType *bind_value; 212 213 /// Creates a matcher instance that binds the value to bv if match succeeds. 214 constant_int_value_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {} 215 216 bool match(Attribute attr) { 217 attr_value_binder<IntegerAttr> matcher(bind_value); 218 if (matcher.match(attr)) 219 return true; 220 221 if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) 222 return matcher.match(splatAttr.getSplatValue<Attribute>()); 223 224 return false; 225 } 226 227 bool match(Operation *op) { 228 Attribute attr; 229 if (!constant_op_binder<Attribute>(&attr).match(op)) 230 return false; 231 232 Type type = op->getResult(0).getType(); 233 if (isa<IntegerType, IndexType, VectorType, RankedTensorType>(type)) 234 return match(attr); 235 236 return false; 237 } 238 }; 239 240 /// The matcher that matches a given target constant scalar / vector splat / 241 /// tensor splat integer value that fulfills a predicate. 242 struct constant_int_predicate_matcher { 243 bool (*predicate)(const APInt &); 244 245 bool match(Attribute attr) { 246 APInt value; 247 return constant_int_value_binder(&value).match(attr) && predicate(value); 248 } 249 250 bool match(Operation *op) { 251 APInt value; 252 return constant_int_value_binder(&value).match(op) && predicate(value); 253 } 254 }; 255 256 /// A matcher that matches a given a constant scalar / vector splat / tensor 257 /// splat integer value or a constant integer range that fulfills a predicate. 258 struct constant_int_range_predicate_matcher { 259 bool (*predicate)(const ConstantIntRanges &); 260 261 bool match(Attribute attr) { 262 APInt value; 263 return constant_int_value_binder(&value).match(attr) && 264 predicate(ConstantIntRanges::constant(value)); 265 } 266 267 bool match(Operation *op) { 268 // Try to match a constant integer value first. 269 APInt value; 270 if (constant_int_value_binder(&value).match(op)) 271 return predicate(ConstantIntRanges::constant(value)); 272 273 // Otherwise, try to match an operation that implements the 274 // `InferIntRangeInterface` interface. 275 IntegerValueRange range; 276 return infer_int_range_op_binder(&range).match(op) && 277 predicate(range.getValue()); 278 } 279 }; 280 281 /// The matcher that matches a certain kind of op. 282 template <typename OpClass> 283 struct op_matcher { 284 bool match(Operation *op) { return isa<OpClass>(op); } 285 }; 286 287 /// Trait to check whether T provides a 'match' method with type 288 /// `MatchTarget` (Value, Operation, or Attribute). 289 template <typename T, typename MatchTarget> 290 using has_compatible_matcher_t = 291 decltype(std::declval<T>().match(std::declval<MatchTarget>())); 292 293 /// Statically switch to a Value matcher. 294 template <typename MatcherClass> 295 std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t, 296 MatcherClass, Value>::value, 297 bool> 298 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { 299 return matcher.match(op->getOperand(idx)); 300 } 301 302 /// Statically switch to an Operation matcher. 303 template <typename MatcherClass> 304 std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t, 305 MatcherClass, Operation *>::value, 306 bool> 307 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { 308 if (auto *defOp = op->getOperand(idx).getDefiningOp()) 309 return matcher.match(defOp); 310 return false; 311 } 312 313 /// Terminal matcher, always returns true. 314 struct AnyValueMatcher { 315 bool match(Value op) const { return true; } 316 }; 317 318 /// Terminal matcher, always returns true. 319 struct AnyCapturedValueMatcher { 320 Value *what; 321 AnyCapturedValueMatcher(Value *what) : what(what) {} 322 bool match(Value op) const { 323 *what = op; 324 return true; 325 } 326 }; 327 328 /// Binds to a specific value and matches it. 329 struct PatternMatcherValue { 330 PatternMatcherValue(Value val) : value(val) {} 331 bool match(Value val) const { return val == value; } 332 Value value; 333 }; 334 335 template <typename TupleT, class CallbackT, std::size_t... Is> 336 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, 337 std::index_sequence<Is...>) { 338 339 (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)), 340 ...); 341 } 342 343 template <typename... Tys, typename CallbackT> 344 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) { 345 detail::enumerateImpl(tuple, std::forward<CallbackT>(callback), 346 std::make_index_sequence<sizeof...(Tys)>{}); 347 } 348 349 /// RecursivePatternMatcher that composes. 350 template <typename OpType, typename... OperandMatchers> 351 struct RecursivePatternMatcher { 352 RecursivePatternMatcher(OperandMatchers... matchers) 353 : operandMatchers(matchers...) {} 354 bool match(Operation *op) { 355 if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers)) 356 return false; 357 bool res = true; 358 enumerate(operandMatchers, [&](size_t index, auto &matcher) { 359 res &= matchOperandOrValueAtIndex(op, index, matcher); 360 }); 361 return res; 362 } 363 std::tuple<OperandMatchers...> operandMatchers; 364 }; 365 366 } // namespace detail 367 368 /// Matches a constant foldable operation. 369 inline detail::constant_op_matcher m_Constant() { 370 return detail::constant_op_matcher(); 371 } 372 373 /// Matches a named attribute operation. 374 inline detail::AttrOpMatcher m_Attr(StringRef attrName) { 375 return detail::AttrOpMatcher(attrName); 376 } 377 378 /// Matches a named operation. 379 inline detail::NameOpMatcher m_Op(StringRef opName) { 380 return detail::NameOpMatcher(opName); 381 } 382 383 /// Matches a value from a constant foldable operation and writes the value to 384 /// bind_value. 385 template <typename AttrT> 386 inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) { 387 return detail::constant_op_binder<AttrT>(bind_value); 388 } 389 390 /// Matches a named attribute operation and writes the value to bind_value. 391 template <typename AttrT> 392 inline detail::AttrOpBinder<AttrT> m_Attr(StringRef attrName, 393 AttrT *bindValue) { 394 return detail::AttrOpBinder<AttrT>(attrName, bindValue); 395 } 396 397 /// Matches a constant scalar / vector splat / tensor splat float (both positive 398 /// and negative) zero. 399 inline detail::constant_float_predicate_matcher m_AnyZeroFloat() { 400 return {[](const APFloat &value) { return value.isZero(); }}; 401 } 402 403 /// Matches a constant scalar / vector splat / tensor splat float positive zero. 404 inline detail::constant_float_predicate_matcher m_PosZeroFloat() { 405 return {[](const APFloat &value) { return value.isPosZero(); }}; 406 } 407 408 /// Matches a constant scalar / vector splat / tensor splat float negative zero. 409 inline detail::constant_float_predicate_matcher m_NegZeroFloat() { 410 return {[](const APFloat &value) { return value.isNegZero(); }}; 411 } 412 413 /// Matches a constant scalar / vector splat / tensor splat float ones. 414 inline detail::constant_float_predicate_matcher m_OneFloat() { 415 return {[](const APFloat &value) { 416 return APFloat(value.getSemantics(), 1) == value; 417 }}; 418 } 419 420 /// Matches a constant scalar / vector splat / tensor splat float ones. 421 inline detail::constant_float_predicate_matcher m_NaNFloat() { 422 return {[](const APFloat &value) { return value.isNaN(); }}; 423 } 424 425 /// Matches a constant scalar / vector splat / tensor splat float positive 426 /// infinity. 427 inline detail::constant_float_predicate_matcher m_PosInfFloat() { 428 return {[](const APFloat &value) { 429 return !value.isNegative() && value.isInfinity(); 430 }}; 431 } 432 433 /// Matches a constant scalar / vector splat / tensor splat float negative 434 /// infinity. 435 inline detail::constant_float_predicate_matcher m_NegInfFloat() { 436 return {[](const APFloat &value) { 437 return value.isNegative() && value.isInfinity(); 438 }}; 439 } 440 441 /// Matches a constant scalar / vector splat / tensor splat integer zero. 442 inline detail::constant_int_predicate_matcher m_Zero() { 443 return {[](const APInt &value) { return 0 == value; }}; 444 } 445 446 /// Matches a constant scalar / vector splat / tensor splat integer that is any 447 /// non-zero value. 448 inline detail::constant_int_predicate_matcher m_NonZero() { 449 return {[](const APInt &value) { return 0 != value; }}; 450 } 451 452 /// Matches a constant scalar / vector splat / tensor splat integer or a 453 /// unsigned integer range that does not contain zero. Note that this matcher 454 /// interprets the target value as an unsigned integer. 455 inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU() { 456 return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }}; 457 } 458 459 /// Matches a constant scalar / vector splat / tensor splat integer or a 460 /// signed integer range that does not contain zero. Note that this matcher 461 /// interprets the target value as a signed integer. 462 inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS() { 463 return {[](const ConstantIntRanges &range) { 464 return range.smin().sgt(0) || range.smax().slt(0); 465 }}; 466 } 467 468 /// Matches a constant scalar / vector splat / tensor splat integer or a 469 /// signed integer range that does not contain minus one. Note 470 /// that this matcher interprets the target value as a signed integer. 471 inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS() { 472 return {[](const ConstantIntRanges &range) { 473 return range.smin().sgt(-1) || range.smax().slt(-1); 474 }}; 475 } 476 477 /// Matches a constant scalar / vector splat / tensor splat integer one. 478 inline detail::constant_int_predicate_matcher m_One() { 479 return {[](const APInt &value) { return 1 == value; }}; 480 } 481 482 /// Matches the given OpClass. 483 template <typename OpClass> 484 inline detail::op_matcher<OpClass> m_Op() { 485 return detail::op_matcher<OpClass>(); 486 } 487 488 /// Entry point for matching a pattern over a Value. 489 template <typename Pattern> 490 inline bool matchPattern(Value value, const Pattern &pattern) { 491 assert(value); 492 // TODO: handle other cases 493 if (auto *op = value.getDefiningOp()) 494 return const_cast<Pattern &>(pattern).match(op); 495 return false; 496 } 497 498 /// Entry point for matching a pattern over an Operation. 499 template <typename Pattern> 500 inline bool matchPattern(Operation *op, const Pattern &pattern) { 501 assert(op); 502 return const_cast<Pattern &>(pattern).match(op); 503 } 504 505 /// Entry point for matching a pattern over an Attribute. Returns `false` 506 /// when `attr` is null. 507 template <typename Pattern> 508 inline bool matchPattern(Attribute attr, const Pattern &pattern) { 509 static_assert(llvm::is_detected<detail::has_compatible_matcher_t, Pattern, 510 Attribute>::value, 511 "Pattern does not support matching Attributes"); 512 if (!attr) 513 return false; 514 return const_cast<Pattern &>(pattern).match(attr); 515 } 516 517 /// Matches a constant holding a scalar/vector/tensor float (splat) and 518 /// writes the float value to bind_value. 519 inline detail::constant_float_value_binder 520 m_ConstantFloat(FloatAttr::ValueType *bind_value) { 521 return detail::constant_float_value_binder(bind_value); 522 } 523 524 /// Matches a constant holding a scalar/vector/tensor integer (splat) and 525 /// writes the integer value to bind_value. 526 inline detail::constant_int_value_binder 527 m_ConstantInt(IntegerAttr::ValueType *bind_value) { 528 return detail::constant_int_value_binder(bind_value); 529 } 530 531 template <typename OpType, typename... Matchers> 532 auto m_Op(Matchers... matchers) { 533 return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...); 534 } 535 536 namespace matchers { 537 inline auto m_Any() { return detail::AnyValueMatcher(); } 538 inline auto m_Any(Value *val) { return detail::AnyCapturedValueMatcher(val); } 539 inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); } 540 } // namespace matchers 541 542 } // namespace mlir 543 544 #endif // MLIR_IR_MATCHERS_H 545