1 //===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides a simple and efficient mechanism for performing general 10 // tree-based pattern matches on the VPlan values and recipes, based on 11 // LLVM's IR pattern matchers. 12 // 13 // Currently it provides generic matchers for unary and binary VPInstructions, 14 // and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond, 15 // m_BranchOnCount to match specific VPInstructions. 16 // TODO: Add missing matchers for additional opcodes and recipes as needed. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H 21 #define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H 22 23 #include "VPlan.h" 24 25 namespace llvm { 26 namespace VPlanPatternMatch { 27 28 template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) { 29 return P.match(V); 30 } 31 32 template <typename Pattern> bool match(VPUser *U, const Pattern &P) { 33 auto *R = dyn_cast<VPRecipeBase>(U); 34 return R && match(R, P); 35 } 36 37 template <typename Class> struct class_match { 38 template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); } 39 }; 40 41 /// Match an arbitrary VPValue and ignore it. 42 inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); } 43 44 template <typename Class> struct bind_ty { 45 Class *&VR; 46 47 bind_ty(Class *&V) : VR(V) {} 48 49 template <typename ITy> bool match(ITy *V) const { 50 if (auto *CV = dyn_cast<Class>(V)) { 51 VR = CV; 52 return true; 53 } 54 return false; 55 } 56 }; 57 58 /// Match a specified VPValue. 59 struct specificval_ty { 60 const VPValue *Val; 61 62 specificval_ty(const VPValue *V) : Val(V) {} 63 64 bool match(VPValue *VPV) const { return VPV == Val; } 65 }; 66 67 inline specificval_ty m_Specific(const VPValue *VPV) { return VPV; } 68 69 /// Match a specified integer value or vector of all elements of that 70 /// value. \p BitWidth optionally specifies the bitwidth the matched constant 71 /// must have. If it is 0, the matched constant can have any bitwidth. 72 template <unsigned BitWidth = 0> struct specific_intval { 73 APInt Val; 74 75 specific_intval(APInt V) : Val(std::move(V)) {} 76 77 bool match(VPValue *VPV) const { 78 if (!VPV->isLiveIn()) 79 return false; 80 Value *V = VPV->getLiveInIRValue(); 81 if (!V) 82 return false; 83 const auto *CI = dyn_cast<ConstantInt>(V); 84 if (!CI && V->getType()->isVectorTy()) 85 if (const auto *C = dyn_cast<Constant>(V)) 86 CI = dyn_cast_or_null<ConstantInt>( 87 C->getSplatValue(/*AllowPoison=*/false)); 88 if (!CI) 89 return false; 90 91 if (BitWidth != 0 && CI->getBitWidth() != BitWidth) 92 return false; 93 return APInt::isSameValue(CI->getValue(), Val); 94 } 95 }; 96 97 inline specific_intval<0> m_SpecificInt(uint64_t V) { 98 return specific_intval<0>(APInt(64, V)); 99 } 100 101 inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); } 102 103 inline specific_intval<1> m_True() { return specific_intval<1>(APInt(64, 1)); } 104 105 /// Matching combinators 106 template <typename LTy, typename RTy> struct match_combine_or { 107 LTy L; 108 RTy R; 109 110 match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} 111 112 template <typename ITy> bool match(ITy *V) const { 113 if (L.match(V)) 114 return true; 115 if (R.match(V)) 116 return true; 117 return false; 118 } 119 }; 120 121 template <typename LTy, typename RTy> 122 inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) { 123 return match_combine_or<LTy, RTy>(L, R); 124 } 125 126 /// Match a VPValue, capturing it if we match. 127 inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; } 128 129 namespace detail { 130 131 /// A helper to match an opcode against multiple recipe types. 132 template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {}; 133 134 template <unsigned Opcode, typename RecipeTy> 135 struct MatchRecipeAndOpcode<Opcode, RecipeTy> { 136 static bool match(const VPRecipeBase *R) { 137 auto *DefR = dyn_cast<RecipeTy>(R); 138 // Check for recipes that do not have opcodes. 139 if constexpr (std::is_same<RecipeTy, VPScalarIVStepsRecipe>::value || 140 std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value || 141 std::is_same<RecipeTy, VPWidenSelectRecipe>::value || 142 std::is_same<RecipeTy, VPDerivedIVRecipe>::value || 143 std::is_same<RecipeTy, VPWidenGEPRecipe>::value) 144 return DefR; 145 else 146 return DefR && DefR->getOpcode() == Opcode; 147 } 148 }; 149 150 template <unsigned Opcode, typename RecipeTy, typename... RecipeTys> 151 struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> { 152 static bool match(const VPRecipeBase *R) { 153 return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) || 154 MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R); 155 } 156 }; 157 template <typename TupleTy, typename Fn, std::size_t... Is> 158 bool CheckTupleElements(const TupleTy &Ops, Fn P, std::index_sequence<Is...>) { 159 return (P(std::get<Is>(Ops), Is) && ...); 160 } 161 162 /// Helper to check if predicate \p P holds on all tuple elements in \p Ops 163 template <typename TupleTy, typename Fn> 164 bool all_of_tuple_elements(const TupleTy &Ops, Fn P) { 165 return CheckTupleElements( 166 Ops, P, std::make_index_sequence<std::tuple_size<TupleTy>::value>{}); 167 } 168 } // namespace detail 169 170 template <typename Ops_t, unsigned Opcode, bool Commutative, 171 typename... RecipeTys> 172 struct Recipe_match { 173 Ops_t Ops; 174 175 Recipe_match() : Ops() { 176 static_assert(std::tuple_size<Ops_t>::value == 0 && 177 "constructor can only be used with zero operands"); 178 } 179 Recipe_match(Ops_t Ops) : Ops(Ops) {} 180 template <typename A_t, typename B_t> 181 Recipe_match(A_t A, B_t B) : Ops({A, B}) { 182 static_assert(std::tuple_size<Ops_t>::value == 2 && 183 "constructor can only be used for binary matcher"); 184 } 185 186 bool match(const VPValue *V) const { 187 auto *DefR = V->getDefiningRecipe(); 188 return DefR && match(DefR); 189 } 190 191 bool match(const VPSingleDefRecipe *R) const { 192 return match(static_cast<const VPRecipeBase *>(R)); 193 } 194 195 bool match(const VPRecipeBase *R) const { 196 if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R)) 197 return false; 198 assert(R->getNumOperands() == std::tuple_size<Ops_t>::value && 199 "recipe with matched opcode the expected number of operands"); 200 201 if (detail::all_of_tuple_elements(Ops, [R](auto Op, unsigned Idx) { 202 return Op.match(R->getOperand(Idx)); 203 })) 204 return true; 205 206 return Commutative && 207 detail::all_of_tuple_elements(Ops, [R](auto Op, unsigned Idx) { 208 return Op.match(R->getOperand(R->getNumOperands() - Idx - 1)); 209 }); 210 } 211 }; 212 213 template <typename Op0_t, unsigned Opcode, typename... RecipeTys> 214 using UnaryRecipe_match = 215 Recipe_match<std::tuple<Op0_t>, Opcode, false, RecipeTys...>; 216 217 template <typename Op0_t, unsigned Opcode> 218 using UnaryVPInstruction_match = 219 UnaryRecipe_match<Op0_t, Opcode, VPInstruction>; 220 221 template <typename Op0_t, unsigned Opcode> 222 using AllUnaryRecipe_match = 223 UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe, 224 VPWidenCastRecipe, VPInstruction>; 225 226 template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative, 227 typename... RecipeTys> 228 using BinaryRecipe_match = 229 Recipe_match<std::tuple<Op0_t, Op1_t>, Opcode, Commutative, RecipeTys...>; 230 231 template <typename Op0_t, typename Op1_t, unsigned Opcode> 232 using BinaryVPInstruction_match = 233 BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false, 234 VPInstruction>; 235 236 template <typename Op0_t, typename Op1_t, unsigned Opcode, 237 bool Commutative = false> 238 using AllBinaryRecipe_match = 239 BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe, 240 VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>; 241 242 template <unsigned Opcode, typename Op0_t> 243 inline UnaryVPInstruction_match<Op0_t, Opcode> 244 m_VPInstruction(const Op0_t &Op0) { 245 return UnaryVPInstruction_match<Op0_t, Opcode>(Op0); 246 } 247 248 template <unsigned Opcode, typename Op0_t, typename Op1_t> 249 inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode> 250 m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) { 251 return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1); 252 } 253 254 template <typename Op0_t> 255 inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not> 256 m_Not(const Op0_t &Op0) { 257 return m_VPInstruction<VPInstruction::Not>(Op0); 258 } 259 260 template <typename Op0_t> 261 inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond> 262 m_BranchOnCond(const Op0_t &Op0) { 263 return m_VPInstruction<VPInstruction::BranchOnCond>(Op0); 264 } 265 266 template <typename Op0_t, typename Op1_t> 267 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask> 268 m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) { 269 return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1); 270 } 271 272 template <typename Op0_t, typename Op1_t> 273 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount> 274 m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) { 275 return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1); 276 } 277 278 template <unsigned Opcode, typename Op0_t> 279 inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) { 280 return AllUnaryRecipe_match<Op0_t, Opcode>(Op0); 281 } 282 283 template <typename Op0_t> 284 inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc> 285 m_Trunc(const Op0_t &Op0) { 286 return m_Unary<Instruction::Trunc, Op0_t>(Op0); 287 } 288 289 template <typename Op0_t> 290 inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) { 291 return m_Unary<Instruction::ZExt, Op0_t>(Op0); 292 } 293 294 template <typename Op0_t> 295 inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) { 296 return m_Unary<Instruction::SExt, Op0_t>(Op0); 297 } 298 299 template <typename Op0_t> 300 inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>, 301 AllUnaryRecipe_match<Op0_t, Instruction::SExt>> 302 m_ZExtOrSExt(const Op0_t &Op0) { 303 return m_CombineOr(m_ZExt(Op0), m_SExt(Op0)); 304 } 305 306 template <unsigned Opcode, typename Op0_t, typename Op1_t, 307 bool Commutative = false> 308 inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative> 309 m_Binary(const Op0_t &Op0, const Op1_t &Op1) { 310 return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1); 311 } 312 313 template <unsigned Opcode, typename Op0_t, typename Op1_t> 314 inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, true> 315 m_c_Binary(const Op0_t &Op0, const Op1_t &Op1) { 316 return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, true>(Op0, Op1); 317 } 318 319 template <typename Op0_t, typename Op1_t> 320 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul> 321 m_Mul(const Op0_t &Op0, const Op1_t &Op1) { 322 return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1); 323 } 324 325 template <typename Op0_t, typename Op1_t> 326 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul, 327 /* Commutative =*/true> 328 m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) { 329 return m_Binary<Instruction::Mul, Op0_t, Op1_t, true>(Op0, Op1); 330 } 331 332 /// Match a binary OR operation. Note that while conceptually the operands can 333 /// be matched commutatively, \p Commutative defaults to false in line with the 334 /// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative 335 /// version of the matcher. 336 template <typename Op0_t, typename Op1_t, bool Commutative = false> 337 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative> 338 m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { 339 return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1); 340 } 341 342 template <typename Op0_t, typename Op1_t> 343 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, 344 /*Commutative*/ true> 345 m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { 346 return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1); 347 } 348 349 template <typename Op0_t, typename Op1_t> 350 using GEPLikeRecipe_match = 351 BinaryRecipe_match<Op0_t, Op1_t, Instruction::GetElementPtr, false, 352 VPWidenRecipe, VPReplicateRecipe, VPWidenGEPRecipe, 353 VPInstruction>; 354 355 template <typename Op0_t, typename Op1_t> 356 inline GEPLikeRecipe_match<Op0_t, Op1_t> m_GetElementPtr(const Op0_t &Op0, 357 const Op1_t &Op1) { 358 return GEPLikeRecipe_match<Op0_t, Op1_t>(Op0, Op1); 359 } 360 361 template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode> 362 using AllTernaryRecipe_match = 363 Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, Opcode, false, 364 VPReplicateRecipe, VPInstruction, VPWidenSelectRecipe>; 365 366 template <typename Op0_t, typename Op1_t, typename Op2_t> 367 inline AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select> 368 m_Select(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { 369 return AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>( 370 {Op0, Op1, Op2}); 371 } 372 373 template <typename Op0_t, typename Op1_t> 374 inline match_combine_or< 375 BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>, 376 AllTernaryRecipe_match<Op0_t, Op1_t, specific_intval<1>, 377 Instruction::Select>> 378 m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) { 379 return m_CombineOr( 380 m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1), 381 m_Select(Op0, Op1, m_False())); 382 } 383 384 template <typename Op0_t, typename Op1_t> 385 inline AllTernaryRecipe_match<Op0_t, specific_intval<1>, Op1_t, 386 Instruction::Select> 387 m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) { 388 return m_Select(Op0, m_True(), Op1); 389 } 390 391 using VPCanonicalIVPHI_match = 392 Recipe_match<std::tuple<>, 0, false, VPCanonicalIVPHIRecipe>; 393 394 inline VPCanonicalIVPHI_match m_CanonicalIV() { 395 return VPCanonicalIVPHI_match(); 396 } 397 398 template <typename Op0_t, typename Op1_t> 399 using VPScalarIVSteps_match = 400 Recipe_match<std::tuple<Op0_t, Op1_t>, 0, false, VPScalarIVStepsRecipe>; 401 402 template <typename Op0_t, typename Op1_t> 403 inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0, 404 const Op1_t &Op1) { 405 return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1); 406 } 407 408 template <typename Op0_t, typename Op1_t, typename Op2_t> 409 using VPDerivedIV_match = 410 Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, 0, false, VPDerivedIVRecipe>; 411 412 template <typename Op0_t, typename Op1_t, typename Op2_t> 413 inline VPDerivedIV_match<Op0_t, Op1_t, Op2_t> 414 m_DerivedIV(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { 415 return VPDerivedIV_match<Op0_t, Op1_t, Op2_t>({Op0, Op1, Op2}); 416 } 417 418 } // namespace VPlanPatternMatch 419 } // namespace llvm 420 421 #endif 422