xref: /llvm-project/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (revision 99741ac28522f519713907d7bea4438ea5412e21)
1 //===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matches on the VPlan values and recipes, based on
11 // LLVM's IR pattern matchers.
12 //
13 // Currently it provides generic matchers for unary and binary VPInstructions,
14 // and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond,
15 // m_BranchOnCount to match specific VPInstructions.
16 // TODO: Add missing matchers for additional opcodes and recipes as needed.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
21 #define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
22 
23 #include "VPlan.h"
24 
25 namespace llvm {
26 namespace VPlanPatternMatch {
27 
28 template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
29   return const_cast<Pattern &>(P).match(V);
30 }
31 
32 template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
33   auto *R = dyn_cast<VPRecipeBase>(U);
34   return R && match(R, P);
35 }
36 
37 template <typename Class> struct class_match {
38   template <typename ITy> bool match(ITy *V) { return isa<Class>(V); }
39 };
40 
41 /// Match an arbitrary VPValue and ignore it.
42 inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); }
43 
44 template <typename Class> struct bind_ty {
45   Class *&VR;
46 
47   bind_ty(Class *&V) : VR(V) {}
48 
49   template <typename ITy> bool match(ITy *V) {
50     if (auto *CV = dyn_cast<Class>(V)) {
51       VR = CV;
52       return true;
53     }
54     return false;
55   }
56 };
57 
58 /// Match a specified integer value or vector of all elements of that
59 /// value. \p BitWidth optionally specifies the bitwidth the matched constant
60 /// must have. If it is 0, the matched constant can have any bitwidth.
61 template <unsigned BitWidth = 0> struct specific_intval {
62   APInt Val;
63 
64   specific_intval(APInt V) : Val(std::move(V)) {}
65 
66   bool match(VPValue *VPV) {
67     if (!VPV->isLiveIn())
68       return false;
69     Value *V = VPV->getLiveInIRValue();
70     const auto *CI = dyn_cast<ConstantInt>(V);
71     if (!CI && V->getType()->isVectorTy())
72       if (const auto *C = dyn_cast<Constant>(V))
73         CI = dyn_cast_or_null<ConstantInt>(
74             C->getSplatValue(/*AllowPoison=*/false));
75     if (!CI)
76       return false;
77 
78     assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) &&
79            "Trying the match constant with unexpected bitwidth.");
80     return APInt::isSameValue(CI->getValue(), Val);
81   }
82 };
83 
84 inline specific_intval<0> m_SpecificInt(uint64_t V) {
85   return specific_intval<0>(APInt(64, V));
86 }
87 
88 inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); }
89 
90 /// Matching combinators
91 template <typename LTy, typename RTy> struct match_combine_or {
92   LTy L;
93   RTy R;
94 
95   match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {}
96 
97   template <typename ITy> bool match(ITy *V) {
98     if (L.match(V))
99       return true;
100     if (R.match(V))
101       return true;
102     return false;
103   }
104 };
105 
106 template <typename LTy, typename RTy>
107 inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) {
108   return match_combine_or<LTy, RTy>(L, R);
109 }
110 
111 /// Match a VPValue, capturing it if we match.
112 inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; }
113 
114 namespace detail {
115 
116 /// A helper to match an opcode against multiple recipe types.
117 template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {};
118 
119 template <unsigned Opcode, typename RecipeTy>
120 struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
121   static bool match(const VPRecipeBase *R) {
122     auto *DefR = dyn_cast<RecipeTy>(R);
123     return DefR && DefR->getOpcode() == Opcode;
124   }
125 };
126 
127 template <unsigned Opcode, typename RecipeTy, typename... RecipeTys>
128 struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> {
129   static bool match(const VPRecipeBase *R) {
130     return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) ||
131            MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R);
132   }
133 };
134 } // namespace detail
135 
136 template <typename Op0_t, unsigned Opcode, typename... RecipeTys>
137 struct UnaryRecipe_match {
138   Op0_t Op0;
139 
140   UnaryRecipe_match(Op0_t Op0) : Op0(Op0) {}
141 
142   bool match(const VPValue *V) {
143     auto *DefR = V->getDefiningRecipe();
144     return DefR && match(DefR);
145   }
146 
147   bool match(const VPRecipeBase *R) {
148     if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
149       return false;
150     assert(R->getNumOperands() == 1 &&
151            "recipe with matched opcode does not have 1 operands");
152     return Op0.match(R->getOperand(0));
153   }
154 };
155 
156 template <typename Op0_t, unsigned Opcode>
157 using UnaryVPInstruction_match =
158     UnaryRecipe_match<Op0_t, Opcode, VPInstruction>;
159 
160 template <typename Op0_t, unsigned Opcode>
161 using AllUnaryRecipe_match =
162     UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
163                       VPWidenCastRecipe, VPInstruction>;
164 
165 template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative,
166           typename... RecipeTys>
167 struct BinaryRecipe_match {
168   Op0_t Op0;
169   Op1_t Op1;
170 
171   BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
172 
173   bool match(const VPValue *V) {
174     auto *DefR = V->getDefiningRecipe();
175     return DefR && match(DefR);
176   }
177 
178   bool match(const VPSingleDefRecipe *R) {
179     return match(static_cast<const VPRecipeBase *>(R));
180   }
181 
182   bool match(const VPRecipeBase *R) {
183     if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
184       return false;
185     assert(R->getNumOperands() == 2 &&
186            "recipe with matched opcode does not have 2 operands");
187     if (Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1)))
188       return true;
189     return Commutative && Op0.match(R->getOperand(1)) &&
190            Op1.match(R->getOperand(0));
191   }
192 };
193 
194 template <typename Op0_t, typename Op1_t, unsigned Opcode>
195 using BinaryVPInstruction_match =
196     BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false,
197                        VPInstruction>;
198 
199 template <typename Op0_t, typename Op1_t, unsigned Opcode,
200           bool Commutative = false>
201 using AllBinaryRecipe_match =
202     BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe,
203                        VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>;
204 
205 template <unsigned Opcode, typename Op0_t>
206 inline UnaryVPInstruction_match<Op0_t, Opcode>
207 m_VPInstruction(const Op0_t &Op0) {
208   return UnaryVPInstruction_match<Op0_t, Opcode>(Op0);
209 }
210 
211 template <unsigned Opcode, typename Op0_t, typename Op1_t>
212 inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>
213 m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) {
214   return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
215 }
216 
217 template <typename Op0_t>
218 inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
219 m_Not(const Op0_t &Op0) {
220   return m_VPInstruction<VPInstruction::Not>(Op0);
221 }
222 
223 template <typename Op0_t>
224 inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond>
225 m_BranchOnCond(const Op0_t &Op0) {
226   return m_VPInstruction<VPInstruction::BranchOnCond>(Op0);
227 }
228 
229 template <typename Op0_t, typename Op1_t>
230 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask>
231 m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) {
232   return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1);
233 }
234 
235 template <typename Op0_t, typename Op1_t>
236 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
237 m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
238   return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
239 }
240 
241 template <unsigned Opcode, typename Op0_t>
242 inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) {
243   return AllUnaryRecipe_match<Op0_t, Opcode>(Op0);
244 }
245 
246 template <typename Op0_t>
247 inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc>
248 m_Trunc(const Op0_t &Op0) {
249   return m_Unary<Instruction::Trunc, Op0_t>(Op0);
250 }
251 
252 template <typename Op0_t>
253 inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) {
254   return m_Unary<Instruction::ZExt, Op0_t>(Op0);
255 }
256 
257 template <typename Op0_t>
258 inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) {
259   return m_Unary<Instruction::SExt, Op0_t>(Op0);
260 }
261 
262 template <typename Op0_t>
263 inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>,
264                         AllUnaryRecipe_match<Op0_t, Instruction::SExt>>
265 m_ZExtOrSExt(const Op0_t &Op0) {
266   return m_CombineOr(m_ZExt(Op0), m_SExt(Op0));
267 }
268 
269 template <unsigned Opcode, typename Op0_t, typename Op1_t,
270           bool Commutative = false>
271 inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>
272 m_Binary(const Op0_t &Op0, const Op1_t &Op1) {
273   return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1);
274 }
275 
276 template <typename Op0_t, typename Op1_t>
277 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul>
278 m_Mul(const Op0_t &Op0, const Op1_t &Op1) {
279   return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
280 }
281 
282 template <typename Op0_t, typename Op1_t>
283 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul,
284                              /* Commutative =*/true>
285 m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
286   return m_Binary<Instruction::Mul, Op0_t, Op1_t, true>(Op0, Op1);
287 }
288 
289 /// Match a binary OR operation. Note that while conceptually the operands can
290 /// be matched commutatively, \p Commutative defaults to false in line with the
291 /// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative
292 /// version of the matcher.
293 template <typename Op0_t, typename Op1_t, bool Commutative = false>
294 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative>
295 m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
296   return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1);
297 }
298 
299 template <typename Op0_t, typename Op1_t>
300 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or,
301                              /*Commutative*/ true>
302 m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
303   return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
304 }
305 
306 template <typename Op0_t, typename Op1_t>
307 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>
308 m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
309   return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1);
310 }
311 
312 struct VPCanonicalIVPHI_match {
313   bool match(const VPValue *V) {
314     auto *DefR = V->getDefiningRecipe();
315     return DefR && match(DefR);
316   }
317 
318   bool match(const VPRecipeBase *R) { return isa<VPCanonicalIVPHIRecipe>(R); }
319 };
320 
321 inline VPCanonicalIVPHI_match m_CanonicalIV() {
322   return VPCanonicalIVPHI_match();
323 }
324 
325 template <typename Op0_t, typename Op1_t> struct VPScalarIVSteps_match {
326   Op0_t Op0;
327   Op1_t Op1;
328 
329   VPScalarIVSteps_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
330 
331   bool match(const VPValue *V) {
332     auto *DefR = V->getDefiningRecipe();
333     return DefR && match(DefR);
334   }
335 
336   bool match(const VPRecipeBase *R) {
337     if (!isa<VPScalarIVStepsRecipe>(R))
338       return false;
339     assert(R->getNumOperands() == 2 &&
340            "VPScalarIVSteps must have exactly 2 operands");
341     return Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1));
342   }
343 };
344 
345 template <typename Op0_t, typename Op1_t>
346 inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0,
347                                                            const Op1_t &Op1) {
348   return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1);
349 }
350 
351 } // namespace VPlanPatternMatch
352 } // namespace llvm
353 
354 #endif
355