1 //===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides a simple and efficient mechanism for performing general 10 // tree-based pattern matches on the VPlan values and recipes, based on 11 // LLVM's IR pattern matchers. 12 // 13 // Currently it provides generic matchers for unary and binary VPInstructions, 14 // and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond, 15 // m_BranchOnCount to match specific VPInstructions. 16 // TODO: Add missing matchers for additional opcodes and recipes as needed. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H 21 #define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H 22 23 #include "VPlan.h" 24 25 namespace llvm { 26 namespace VPlanPatternMatch { 27 28 template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) { 29 return const_cast<Pattern &>(P).match(V); 30 } 31 32 template <typename Pattern> bool match(VPUser *U, const Pattern &P) { 33 auto *R = dyn_cast<VPRecipeBase>(U); 34 return R && match(R, P); 35 } 36 37 template <typename Class> struct class_match { 38 template <typename ITy> bool match(ITy *V) { return isa<Class>(V); } 39 }; 40 41 /// Match an arbitrary VPValue and ignore it. 42 inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); } 43 44 template <typename Class> struct bind_ty { 45 Class *&VR; 46 47 bind_ty(Class *&V) : VR(V) {} 48 49 template <typename ITy> bool match(ITy *V) { 50 if (auto *CV = dyn_cast<Class>(V)) { 51 VR = CV; 52 return true; 53 } 54 return false; 55 } 56 }; 57 58 /// Match a specified integer value or vector of all elements of that 59 /// value. \p BitWidth optionally specifies the bitwidth the matched constant 60 /// must have. If it is 0, the matched constant can have any bitwidth. 61 template <unsigned BitWidth = 0> struct specific_intval { 62 APInt Val; 63 64 specific_intval(APInt V) : Val(std::move(V)) {} 65 66 bool match(VPValue *VPV) { 67 if (!VPV->isLiveIn()) 68 return false; 69 Value *V = VPV->getLiveInIRValue(); 70 const auto *CI = dyn_cast<ConstantInt>(V); 71 if (!CI && V->getType()->isVectorTy()) 72 if (const auto *C = dyn_cast<Constant>(V)) 73 CI = dyn_cast_or_null<ConstantInt>( 74 C->getSplatValue(/*AllowPoison=*/false)); 75 if (!CI) 76 return false; 77 78 assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) && 79 "Trying the match constant with unexpected bitwidth."); 80 return APInt::isSameValue(CI->getValue(), Val); 81 } 82 }; 83 84 inline specific_intval<0> m_SpecificInt(uint64_t V) { 85 return specific_intval<0>(APInt(64, V)); 86 } 87 88 inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); } 89 90 /// Matching combinators 91 template <typename LTy, typename RTy> struct match_combine_or { 92 LTy L; 93 RTy R; 94 95 match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} 96 97 template <typename ITy> bool match(ITy *V) { 98 if (L.match(V)) 99 return true; 100 if (R.match(V)) 101 return true; 102 return false; 103 } 104 }; 105 106 template <typename LTy, typename RTy> 107 inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) { 108 return match_combine_or<LTy, RTy>(L, R); 109 } 110 111 /// Match a VPValue, capturing it if we match. 112 inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; } 113 114 namespace detail { 115 116 /// A helper to match an opcode against multiple recipe types. 117 template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {}; 118 119 template <unsigned Opcode, typename RecipeTy> 120 struct MatchRecipeAndOpcode<Opcode, RecipeTy> { 121 static bool match(const VPRecipeBase *R) { 122 auto *DefR = dyn_cast<RecipeTy>(R); 123 return DefR && DefR->getOpcode() == Opcode; 124 } 125 }; 126 127 template <unsigned Opcode, typename RecipeTy, typename... RecipeTys> 128 struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> { 129 static bool match(const VPRecipeBase *R) { 130 return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) || 131 MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R); 132 } 133 }; 134 } // namespace detail 135 136 template <typename Op0_t, unsigned Opcode, typename... RecipeTys> 137 struct UnaryRecipe_match { 138 Op0_t Op0; 139 140 UnaryRecipe_match(Op0_t Op0) : Op0(Op0) {} 141 142 bool match(const VPValue *V) { 143 auto *DefR = V->getDefiningRecipe(); 144 return DefR && match(DefR); 145 } 146 147 bool match(const VPRecipeBase *R) { 148 if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R)) 149 return false; 150 assert(R->getNumOperands() == 1 && 151 "recipe with matched opcode does not have 1 operands"); 152 return Op0.match(R->getOperand(0)); 153 } 154 }; 155 156 template <typename Op0_t, unsigned Opcode> 157 using UnaryVPInstruction_match = 158 UnaryRecipe_match<Op0_t, Opcode, VPInstruction>; 159 160 template <typename Op0_t, unsigned Opcode> 161 using AllUnaryRecipe_match = 162 UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe, 163 VPWidenCastRecipe, VPInstruction>; 164 165 template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative, 166 typename... RecipeTys> 167 struct BinaryRecipe_match { 168 Op0_t Op0; 169 Op1_t Op1; 170 171 BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {} 172 173 bool match(const VPValue *V) { 174 auto *DefR = V->getDefiningRecipe(); 175 return DefR && match(DefR); 176 } 177 178 bool match(const VPSingleDefRecipe *R) { 179 return match(static_cast<const VPRecipeBase *>(R)); 180 } 181 182 bool match(const VPRecipeBase *R) { 183 if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R)) 184 return false; 185 assert(R->getNumOperands() == 2 && 186 "recipe with matched opcode does not have 2 operands"); 187 if (Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1))) 188 return true; 189 return Commutative && Op0.match(R->getOperand(1)) && 190 Op1.match(R->getOperand(0)); 191 } 192 }; 193 194 template <typename Op0_t, typename Op1_t, unsigned Opcode> 195 using BinaryVPInstruction_match = 196 BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false, 197 VPInstruction>; 198 199 template <typename Op0_t, typename Op1_t, unsigned Opcode, 200 bool Commutative = false> 201 using AllBinaryRecipe_match = 202 BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe, 203 VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>; 204 205 template <unsigned Opcode, typename Op0_t> 206 inline UnaryVPInstruction_match<Op0_t, Opcode> 207 m_VPInstruction(const Op0_t &Op0) { 208 return UnaryVPInstruction_match<Op0_t, Opcode>(Op0); 209 } 210 211 template <unsigned Opcode, typename Op0_t, typename Op1_t> 212 inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode> 213 m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) { 214 return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1); 215 } 216 217 template <typename Op0_t> 218 inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not> 219 m_Not(const Op0_t &Op0) { 220 return m_VPInstruction<VPInstruction::Not>(Op0); 221 } 222 223 template <typename Op0_t> 224 inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond> 225 m_BranchOnCond(const Op0_t &Op0) { 226 return m_VPInstruction<VPInstruction::BranchOnCond>(Op0); 227 } 228 229 template <typename Op0_t, typename Op1_t> 230 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask> 231 m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) { 232 return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1); 233 } 234 235 template <typename Op0_t, typename Op1_t> 236 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount> 237 m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) { 238 return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1); 239 } 240 241 template <unsigned Opcode, typename Op0_t> 242 inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) { 243 return AllUnaryRecipe_match<Op0_t, Opcode>(Op0); 244 } 245 246 template <typename Op0_t> 247 inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc> 248 m_Trunc(const Op0_t &Op0) { 249 return m_Unary<Instruction::Trunc, Op0_t>(Op0); 250 } 251 252 template <typename Op0_t> 253 inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) { 254 return m_Unary<Instruction::ZExt, Op0_t>(Op0); 255 } 256 257 template <typename Op0_t> 258 inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) { 259 return m_Unary<Instruction::SExt, Op0_t>(Op0); 260 } 261 262 template <typename Op0_t> 263 inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>, 264 AllUnaryRecipe_match<Op0_t, Instruction::SExt>> 265 m_ZExtOrSExt(const Op0_t &Op0) { 266 return m_CombineOr(m_ZExt(Op0), m_SExt(Op0)); 267 } 268 269 template <unsigned Opcode, typename Op0_t, typename Op1_t, 270 bool Commutative = false> 271 inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative> 272 m_Binary(const Op0_t &Op0, const Op1_t &Op1) { 273 return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1); 274 } 275 276 template <typename Op0_t, typename Op1_t> 277 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul> 278 m_Mul(const Op0_t &Op0, const Op1_t &Op1) { 279 return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1); 280 } 281 282 template <typename Op0_t, typename Op1_t> 283 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul, 284 /* Commutative =*/true> 285 m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) { 286 return m_Binary<Instruction::Mul, Op0_t, Op1_t, true>(Op0, Op1); 287 } 288 289 /// Match a binary OR operation. Note that while conceptually the operands can 290 /// be matched commutatively, \p Commutative defaults to false in line with the 291 /// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative 292 /// version of the matcher. 293 template <typename Op0_t, typename Op1_t, bool Commutative = false> 294 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative> 295 m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { 296 return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1); 297 } 298 299 template <typename Op0_t, typename Op1_t> 300 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, 301 /*Commutative*/ true> 302 m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { 303 return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1); 304 } 305 306 template <typename Op0_t, typename Op1_t> 307 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd> 308 m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) { 309 return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1); 310 } 311 312 struct VPCanonicalIVPHI_match { 313 bool match(const VPValue *V) { 314 auto *DefR = V->getDefiningRecipe(); 315 return DefR && match(DefR); 316 } 317 318 bool match(const VPRecipeBase *R) { return isa<VPCanonicalIVPHIRecipe>(R); } 319 }; 320 321 inline VPCanonicalIVPHI_match m_CanonicalIV() { 322 return VPCanonicalIVPHI_match(); 323 } 324 325 template <typename Op0_t, typename Op1_t> struct VPScalarIVSteps_match { 326 Op0_t Op0; 327 Op1_t Op1; 328 329 VPScalarIVSteps_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {} 330 331 bool match(const VPValue *V) { 332 auto *DefR = V->getDefiningRecipe(); 333 return DefR && match(DefR); 334 } 335 336 bool match(const VPRecipeBase *R) { 337 if (!isa<VPScalarIVStepsRecipe>(R)) 338 return false; 339 assert(R->getNumOperands() == 2 && 340 "VPScalarIVSteps must have exactly 2 operands"); 341 return Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1)); 342 } 343 }; 344 345 template <typename Op0_t, typename Op1_t> 346 inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0, 347 const Op1_t &Op1) { 348 return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1); 349 } 350 351 } // namespace VPlanPatternMatch 352 } // namespace llvm 353 354 #endif 355