1 //==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// Contains matchers for matching SelectionDAG nodes and values. 10 /// 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_CODEGEN_SDPATTERNMATCH_H 14 #define LLVM_CODEGEN_SDPATTERNMATCH_H 15 16 #include "llvm/ADT/APInt.h" 17 #include "llvm/ADT/STLExtras.h" 18 #include "llvm/CodeGen/SelectionDAG.h" 19 #include "llvm/CodeGen/SelectionDAGNodes.h" 20 #include "llvm/CodeGen/TargetLowering.h" 21 22 namespace llvm { 23 namespace SDPatternMatch { 24 25 /// MatchContext can repurpose existing patterns to behave differently under 26 /// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes 27 /// in normal circumstances, but matches VP_ADD nodes under a custom 28 /// VPMatchContext. This design is meant to facilitate code / pattern reusing. 29 class BasicMatchContext { 30 const SelectionDAG *DAG; 31 const TargetLowering *TLI; 32 33 public: 34 explicit BasicMatchContext(const SelectionDAG *DAG) 35 : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {} 36 37 explicit BasicMatchContext(const TargetLowering *TLI) 38 : DAG(nullptr), TLI(TLI) {} 39 40 // A valid MatchContext has to implement the following functions. 41 42 const SelectionDAG *getDAG() const { return DAG; } 43 44 const TargetLowering *getTLI() const { return TLI; } 45 46 /// Return true if N effectively has opcode Opcode. 47 bool match(SDValue N, unsigned Opcode) const { 48 return N->getOpcode() == Opcode; 49 } 50 51 unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); } 52 }; 53 54 template <typename Pattern, typename MatchContext> 55 [[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx, 56 Pattern &&P) { 57 return P.match(Ctx, N); 58 } 59 60 template <typename Pattern, typename MatchContext> 61 [[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx, 62 Pattern &&P) { 63 return sd_context_match(SDValue(N, 0), Ctx, P); 64 } 65 66 template <typename Pattern> 67 [[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) { 68 return sd_context_match(N, BasicMatchContext(DAG), P); 69 } 70 71 template <typename Pattern> 72 [[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) { 73 return sd_context_match(N, BasicMatchContext(DAG), P); 74 } 75 76 template <typename Pattern> 77 [[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) { 78 return sd_match(N, nullptr, P); 79 } 80 81 template <typename Pattern> 82 [[nodiscard]] bool sd_match(SDValue N, Pattern &&P) { 83 return sd_match(N, nullptr, P); 84 } 85 86 // === Utilities === 87 struct Value_match { 88 SDValue MatchVal; 89 90 Value_match() = default; 91 92 explicit Value_match(SDValue Match) : MatchVal(Match) {} 93 94 template <typename MatchContext> bool match(const MatchContext &, SDValue N) { 95 if (MatchVal) 96 return MatchVal == N; 97 return N.getNode(); 98 } 99 }; 100 101 /// Match any valid SDValue. 102 inline Value_match m_Value() { return Value_match(); } 103 104 inline Value_match m_Specific(SDValue N) { 105 assert(N); 106 return Value_match(N); 107 } 108 109 struct DeferredValue_match { 110 SDValue &MatchVal; 111 112 explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {} 113 114 template <typename MatchContext> bool match(const MatchContext &, SDValue N) { 115 return N == MatchVal; 116 } 117 }; 118 119 /// Similar to m_Specific, but the specific value to match is determined by 120 /// another sub-pattern in the same sd_match() expression. For instance, 121 /// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since 122 /// `X` is not initialized at the time it got copied into `m_Specific`. Instead, 123 /// we should use `m_Add(m_Value(X), m_Deferred(X))`. 124 inline DeferredValue_match m_Deferred(SDValue &V) { 125 return DeferredValue_match(V); 126 } 127 128 struct Opcode_match { 129 unsigned Opcode; 130 131 explicit Opcode_match(unsigned Opc) : Opcode(Opc) {} 132 133 template <typename MatchContext> 134 bool match(const MatchContext &Ctx, SDValue N) { 135 return Ctx.match(N, Opcode); 136 } 137 }; 138 139 inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); } 140 141 inline Opcode_match m_Undef() { return Opcode_match(ISD::UNDEF); } 142 143 template <unsigned NumUses, typename Pattern> struct NUses_match { 144 Pattern P; 145 146 explicit NUses_match(const Pattern &P) : P(P) {} 147 148 template <typename MatchContext> 149 bool match(const MatchContext &Ctx, SDValue N) { 150 // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces 151 // multiple results, hence we check the subsequent pattern here before 152 // checking the number of value users. 153 return P.match(Ctx, N) && N->hasNUsesOfValue(NumUses, N.getResNo()); 154 } 155 }; 156 157 template <typename Pattern> 158 inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) { 159 return NUses_match<1, Pattern>(P); 160 } 161 template <unsigned N, typename Pattern> 162 inline NUses_match<N, Pattern> m_NUses(const Pattern &P) { 163 return NUses_match<N, Pattern>(P); 164 } 165 166 inline NUses_match<1, Value_match> m_OneUse() { 167 return NUses_match<1, Value_match>(m_Value()); 168 } 169 template <unsigned N> inline NUses_match<N, Value_match> m_NUses() { 170 return NUses_match<N, Value_match>(m_Value()); 171 } 172 173 struct Value_bind { 174 SDValue &BindVal; 175 176 explicit Value_bind(SDValue &N) : BindVal(N) {} 177 178 template <typename MatchContext> bool match(const MatchContext &, SDValue N) { 179 BindVal = N; 180 return true; 181 } 182 }; 183 184 inline Value_bind m_Value(SDValue &N) { return Value_bind(N); } 185 186 template <typename Pattern, typename PredFuncT> struct TLI_pred_match { 187 Pattern P; 188 PredFuncT PredFunc; 189 190 TLI_pred_match(const PredFuncT &Pred, const Pattern &P) 191 : P(P), PredFunc(Pred) {} 192 193 template <typename MatchContext> 194 bool match(const MatchContext &Ctx, SDValue N) { 195 assert(Ctx.getTLI() && "TargetLowering is required for this pattern."); 196 return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N); 197 } 198 }; 199 200 // Explicit deduction guide. 201 template <typename PredFuncT, typename Pattern> 202 TLI_pred_match(const PredFuncT &Pred, const Pattern &P) 203 -> TLI_pred_match<Pattern, PredFuncT>; 204 205 /// Match legal SDNodes based on the information provided by TargetLowering. 206 template <typename Pattern> inline auto m_LegalOp(const Pattern &P) { 207 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) { 208 return TLI.isOperationLegal(N->getOpcode(), 209 N.getValueType()); 210 }, 211 P}; 212 } 213 214 /// Switch to a different MatchContext for subsequent patterns. 215 template <typename NewMatchContext, typename Pattern> struct SwitchContext { 216 const NewMatchContext &Ctx; 217 Pattern P; 218 219 template <typename OrigMatchContext> 220 bool match(const OrigMatchContext &, SDValue N) { 221 return P.match(Ctx, N); 222 } 223 }; 224 225 template <typename MatchContext, typename Pattern> 226 inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx, 227 Pattern &&P) { 228 return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)}; 229 } 230 231 // === Value type === 232 struct ValueType_bind { 233 EVT &BindVT; 234 235 explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {} 236 237 template <typename MatchContext> bool match(const MatchContext &, SDValue N) { 238 BindVT = N.getValueType(); 239 return true; 240 } 241 }; 242 243 /// Retreive the ValueType of the current SDValue. 244 inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); } 245 246 template <typename Pattern, typename PredFuncT> struct ValueType_match { 247 PredFuncT PredFunc; 248 Pattern P; 249 250 ValueType_match(const PredFuncT &Pred, const Pattern &P) 251 : PredFunc(Pred), P(P) {} 252 253 template <typename MatchContext> 254 bool match(const MatchContext &Ctx, SDValue N) { 255 return PredFunc(N.getValueType()) && P.match(Ctx, N); 256 } 257 }; 258 259 // Explicit deduction guide. 260 template <typename PredFuncT, typename Pattern> 261 ValueType_match(const PredFuncT &Pred, const Pattern &P) 262 -> ValueType_match<Pattern, PredFuncT>; 263 264 /// Match a specific ValueType. 265 template <typename Pattern> 266 inline auto m_SpecificVT(EVT RefVT, const Pattern &P) { 267 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P}; 268 } 269 inline auto m_SpecificVT(EVT RefVT) { 270 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()}; 271 } 272 273 inline auto m_Glue() { return m_SpecificVT(MVT::Glue); } 274 inline auto m_OtherVT() { return m_SpecificVT(MVT::Other); } 275 276 /// Match any integer ValueTypes. 277 template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) { 278 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P}; 279 } 280 inline auto m_IntegerVT() { 281 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()}; 282 } 283 284 /// Match any floating point ValueTypes. 285 template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) { 286 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P}; 287 } 288 inline auto m_FloatingPointVT() { 289 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, 290 m_Value()}; 291 } 292 293 /// Match any vector ValueTypes. 294 template <typename Pattern> inline auto m_VectorVT(const Pattern &P) { 295 return ValueType_match{[](EVT VT) { return VT.isVector(); }, P}; 296 } 297 inline auto m_VectorVT() { 298 return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()}; 299 } 300 301 /// Match fixed-length vector ValueTypes. 302 template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) { 303 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P}; 304 } 305 inline auto m_FixedVectorVT() { 306 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, 307 m_Value()}; 308 } 309 310 /// Match scalable vector ValueTypes. 311 template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) { 312 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P}; 313 } 314 inline auto m_ScalableVectorVT() { 315 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, 316 m_Value()}; 317 } 318 319 /// Match legal ValueTypes based on the information provided by TargetLowering. 320 template <typename Pattern> inline auto m_LegalType(const Pattern &P) { 321 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) { 322 return TLI.isTypeLegal(N.getValueType()); 323 }, 324 P}; 325 } 326 327 // === Patterns combinators === 328 template <typename... Preds> struct And { 329 template <typename MatchContext> bool match(const MatchContext &, SDValue N) { 330 return true; 331 } 332 }; 333 334 template <typename Pred, typename... Preds> 335 struct And<Pred, Preds...> : And<Preds...> { 336 Pred P; 337 And(const Pred &p, const Preds &...preds) : And<Preds...>(preds...), P(p) {} 338 339 template <typename MatchContext> 340 bool match(const MatchContext &Ctx, SDValue N) { 341 return P.match(Ctx, N) && And<Preds...>::match(Ctx, N); 342 } 343 }; 344 345 template <typename... Preds> struct Or { 346 template <typename MatchContext> bool match(const MatchContext &, SDValue N) { 347 return false; 348 } 349 }; 350 351 template <typename Pred, typename... Preds> 352 struct Or<Pred, Preds...> : Or<Preds...> { 353 Pred P; 354 Or(const Pred &p, const Preds &...preds) : Or<Preds...>(preds...), P(p) {} 355 356 template <typename MatchContext> 357 bool match(const MatchContext &Ctx, SDValue N) { 358 return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N); 359 } 360 }; 361 362 template <typename Pred> struct Not { 363 Pred P; 364 365 explicit Not(const Pred &P) : P(P) {} 366 367 template <typename MatchContext> 368 bool match(const MatchContext &Ctx, SDValue N) { 369 return !P.match(Ctx, N); 370 } 371 }; 372 // Explicit deduction guide. 373 template <typename Pred> Not(const Pred &P) -> Not<Pred>; 374 375 /// Match if the inner pattern does NOT match. 376 template <typename Pred> inline Not<Pred> m_Unless(const Pred &P) { 377 return Not{P}; 378 } 379 380 template <typename... Preds> And<Preds...> m_AllOf(const Preds &...preds) { 381 return And<Preds...>(preds...); 382 } 383 384 template <typename... Preds> Or<Preds...> m_AnyOf(const Preds &...preds) { 385 return Or<Preds...>(preds...); 386 } 387 388 template <typename... Preds> auto m_NoneOf(const Preds &...preds) { 389 return m_Unless(m_AnyOf(preds...)); 390 } 391 392 // === Generic node matching === 393 template <unsigned OpIdx, typename... OpndPreds> struct Operands_match { 394 template <typename MatchContext> 395 bool match(const MatchContext &Ctx, SDValue N) { 396 // Returns false if there are more operands than predicates; 397 // Ignores the last two operands if both the Context and the Node are VP 398 return Ctx.getNumOperands(N) == OpIdx; 399 } 400 }; 401 402 template <unsigned OpIdx, typename OpndPred, typename... OpndPreds> 403 struct Operands_match<OpIdx, OpndPred, OpndPreds...> 404 : Operands_match<OpIdx + 1, OpndPreds...> { 405 OpndPred P; 406 407 Operands_match(const OpndPred &p, const OpndPreds &...preds) 408 : Operands_match<OpIdx + 1, OpndPreds...>(preds...), P(p) {} 409 410 template <typename MatchContext> 411 bool match(const MatchContext &Ctx, SDValue N) { 412 if (OpIdx < N->getNumOperands()) 413 return P.match(Ctx, N->getOperand(OpIdx)) && 414 Operands_match<OpIdx + 1, OpndPreds...>::match(Ctx, N); 415 416 // This is the case where there are more predicates than operands. 417 return false; 418 } 419 }; 420 421 template <typename... OpndPreds> 422 auto m_Node(unsigned Opcode, const OpndPreds &...preds) { 423 return m_AllOf(m_Opc(Opcode), Operands_match<0, OpndPreds...>(preds...)); 424 } 425 426 /// Provide number of operands that are not chain or glue, as well as the first 427 /// index of such operand. 428 template <bool ExcludeChain> struct EffectiveOperands { 429 unsigned Size = 0; 430 unsigned FirstIndex = 0; 431 432 template <typename MatchContext> 433 explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) { 434 const unsigned TotalNumOps = Ctx.getNumOperands(N); 435 FirstIndex = TotalNumOps; 436 for (unsigned I = 0; I < TotalNumOps; ++I) { 437 // Count the number of non-chain and non-glue nodes (we ignore chain 438 // and glue by default) and retreive the operand index offset. 439 EVT VT = N->getOperand(I).getValueType(); 440 if (VT != MVT::Glue && VT != MVT::Other) { 441 ++Size; 442 if (FirstIndex == TotalNumOps) 443 FirstIndex = I; 444 } 445 } 446 } 447 }; 448 449 template <> struct EffectiveOperands<false> { 450 unsigned Size = 0; 451 unsigned FirstIndex = 0; 452 453 template <typename MatchContext> 454 explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) 455 : Size(Ctx.getNumOperands(N)) {} 456 }; 457 458 // === Ternary operations === 459 template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false, 460 bool ExcludeChain = false> 461 struct TernaryOpc_match { 462 unsigned Opcode; 463 T0_P Op0; 464 T1_P Op1; 465 T2_P Op2; 466 467 TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1, 468 const T2_P &Op2) 469 : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {} 470 471 template <typename MatchContext> 472 bool match(const MatchContext &Ctx, SDValue N) { 473 if (sd_context_match(N, Ctx, m_Opc(Opcode))) { 474 EffectiveOperands<ExcludeChain> EO(N, Ctx); 475 assert(EO.Size == 3); 476 return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) && 477 Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) || 478 (Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) && 479 Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) && 480 Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2)); 481 } 482 483 return false; 484 } 485 }; 486 487 template <typename T0_P, typename T1_P, typename T2_P> 488 inline TernaryOpc_match<T0_P, T1_P, T2_P> 489 m_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) { 490 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SETCC, LHS, RHS, CC); 491 } 492 493 template <typename T0_P, typename T1_P, typename T2_P> 494 inline TernaryOpc_match<T0_P, T1_P, T2_P, true, false> 495 m_c_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) { 496 return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, LHS, RHS, 497 CC); 498 } 499 500 template <typename T0_P, typename T1_P, typename T2_P> 501 inline TernaryOpc_match<T0_P, T1_P, T2_P> 502 m_Select(const T0_P &Cond, const T1_P &T, const T2_P &F) { 503 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SELECT, Cond, T, F); 504 } 505 506 template <typename T0_P, typename T1_P, typename T2_P> 507 inline TernaryOpc_match<T0_P, T1_P, T2_P> 508 m_VSelect(const T0_P &Cond, const T1_P &T, const T2_P &F) { 509 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::VSELECT, Cond, T, F); 510 } 511 512 template <typename T0_P, typename T1_P, typename T2_P> 513 inline TernaryOpc_match<T0_P, T1_P, T2_P> 514 m_InsertElt(const T0_P &Vec, const T1_P &Val, const T2_P &Idx) { 515 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::INSERT_VECTOR_ELT, Vec, Val, 516 Idx); 517 } 518 519 template <typename LHS, typename RHS, typename IDX> 520 inline TernaryOpc_match<LHS, RHS, IDX> 521 m_InsertSubvector(const LHS &Base, const RHS &Sub, const IDX &Idx) { 522 return TernaryOpc_match<LHS, RHS, IDX>(ISD::INSERT_SUBVECTOR, Base, Sub, Idx); 523 } 524 525 // === Binary operations === 526 template <typename LHS_P, typename RHS_P, bool Commutable = false, 527 bool ExcludeChain = false> 528 struct BinaryOpc_match { 529 unsigned Opcode; 530 LHS_P LHS; 531 RHS_P RHS; 532 std::optional<SDNodeFlags> Flags; 533 BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R, 534 std::optional<SDNodeFlags> Flgs = std::nullopt) 535 : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {} 536 537 template <typename MatchContext> 538 bool match(const MatchContext &Ctx, SDValue N) { 539 if (sd_context_match(N, Ctx, m_Opc(Opcode))) { 540 EffectiveOperands<ExcludeChain> EO(N, Ctx); 541 assert(EO.Size == 2); 542 if (!((LHS.match(Ctx, N->getOperand(EO.FirstIndex)) && 543 RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) || 544 (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) && 545 RHS.match(Ctx, N->getOperand(EO.FirstIndex))))) 546 return false; 547 548 if (!Flags.has_value()) 549 return true; 550 551 return (*Flags & N->getFlags()) == *Flags; 552 } 553 554 return false; 555 } 556 }; 557 558 /// Matching while capturing mask 559 template <typename T0, typename T1, typename T2> struct SDShuffle_match { 560 T0 Op1; 561 T1 Op2; 562 T2 Mask; 563 564 SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask) 565 : Op1(Op1), Op2(Op2), Mask(Mask) {} 566 567 template <typename MatchContext> 568 bool match(const MatchContext &Ctx, SDValue N) { 569 if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) { 570 return Op1.match(Ctx, I->getOperand(0)) && 571 Op2.match(Ctx, I->getOperand(1)) && Mask.match(I->getMask()); 572 } 573 return false; 574 } 575 }; 576 struct m_Mask { 577 ArrayRef<int> &MaskRef; 578 m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {} 579 bool match(ArrayRef<int> Mask) { 580 MaskRef = Mask; 581 return true; 582 } 583 }; 584 585 struct m_SpecificMask { 586 ArrayRef<int> MaskRef; 587 m_SpecificMask(ArrayRef<int> MaskRef) : MaskRef(MaskRef) {} 588 bool match(ArrayRef<int> Mask) { return MaskRef == Mask; } 589 }; 590 591 template <typename LHS_P, typename RHS_P, typename Pred_t, 592 bool Commutable = false, bool ExcludeChain = false> 593 struct MaxMin_match { 594 using PredType = Pred_t; 595 LHS_P LHS; 596 RHS_P RHS; 597 598 MaxMin_match(const LHS_P &L, const RHS_P &R) : LHS(L), RHS(R) {} 599 600 template <typename MatchContext> 601 bool match(const MatchContext &Ctx, SDValue N) { 602 if (sd_context_match(N, Ctx, m_Opc(ISD::SELECT)) || 603 sd_context_match(N, Ctx, m_Opc(ISD::VSELECT))) { 604 EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx); 605 assert(EO_SELECT.Size == 3); 606 SDValue Cond = N->getOperand(EO_SELECT.FirstIndex); 607 SDValue TrueValue = N->getOperand(EO_SELECT.FirstIndex + 1); 608 SDValue FalseValue = N->getOperand(EO_SELECT.FirstIndex + 2); 609 610 if (sd_context_match(Cond, Ctx, m_Opc(ISD::SETCC))) { 611 EffectiveOperands<ExcludeChain> EO_SETCC(Cond, Ctx); 612 assert(EO_SETCC.Size == 3); 613 SDValue L = Cond->getOperand(EO_SETCC.FirstIndex); 614 SDValue R = Cond->getOperand(EO_SETCC.FirstIndex + 1); 615 auto *CondNode = 616 cast<CondCodeSDNode>(Cond->getOperand(EO_SETCC.FirstIndex + 2)); 617 618 if ((TrueValue != L || FalseValue != R) && 619 (TrueValue != R || FalseValue != L)) { 620 return false; 621 } 622 623 ISD::CondCode Cond = 624 TrueValue == L ? CondNode->get() 625 : getSetCCInverse(CondNode->get(), L.getValueType()); 626 if (!Pred_t::match(Cond)) { 627 return false; 628 } 629 return (LHS.match(Ctx, L) && RHS.match(Ctx, R)) || 630 (Commutable && LHS.match(Ctx, R) && RHS.match(Ctx, L)); 631 } 632 } 633 634 return false; 635 } 636 }; 637 638 // Helper class for identifying signed max predicates. 639 struct smax_pred_ty { 640 static bool match(ISD::CondCode Cond) { 641 return Cond == ISD::CondCode::SETGT || Cond == ISD::CondCode::SETGE; 642 } 643 }; 644 645 // Helper class for identifying unsigned max predicates. 646 struct umax_pred_ty { 647 static bool match(ISD::CondCode Cond) { 648 return Cond == ISD::CondCode::SETUGT || Cond == ISD::CondCode::SETUGE; 649 } 650 }; 651 652 // Helper class for identifying signed min predicates. 653 struct smin_pred_ty { 654 static bool match(ISD::CondCode Cond) { 655 return Cond == ISD::CondCode::SETLT || Cond == ISD::CondCode::SETLE; 656 } 657 }; 658 659 // Helper class for identifying unsigned min predicates. 660 struct umin_pred_ty { 661 static bool match(ISD::CondCode Cond) { 662 return Cond == ISD::CondCode::SETULT || Cond == ISD::CondCode::SETULE; 663 } 664 }; 665 666 template <typename LHS, typename RHS> 667 inline BinaryOpc_match<LHS, RHS> m_BinOp(unsigned Opc, const LHS &L, 668 const RHS &R) { 669 return BinaryOpc_match<LHS, RHS>(Opc, L, R); 670 } 671 template <typename LHS, typename RHS> 672 inline BinaryOpc_match<LHS, RHS, true> m_c_BinOp(unsigned Opc, const LHS &L, 673 const RHS &R) { 674 return BinaryOpc_match<LHS, RHS, true>(Opc, L, R); 675 } 676 677 template <typename LHS, typename RHS> 678 inline BinaryOpc_match<LHS, RHS, false, true> 679 m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) { 680 return BinaryOpc_match<LHS, RHS, false, true>(Opc, L, R); 681 } 682 template <typename LHS, typename RHS> 683 inline BinaryOpc_match<LHS, RHS, true, true> 684 m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) { 685 return BinaryOpc_match<LHS, RHS, true, true>(Opc, L, R); 686 } 687 688 // Common binary operations 689 template <typename LHS, typename RHS> 690 inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) { 691 return BinaryOpc_match<LHS, RHS, true>(ISD::ADD, L, R); 692 } 693 694 template <typename LHS, typename RHS> 695 inline BinaryOpc_match<LHS, RHS> m_Sub(const LHS &L, const RHS &R) { 696 return BinaryOpc_match<LHS, RHS>(ISD::SUB, L, R); 697 } 698 699 template <typename LHS, typename RHS> 700 inline BinaryOpc_match<LHS, RHS, true> m_Mul(const LHS &L, const RHS &R) { 701 return BinaryOpc_match<LHS, RHS, true>(ISD::MUL, L, R); 702 } 703 704 template <typename LHS, typename RHS> 705 inline BinaryOpc_match<LHS, RHS, true> m_And(const LHS &L, const RHS &R) { 706 return BinaryOpc_match<LHS, RHS, true>(ISD::AND, L, R); 707 } 708 709 template <typename LHS, typename RHS> 710 inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) { 711 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R); 712 } 713 714 template <typename LHS, typename RHS> 715 inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L, 716 const RHS &R) { 717 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, SDNodeFlags::Disjoint); 718 } 719 720 template <typename LHS, typename RHS> 721 inline auto m_AddLike(const LHS &L, const RHS &R) { 722 return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R)); 723 } 724 725 template <typename LHS, typename RHS> 726 inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) { 727 return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R); 728 } 729 730 template <typename LHS, typename RHS> 731 inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) { 732 return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R); 733 } 734 735 template <typename LHS, typename RHS> 736 inline auto m_SMinLike(const LHS &L, const RHS &R) { 737 return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R), 738 MaxMin_match<LHS, RHS, smin_pred_ty, true>(L, R)); 739 } 740 741 template <typename LHS, typename RHS> 742 inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) { 743 return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R); 744 } 745 746 template <typename LHS, typename RHS> 747 inline auto m_SMaxLike(const LHS &L, const RHS &R) { 748 return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R), 749 MaxMin_match<LHS, RHS, smax_pred_ty, true>(L, R)); 750 } 751 752 template <typename LHS, typename RHS> 753 inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) { 754 return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R); 755 } 756 757 template <typename LHS, typename RHS> 758 inline auto m_UMinLike(const LHS &L, const RHS &R) { 759 return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R), 760 MaxMin_match<LHS, RHS, umin_pred_ty, true>(L, R)); 761 } 762 763 template <typename LHS, typename RHS> 764 inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) { 765 return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R); 766 } 767 768 template <typename LHS, typename RHS> 769 inline auto m_UMaxLike(const LHS &L, const RHS &R) { 770 return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R), 771 MaxMin_match<LHS, RHS, umax_pred_ty, true>(L, R)); 772 } 773 774 template <typename LHS, typename RHS> 775 inline BinaryOpc_match<LHS, RHS> m_UDiv(const LHS &L, const RHS &R) { 776 return BinaryOpc_match<LHS, RHS>(ISD::UDIV, L, R); 777 } 778 template <typename LHS, typename RHS> 779 inline BinaryOpc_match<LHS, RHS> m_SDiv(const LHS &L, const RHS &R) { 780 return BinaryOpc_match<LHS, RHS>(ISD::SDIV, L, R); 781 } 782 783 template <typename LHS, typename RHS> 784 inline BinaryOpc_match<LHS, RHS> m_URem(const LHS &L, const RHS &R) { 785 return BinaryOpc_match<LHS, RHS>(ISD::UREM, L, R); 786 } 787 template <typename LHS, typename RHS> 788 inline BinaryOpc_match<LHS, RHS> m_SRem(const LHS &L, const RHS &R) { 789 return BinaryOpc_match<LHS, RHS>(ISD::SREM, L, R); 790 } 791 792 template <typename LHS, typename RHS> 793 inline BinaryOpc_match<LHS, RHS> m_Shl(const LHS &L, const RHS &R) { 794 return BinaryOpc_match<LHS, RHS>(ISD::SHL, L, R); 795 } 796 797 template <typename LHS, typename RHS> 798 inline BinaryOpc_match<LHS, RHS> m_Sra(const LHS &L, const RHS &R) { 799 return BinaryOpc_match<LHS, RHS>(ISD::SRA, L, R); 800 } 801 template <typename LHS, typename RHS> 802 inline BinaryOpc_match<LHS, RHS> m_Srl(const LHS &L, const RHS &R) { 803 return BinaryOpc_match<LHS, RHS>(ISD::SRL, L, R); 804 } 805 806 template <typename LHS, typename RHS> 807 inline BinaryOpc_match<LHS, RHS> m_Rotl(const LHS &L, const RHS &R) { 808 return BinaryOpc_match<LHS, RHS>(ISD::ROTL, L, R); 809 } 810 811 template <typename LHS, typename RHS> 812 inline BinaryOpc_match<LHS, RHS> m_Rotr(const LHS &L, const RHS &R) { 813 return BinaryOpc_match<LHS, RHS>(ISD::ROTR, L, R); 814 } 815 816 template <typename LHS, typename RHS> 817 inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) { 818 return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R); 819 } 820 821 template <typename LHS, typename RHS> 822 inline BinaryOpc_match<LHS, RHS> m_FSub(const LHS &L, const RHS &R) { 823 return BinaryOpc_match<LHS, RHS>(ISD::FSUB, L, R); 824 } 825 826 template <typename LHS, typename RHS> 827 inline BinaryOpc_match<LHS, RHS, true> m_FMul(const LHS &L, const RHS &R) { 828 return BinaryOpc_match<LHS, RHS, true>(ISD::FMUL, L, R); 829 } 830 831 template <typename LHS, typename RHS> 832 inline BinaryOpc_match<LHS, RHS> m_FDiv(const LHS &L, const RHS &R) { 833 return BinaryOpc_match<LHS, RHS>(ISD::FDIV, L, R); 834 } 835 836 template <typename LHS, typename RHS> 837 inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) { 838 return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R); 839 } 840 841 template <typename V1_t, typename V2_t> 842 inline BinaryOpc_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2) { 843 return BinaryOpc_match<V1_t, V2_t>(ISD::VECTOR_SHUFFLE, v1, v2); 844 } 845 846 template <typename V1_t, typename V2_t, typename Mask_t> 847 inline SDShuffle_match<V1_t, V2_t, Mask_t> 848 m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) { 849 return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask); 850 } 851 852 template <typename LHS, typename RHS> 853 inline BinaryOpc_match<LHS, RHS> m_ExtractElt(const LHS &Vec, const RHS &Idx) { 854 return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_VECTOR_ELT, Vec, Idx); 855 } 856 857 template <typename LHS, typename RHS> 858 inline BinaryOpc_match<LHS, RHS> m_ExtractSubvector(const LHS &Vec, 859 const RHS &Idx) { 860 return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_SUBVECTOR, Vec, Idx); 861 } 862 863 // === Unary operations === 864 template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match { 865 unsigned Opcode; 866 Opnd_P Opnd; 867 std::optional<SDNodeFlags> Flags; 868 UnaryOpc_match(unsigned Opc, const Opnd_P &Op, 869 std::optional<SDNodeFlags> Flgs = std::nullopt) 870 : Opcode(Opc), Opnd(Op), Flags(Flgs) {} 871 872 template <typename MatchContext> 873 bool match(const MatchContext &Ctx, SDValue N) { 874 if (sd_context_match(N, Ctx, m_Opc(Opcode))) { 875 EffectiveOperands<ExcludeChain> EO(N, Ctx); 876 assert(EO.Size == 1); 877 if (!Opnd.match(Ctx, N->getOperand(EO.FirstIndex))) 878 return false; 879 if (!Flags.has_value()) 880 return true; 881 882 return (*Flags & N->getFlags()) == *Flags; 883 } 884 885 return false; 886 } 887 }; 888 889 template <typename Opnd> 890 inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) { 891 return UnaryOpc_match<Opnd>(Opc, Op); 892 } 893 template <typename Opnd> 894 inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc, 895 const Opnd &Op) { 896 return UnaryOpc_match<Opnd, true>(Opc, Op); 897 } 898 899 template <typename Opnd> inline UnaryOpc_match<Opnd> m_BitCast(const Opnd &Op) { 900 return UnaryOpc_match<Opnd>(ISD::BITCAST, Op); 901 } 902 903 template <typename Opnd> 904 inline UnaryOpc_match<Opnd> m_BSwap(const Opnd &Op) { 905 return UnaryOpc_match<Opnd>(ISD::BSWAP, Op); 906 } 907 908 template <typename Opnd> 909 inline UnaryOpc_match<Opnd> m_BitReverse(const Opnd &Op) { 910 return UnaryOpc_match<Opnd>(ISD::BITREVERSE, Op); 911 } 912 913 template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) { 914 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op); 915 } 916 917 template <typename Opnd> 918 inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) { 919 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, SDNodeFlags::NonNeg); 920 } 921 922 template <typename Opnd> inline auto m_SExt(const Opnd &Op) { 923 return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op); 924 } 925 926 template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) { 927 return UnaryOpc_match<Opnd>(ISD::ANY_EXTEND, Op); 928 } 929 930 template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) { 931 return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op); 932 } 933 934 /// Match a zext or identity 935 /// Allows to peek through optional extensions 936 template <typename Opnd> inline auto m_ZExtOrSelf(const Opnd &Op) { 937 return m_AnyOf(m_ZExt(Op), Op); 938 } 939 940 /// Match a sext or identity 941 /// Allows to peek through optional extensions 942 template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) { 943 return m_AnyOf(m_SExt(Op), Op); 944 } 945 946 template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) { 947 return m_AnyOf(m_SExt(Op), m_NNegZExt(Op)); 948 } 949 950 /// Match a aext or identity 951 /// Allows to peek through optional extensions 952 template <typename Opnd> 953 inline Or<UnaryOpc_match<Opnd>, Opnd> m_AExtOrSelf(const Opnd &Op) { 954 return Or<UnaryOpc_match<Opnd>, Opnd>(m_AnyExt(Op), Op); 955 } 956 957 /// Match a trunc or identity 958 /// Allows to peek through optional truncations 959 template <typename Opnd> 960 inline Or<UnaryOpc_match<Opnd>, Opnd> m_TruncOrSelf(const Opnd &Op) { 961 return Or<UnaryOpc_match<Opnd>, Opnd>(m_Trunc(Op), Op); 962 } 963 964 template <typename Opnd> inline UnaryOpc_match<Opnd> m_VScale(const Opnd &Op) { 965 return UnaryOpc_match<Opnd>(ISD::VSCALE, Op); 966 } 967 968 template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToUI(const Opnd &Op) { 969 return UnaryOpc_match<Opnd>(ISD::FP_TO_UINT, Op); 970 } 971 972 template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToSI(const Opnd &Op) { 973 return UnaryOpc_match<Opnd>(ISD::FP_TO_SINT, Op); 974 } 975 976 template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctpop(const Opnd &Op) { 977 return UnaryOpc_match<Opnd>(ISD::CTPOP, Op); 978 } 979 980 template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctlz(const Opnd &Op) { 981 return UnaryOpc_match<Opnd>(ISD::CTLZ, Op); 982 } 983 984 template <typename Opnd> inline UnaryOpc_match<Opnd> m_Cttz(const Opnd &Op) { 985 return UnaryOpc_match<Opnd>(ISD::CTTZ, Op); 986 } 987 988 // === Constants === 989 struct ConstantInt_match { 990 APInt *BindVal; 991 992 explicit ConstantInt_match(APInt *V) : BindVal(V) {} 993 994 template <typename MatchContext> bool match(const MatchContext &, SDValue N) { 995 // The logics here are similar to that in 996 // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also 997 // treats GlobalAddressSDNode as a constant, which is difficult to turn into 998 // APInt. 999 if (auto *C = dyn_cast_or_null<ConstantSDNode>(N.getNode())) { 1000 if (BindVal) 1001 *BindVal = C->getAPIntValue(); 1002 return true; 1003 } 1004 1005 APInt Discard; 1006 return ISD::isConstantSplatVector(N.getNode(), 1007 BindVal ? *BindVal : Discard); 1008 } 1009 }; 1010 /// Match any interger constants or splat of an integer constant. 1011 inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); } 1012 /// Match any interger constants or splat of an integer constant; return the 1013 /// specific constant or constant splat value. 1014 inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); } 1015 1016 struct SpecificInt_match { 1017 APInt IntVal; 1018 1019 explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {} 1020 1021 template <typename MatchContext> 1022 bool match(const MatchContext &Ctx, SDValue N) { 1023 APInt ConstInt; 1024 if (sd_context_match(N, Ctx, m_ConstInt(ConstInt))) 1025 return APInt::isSameValue(IntVal, ConstInt); 1026 return false; 1027 } 1028 }; 1029 1030 /// Match a specific integer constant or constant splat value. 1031 inline SpecificInt_match m_SpecificInt(APInt V) { 1032 return SpecificInt_match(std::move(V)); 1033 } 1034 inline SpecificInt_match m_SpecificInt(uint64_t V) { 1035 return SpecificInt_match(APInt(64, V)); 1036 } 1037 1038 inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); } 1039 inline SpecificInt_match m_One() { return m_SpecificInt(1U); } 1040 1041 struct AllOnes_match { 1042 1043 AllOnes_match() = default; 1044 1045 template <typename MatchContext> bool match(const MatchContext &, SDValue N) { 1046 return isAllOnesOrAllOnesSplat(N); 1047 } 1048 }; 1049 1050 inline AllOnes_match m_AllOnes() { return AllOnes_match(); } 1051 1052 /// Match true boolean value based on the information provided by 1053 /// TargetLowering. 1054 inline auto m_True() { 1055 return TLI_pred_match{ 1056 [](const TargetLowering &TLI, SDValue N) { 1057 APInt ConstVal; 1058 if (sd_match(N, m_ConstInt(ConstVal))) 1059 switch (TLI.getBooleanContents(N.getValueType())) { 1060 case TargetLowering::ZeroOrOneBooleanContent: 1061 return ConstVal.isOne(); 1062 case TargetLowering::ZeroOrNegativeOneBooleanContent: 1063 return ConstVal.isAllOnes(); 1064 case TargetLowering::UndefinedBooleanContent: 1065 return (ConstVal & 0x01) == 1; 1066 } 1067 1068 return false; 1069 }, 1070 m_Value()}; 1071 } 1072 /// Match false boolean value based on the information provided by 1073 /// TargetLowering. 1074 inline auto m_False() { 1075 return TLI_pred_match{ 1076 [](const TargetLowering &TLI, SDValue N) { 1077 APInt ConstVal; 1078 if (sd_match(N, m_ConstInt(ConstVal))) 1079 switch (TLI.getBooleanContents(N.getValueType())) { 1080 case TargetLowering::ZeroOrOneBooleanContent: 1081 case TargetLowering::ZeroOrNegativeOneBooleanContent: 1082 return ConstVal.isZero(); 1083 case TargetLowering::UndefinedBooleanContent: 1084 return (ConstVal & 0x01) == 0; 1085 } 1086 1087 return false; 1088 }, 1089 m_Value()}; 1090 } 1091 1092 struct CondCode_match { 1093 std::optional<ISD::CondCode> CCToMatch; 1094 ISD::CondCode *BindCC = nullptr; 1095 1096 explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {} 1097 1098 explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {} 1099 1100 template <typename MatchContext> bool match(const MatchContext &, SDValue N) { 1101 if (auto *CC = dyn_cast<CondCodeSDNode>(N.getNode())) { 1102 if (CCToMatch && *CCToMatch != CC->get()) 1103 return false; 1104 1105 if (BindCC) 1106 *BindCC = CC->get(); 1107 return true; 1108 } 1109 1110 return false; 1111 } 1112 }; 1113 1114 /// Match any conditional code SDNode. 1115 inline CondCode_match m_CondCode() { return CondCode_match(nullptr); } 1116 /// Match any conditional code SDNode and return its ISD::CondCode value. 1117 inline CondCode_match m_CondCode(ISD::CondCode &CC) { 1118 return CondCode_match(&CC); 1119 } 1120 /// Match a conditional code SDNode with a specific ISD::CondCode. 1121 inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) { 1122 return CondCode_match(CC); 1123 } 1124 1125 /// Match a negate as a sub(0, v) 1126 template <typename ValTy> 1127 inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) { 1128 return m_Sub(m_Zero(), V); 1129 } 1130 1131 /// Match a Not as a xor(v, -1) or xor(-1, v) 1132 template <typename ValTy> 1133 inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) { 1134 return m_Xor(V, m_AllOnes()); 1135 } 1136 1137 } // namespace SDPatternMatch 1138 } // namespace llvm 1139 #endif 1140