xref: /llvm-project/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (revision df4a615c988f3ae56f7e68a7df86acb60f16493a)
1 //===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matches on the VPlan values and recipes, based on
11 // LLVM's IR pattern matchers.
12 //
13 // Currently it provides generic matchers for unary and binary VPInstructions,
14 // and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond,
15 // m_BranchOnCount to match specific VPInstructions.
16 // TODO: Add missing matchers for additional opcodes and recipes as needed.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
21 #define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
22 
23 #include "VPlan.h"
24 
25 namespace llvm {
26 namespace VPlanPatternMatch {
27 
28 template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
29   return P.match(V);
30 }
31 
32 template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
33   auto *R = dyn_cast<VPRecipeBase>(U);
34   return R && match(R, P);
35 }
36 
37 template <typename Class> struct class_match {
38   template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
39 };
40 
41 /// Match an arbitrary VPValue and ignore it.
42 inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); }
43 
44 template <typename Class> struct bind_ty {
45   Class *&VR;
46 
47   bind_ty(Class *&V) : VR(V) {}
48 
49   template <typename ITy> bool match(ITy *V) const {
50     if (auto *CV = dyn_cast<Class>(V)) {
51       VR = CV;
52       return true;
53     }
54     return false;
55   }
56 };
57 
58 /// Match a specified VPValue.
59 struct specificval_ty {
60   const VPValue *Val;
61 
62   specificval_ty(const VPValue *V) : Val(V) {}
63 
64   bool match(VPValue *VPV) const { return VPV == Val; }
65 };
66 
67 inline specificval_ty m_Specific(const VPValue *VPV) { return VPV; }
68 
69 /// Match a specified integer value or vector of all elements of that
70 /// value. \p BitWidth optionally specifies the bitwidth the matched constant
71 /// must have. If it is 0, the matched constant can have any bitwidth.
72 template <unsigned BitWidth = 0> struct specific_intval {
73   APInt Val;
74 
75   specific_intval(APInt V) : Val(std::move(V)) {}
76 
77   bool match(VPValue *VPV) const {
78     if (!VPV->isLiveIn())
79       return false;
80     Value *V = VPV->getLiveInIRValue();
81     if (!V)
82       return false;
83     const auto *CI = dyn_cast<ConstantInt>(V);
84     if (!CI && V->getType()->isVectorTy())
85       if (const auto *C = dyn_cast<Constant>(V))
86         CI = dyn_cast_or_null<ConstantInt>(
87             C->getSplatValue(/*AllowPoison=*/false));
88     if (!CI)
89       return false;
90 
91     if (BitWidth != 0 && CI->getBitWidth() != BitWidth)
92       return false;
93     return APInt::isSameValue(CI->getValue(), Val);
94   }
95 };
96 
97 inline specific_intval<0> m_SpecificInt(uint64_t V) {
98   return specific_intval<0>(APInt(64, V));
99 }
100 
101 inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); }
102 
103 inline specific_intval<1> m_True() { return specific_intval<1>(APInt(64, 1)); }
104 
105 /// Matching combinators
106 template <typename LTy, typename RTy> struct match_combine_or {
107   LTy L;
108   RTy R;
109 
110   match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {}
111 
112   template <typename ITy> bool match(ITy *V) const {
113     if (L.match(V))
114       return true;
115     if (R.match(V))
116       return true;
117     return false;
118   }
119 };
120 
121 template <typename LTy, typename RTy>
122 inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) {
123   return match_combine_or<LTy, RTy>(L, R);
124 }
125 
126 /// Match a VPValue, capturing it if we match.
127 inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; }
128 
129 namespace detail {
130 
131 /// A helper to match an opcode against multiple recipe types.
132 template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {};
133 
134 template <unsigned Opcode, typename RecipeTy>
135 struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
136   static bool match(const VPRecipeBase *R) {
137     auto *DefR = dyn_cast<RecipeTy>(R);
138     // Check for recipes that do not have opcodes.
139     if constexpr (std::is_same<RecipeTy, VPScalarIVStepsRecipe>::value ||
140                   std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value ||
141                   std::is_same<RecipeTy, VPWidenSelectRecipe>::value ||
142                   std::is_same<RecipeTy, VPDerivedIVRecipe>::value ||
143                   std::is_same<RecipeTy, VPWidenGEPRecipe>::value)
144       return DefR;
145     else
146       return DefR && DefR->getOpcode() == Opcode;
147   }
148 };
149 
150 template <unsigned Opcode, typename RecipeTy, typename... RecipeTys>
151 struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> {
152   static bool match(const VPRecipeBase *R) {
153     return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) ||
154            MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R);
155   }
156 };
157 template <typename TupleTy, typename Fn, std::size_t... Is>
158 bool CheckTupleElements(const TupleTy &Ops, Fn P, std::index_sequence<Is...>) {
159   return (P(std::get<Is>(Ops), Is) && ...);
160 }
161 
162 /// Helper to check if predicate \p P holds on all tuple elements in \p Ops
163 template <typename TupleTy, typename Fn>
164 bool all_of_tuple_elements(const TupleTy &Ops, Fn P) {
165   return CheckTupleElements(
166       Ops, P, std::make_index_sequence<std::tuple_size<TupleTy>::value>{});
167 }
168 } // namespace detail
169 
170 template <typename Ops_t, unsigned Opcode, bool Commutative,
171           typename... RecipeTys>
172 struct Recipe_match {
173   Ops_t Ops;
174 
175   Recipe_match() : Ops() {
176     static_assert(std::tuple_size<Ops_t>::value == 0 &&
177                   "constructor can only be used with zero operands");
178   }
179   Recipe_match(Ops_t Ops) : Ops(Ops) {}
180   template <typename A_t, typename B_t>
181   Recipe_match(A_t A, B_t B) : Ops({A, B}) {
182     static_assert(std::tuple_size<Ops_t>::value == 2 &&
183                   "constructor can only be used for binary matcher");
184   }
185 
186   bool match(const VPValue *V) const {
187     auto *DefR = V->getDefiningRecipe();
188     return DefR && match(DefR);
189   }
190 
191   bool match(const VPSingleDefRecipe *R) const {
192     return match(static_cast<const VPRecipeBase *>(R));
193   }
194 
195   bool match(const VPRecipeBase *R) const {
196     if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
197       return false;
198     assert(R->getNumOperands() == std::tuple_size<Ops_t>::value &&
199            "recipe with matched opcode the expected number of operands");
200 
201     if (detail::all_of_tuple_elements(Ops, [R](auto Op, unsigned Idx) {
202           return Op.match(R->getOperand(Idx));
203         }))
204       return true;
205 
206     return Commutative &&
207            detail::all_of_tuple_elements(Ops, [R](auto Op, unsigned Idx) {
208              return Op.match(R->getOperand(R->getNumOperands() - Idx - 1));
209            });
210   }
211 };
212 
213 template <typename Op0_t, unsigned Opcode, typename... RecipeTys>
214 using UnaryRecipe_match =
215     Recipe_match<std::tuple<Op0_t>, Opcode, false, RecipeTys...>;
216 
217 template <typename Op0_t, unsigned Opcode>
218 using UnaryVPInstruction_match =
219     UnaryRecipe_match<Op0_t, Opcode, VPInstruction>;
220 
221 template <typename Op0_t, unsigned Opcode>
222 using AllUnaryRecipe_match =
223     UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
224                       VPWidenCastRecipe, VPInstruction>;
225 
226 template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative,
227           typename... RecipeTys>
228 using BinaryRecipe_match =
229     Recipe_match<std::tuple<Op0_t, Op1_t>, Opcode, Commutative, RecipeTys...>;
230 
231 template <typename Op0_t, typename Op1_t, unsigned Opcode>
232 using BinaryVPInstruction_match =
233     BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false,
234                        VPInstruction>;
235 
236 template <typename Op0_t, typename Op1_t, unsigned Opcode,
237           bool Commutative = false>
238 using AllBinaryRecipe_match =
239     BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe,
240                        VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>;
241 
242 template <unsigned Opcode, typename Op0_t>
243 inline UnaryVPInstruction_match<Op0_t, Opcode>
244 m_VPInstruction(const Op0_t &Op0) {
245   return UnaryVPInstruction_match<Op0_t, Opcode>(Op0);
246 }
247 
248 template <unsigned Opcode, typename Op0_t, typename Op1_t>
249 inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>
250 m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) {
251   return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
252 }
253 
254 template <typename Op0_t>
255 inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
256 m_Not(const Op0_t &Op0) {
257   return m_VPInstruction<VPInstruction::Not>(Op0);
258 }
259 
260 template <typename Op0_t>
261 inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond>
262 m_BranchOnCond(const Op0_t &Op0) {
263   return m_VPInstruction<VPInstruction::BranchOnCond>(Op0);
264 }
265 
266 template <typename Op0_t, typename Op1_t>
267 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask>
268 m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) {
269   return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1);
270 }
271 
272 template <typename Op0_t, typename Op1_t>
273 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
274 m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
275   return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
276 }
277 
278 template <unsigned Opcode, typename Op0_t>
279 inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) {
280   return AllUnaryRecipe_match<Op0_t, Opcode>(Op0);
281 }
282 
283 template <typename Op0_t>
284 inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc>
285 m_Trunc(const Op0_t &Op0) {
286   return m_Unary<Instruction::Trunc, Op0_t>(Op0);
287 }
288 
289 template <typename Op0_t>
290 inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) {
291   return m_Unary<Instruction::ZExt, Op0_t>(Op0);
292 }
293 
294 template <typename Op0_t>
295 inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) {
296   return m_Unary<Instruction::SExt, Op0_t>(Op0);
297 }
298 
299 template <typename Op0_t>
300 inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>,
301                         AllUnaryRecipe_match<Op0_t, Instruction::SExt>>
302 m_ZExtOrSExt(const Op0_t &Op0) {
303   return m_CombineOr(m_ZExt(Op0), m_SExt(Op0));
304 }
305 
306 template <unsigned Opcode, typename Op0_t, typename Op1_t,
307           bool Commutative = false>
308 inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>
309 m_Binary(const Op0_t &Op0, const Op1_t &Op1) {
310   return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1);
311 }
312 
313 template <unsigned Opcode, typename Op0_t, typename Op1_t>
314 inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, true>
315 m_c_Binary(const Op0_t &Op0, const Op1_t &Op1) {
316   return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, true>(Op0, Op1);
317 }
318 
319 template <typename Op0_t, typename Op1_t>
320 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul>
321 m_Mul(const Op0_t &Op0, const Op1_t &Op1) {
322   return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
323 }
324 
325 template <typename Op0_t, typename Op1_t>
326 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul,
327                              /* Commutative =*/true>
328 m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
329   return m_Binary<Instruction::Mul, Op0_t, Op1_t, true>(Op0, Op1);
330 }
331 
332 /// Match a binary OR operation. Note that while conceptually the operands can
333 /// be matched commutatively, \p Commutative defaults to false in line with the
334 /// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative
335 /// version of the matcher.
336 template <typename Op0_t, typename Op1_t, bool Commutative = false>
337 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative>
338 m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
339   return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1);
340 }
341 
342 template <typename Op0_t, typename Op1_t>
343 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or,
344                              /*Commutative*/ true>
345 m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
346   return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
347 }
348 
349 template <typename Op0_t, typename Op1_t>
350 using GEPLikeRecipe_match =
351     BinaryRecipe_match<Op0_t, Op1_t, Instruction::GetElementPtr, false,
352                        VPWidenRecipe, VPReplicateRecipe, VPWidenGEPRecipe,
353                        VPInstruction>;
354 
355 template <typename Op0_t, typename Op1_t>
356 inline GEPLikeRecipe_match<Op0_t, Op1_t> m_GetElementPtr(const Op0_t &Op0,
357                                                          const Op1_t &Op1) {
358   return GEPLikeRecipe_match<Op0_t, Op1_t>(Op0, Op1);
359 }
360 
361 template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode>
362 using AllTernaryRecipe_match =
363     Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, Opcode, false,
364                  VPReplicateRecipe, VPInstruction, VPWidenSelectRecipe>;
365 
366 template <typename Op0_t, typename Op1_t, typename Op2_t>
367 inline AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>
368 m_Select(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
369   return AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>(
370       {Op0, Op1, Op2});
371 }
372 
373 template <typename Op0_t, typename Op1_t>
374 inline match_combine_or<
375     BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>,
376     AllTernaryRecipe_match<Op0_t, Op1_t, specific_intval<1>,
377                            Instruction::Select>>
378 m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
379   return m_CombineOr(
380       m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1),
381       m_Select(Op0, Op1, m_False()));
382 }
383 
384 template <typename Op0_t, typename Op1_t>
385 inline AllTernaryRecipe_match<Op0_t, specific_intval<1>, Op1_t,
386                               Instruction::Select>
387 m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) {
388   return m_Select(Op0, m_True(), Op1);
389 }
390 
391 using VPCanonicalIVPHI_match =
392     Recipe_match<std::tuple<>, 0, false, VPCanonicalIVPHIRecipe>;
393 
394 inline VPCanonicalIVPHI_match m_CanonicalIV() {
395   return VPCanonicalIVPHI_match();
396 }
397 
398 template <typename Op0_t, typename Op1_t>
399 using VPScalarIVSteps_match =
400     Recipe_match<std::tuple<Op0_t, Op1_t>, 0, false, VPScalarIVStepsRecipe>;
401 
402 template <typename Op0_t, typename Op1_t>
403 inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0,
404                                                            const Op1_t &Op1) {
405   return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1);
406 }
407 
408 template <typename Op0_t, typename Op1_t, typename Op2_t>
409 using VPDerivedIV_match =
410     Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, 0, false, VPDerivedIVRecipe>;
411 
412 template <typename Op0_t, typename Op1_t, typename Op2_t>
413 inline VPDerivedIV_match<Op0_t, Op1_t, Op2_t>
414 m_DerivedIV(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
415   return VPDerivedIV_match<Op0_t, Op1_t, Op2_t>({Op0, Op1, Op2});
416 }
417 
418 } // namespace VPlanPatternMatch
419 } // namespace llvm
420 
421 #endif
422