xref: /llvm-project/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (revision 1d9b3222f3de7bad4ef27b7e4d7798f840097380)
1 //===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matches on the VPlan values and recipes, based on
11 // LLVM's IR pattern matchers.
12 //
13 // Currently it provides generic matchers for unary and binary VPInstructions,
14 // and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond,
15 // m_BranchOnCount to match specific VPInstructions.
16 // TODO: Add missing matchers for additional opcodes and recipes as needed.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
21 #define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
22 
23 #include "VPlan.h"
24 
25 namespace llvm {
26 namespace VPlanPatternMatch {
27 
28 template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
29   return P.match(V);
30 }
31 
32 template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
33   auto *R = dyn_cast<VPRecipeBase>(U);
34   return R && match(R, P);
35 }
36 
37 template <typename Class> struct class_match {
38   template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
39 };
40 
41 /// Match an arbitrary VPValue and ignore it.
42 inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); }
43 
44 template <typename Class> struct bind_ty {
45   Class *&VR;
46 
47   bind_ty(Class *&V) : VR(V) {}
48 
49   template <typename ITy> bool match(ITy *V) const {
50     if (auto *CV = dyn_cast<Class>(V)) {
51       VR = CV;
52       return true;
53     }
54     return false;
55   }
56 };
57 
58 /// Match a specified integer value or vector of all elements of that
59 /// value. \p BitWidth optionally specifies the bitwidth the matched constant
60 /// must have. If it is 0, the matched constant can have any bitwidth.
61 template <unsigned BitWidth = 0> struct specific_intval {
62   APInt Val;
63 
64   specific_intval(APInt V) : Val(std::move(V)) {}
65 
66   bool match(VPValue *VPV) const {
67     if (!VPV->isLiveIn())
68       return false;
69     Value *V = VPV->getLiveInIRValue();
70     const auto *CI = dyn_cast<ConstantInt>(V);
71     if (!CI && V->getType()->isVectorTy())
72       if (const auto *C = dyn_cast<Constant>(V))
73         CI = dyn_cast_or_null<ConstantInt>(
74             C->getSplatValue(/*AllowPoison=*/false));
75     if (!CI)
76       return false;
77 
78     if (BitWidth != 0 && CI->getBitWidth() != BitWidth)
79       return false;
80     return APInt::isSameValue(CI->getValue(), Val);
81   }
82 };
83 
84 inline specific_intval<0> m_SpecificInt(uint64_t V) {
85   return specific_intval<0>(APInt(64, V));
86 }
87 
88 inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); }
89 
90 inline specific_intval<1> m_True() { return specific_intval<1>(APInt(64, 1)); }
91 
92 /// Matching combinators
93 template <typename LTy, typename RTy> struct match_combine_or {
94   LTy L;
95   RTy R;
96 
97   match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {}
98 
99   template <typename ITy> bool match(ITy *V) const {
100     if (L.match(V))
101       return true;
102     if (R.match(V))
103       return true;
104     return false;
105   }
106 };
107 
108 template <typename LTy, typename RTy>
109 inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) {
110   return match_combine_or<LTy, RTy>(L, R);
111 }
112 
113 /// Match a VPValue, capturing it if we match.
114 inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; }
115 
116 namespace detail {
117 
118 /// A helper to match an opcode against multiple recipe types.
119 template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {};
120 
121 template <unsigned Opcode, typename RecipeTy>
122 struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
123   static bool match(const VPRecipeBase *R) {
124     auto *DefR = dyn_cast<RecipeTy>(R);
125     // Check for recipes that do not have opcodes.
126     if constexpr (std::is_same<RecipeTy, VPScalarIVStepsRecipe>::value ||
127                   std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value ||
128                   std::is_same<RecipeTy, VPWidenSelectRecipe>::value)
129       return DefR;
130     else
131       return DefR && DefR->getOpcode() == Opcode;
132   }
133 };
134 
135 template <unsigned Opcode, typename RecipeTy, typename... RecipeTys>
136 struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> {
137   static bool match(const VPRecipeBase *R) {
138     return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) ||
139            MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R);
140   }
141 };
142 template <typename TupleTy, typename Fn, std::size_t... Is>
143 bool CheckTupleElements(const TupleTy &Ops, Fn P, std::index_sequence<Is...>) {
144   return (P(std::get<Is>(Ops), Is) && ...);
145 }
146 
147 /// Helper to check if predicate \p P holds on all tuple elements in \p Ops
148 template <typename TupleTy, typename Fn>
149 bool all_of_tuple_elements(const TupleTy &Ops, Fn P) {
150   return CheckTupleElements(
151       Ops, P, std::make_index_sequence<std::tuple_size<TupleTy>::value>{});
152 }
153 } // namespace detail
154 
155 template <typename Ops_t, unsigned Opcode, bool Commutative,
156           typename... RecipeTys>
157 struct Recipe_match {
158   Ops_t Ops;
159 
160   Recipe_match() : Ops() {
161     static_assert(std::tuple_size<Ops_t>::value == 0 &&
162                   "constructor can only be used with zero operands");
163   }
164   Recipe_match(Ops_t Ops) : Ops(Ops) {}
165   template <typename A_t, typename B_t>
166   Recipe_match(A_t A, B_t B) : Ops({A, B}) {
167     static_assert(std::tuple_size<Ops_t>::value == 2 &&
168                   "constructor can only be used for binary matcher");
169   }
170 
171   bool match(const VPValue *V) const {
172     auto *DefR = V->getDefiningRecipe();
173     return DefR && match(DefR);
174   }
175 
176   bool match(const VPSingleDefRecipe *R) const {
177     return match(static_cast<const VPRecipeBase *>(R));
178   }
179 
180   bool match(const VPRecipeBase *R) const {
181     if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
182       return false;
183     assert(R->getNumOperands() == std::tuple_size<Ops_t>::value &&
184            "recipe with matched opcode the expected number of operands");
185 
186     if (detail::all_of_tuple_elements(Ops, [R](auto Op, unsigned Idx) {
187           return Op.match(R->getOperand(Idx));
188         }))
189       return true;
190 
191     return Commutative &&
192            detail::all_of_tuple_elements(Ops, [R](auto Op, unsigned Idx) {
193              return Op.match(R->getOperand(R->getNumOperands() - Idx - 1));
194            });
195   }
196 };
197 
198 template <typename Op0_t, unsigned Opcode, typename... RecipeTys>
199 using UnaryRecipe_match =
200     Recipe_match<std::tuple<Op0_t>, Opcode, false, RecipeTys...>;
201 
202 template <typename Op0_t, unsigned Opcode>
203 using UnaryVPInstruction_match =
204     UnaryRecipe_match<Op0_t, Opcode, VPInstruction>;
205 
206 template <typename Op0_t, unsigned Opcode>
207 using AllUnaryRecipe_match =
208     UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
209                       VPWidenCastRecipe, VPInstruction>;
210 
211 template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative,
212           typename... RecipeTys>
213 using BinaryRecipe_match =
214     Recipe_match<std::tuple<Op0_t, Op1_t>, Opcode, Commutative, RecipeTys...>;
215 
216 template <typename Op0_t, typename Op1_t, unsigned Opcode>
217 using BinaryVPInstruction_match =
218     BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false,
219                        VPInstruction>;
220 
221 template <typename Op0_t, typename Op1_t, unsigned Opcode,
222           bool Commutative = false>
223 using AllBinaryRecipe_match =
224     BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe,
225                        VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>;
226 
227 template <unsigned Opcode, typename Op0_t>
228 inline UnaryVPInstruction_match<Op0_t, Opcode>
229 m_VPInstruction(const Op0_t &Op0) {
230   return UnaryVPInstruction_match<Op0_t, Opcode>(Op0);
231 }
232 
233 template <unsigned Opcode, typename Op0_t, typename Op1_t>
234 inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>
235 m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) {
236   return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
237 }
238 
239 template <typename Op0_t>
240 inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
241 m_Not(const Op0_t &Op0) {
242   return m_VPInstruction<VPInstruction::Not>(Op0);
243 }
244 
245 template <typename Op0_t>
246 inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond>
247 m_BranchOnCond(const Op0_t &Op0) {
248   return m_VPInstruction<VPInstruction::BranchOnCond>(Op0);
249 }
250 
251 template <typename Op0_t, typename Op1_t>
252 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask>
253 m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) {
254   return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1);
255 }
256 
257 template <typename Op0_t, typename Op1_t>
258 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
259 m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
260   return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
261 }
262 
263 template <unsigned Opcode, typename Op0_t>
264 inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) {
265   return AllUnaryRecipe_match<Op0_t, Opcode>(Op0);
266 }
267 
268 template <typename Op0_t>
269 inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc>
270 m_Trunc(const Op0_t &Op0) {
271   return m_Unary<Instruction::Trunc, Op0_t>(Op0);
272 }
273 
274 template <typename Op0_t>
275 inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) {
276   return m_Unary<Instruction::ZExt, Op0_t>(Op0);
277 }
278 
279 template <typename Op0_t>
280 inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) {
281   return m_Unary<Instruction::SExt, Op0_t>(Op0);
282 }
283 
284 template <typename Op0_t>
285 inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>,
286                         AllUnaryRecipe_match<Op0_t, Instruction::SExt>>
287 m_ZExtOrSExt(const Op0_t &Op0) {
288   return m_CombineOr(m_ZExt(Op0), m_SExt(Op0));
289 }
290 
291 template <unsigned Opcode, typename Op0_t, typename Op1_t,
292           bool Commutative = false>
293 inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>
294 m_Binary(const Op0_t &Op0, const Op1_t &Op1) {
295   return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1);
296 }
297 
298 template <typename Op0_t, typename Op1_t>
299 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul>
300 m_Mul(const Op0_t &Op0, const Op1_t &Op1) {
301   return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
302 }
303 
304 template <typename Op0_t, typename Op1_t>
305 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul,
306                              /* Commutative =*/true>
307 m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
308   return m_Binary<Instruction::Mul, Op0_t, Op1_t, true>(Op0, Op1);
309 }
310 
311 /// Match a binary OR operation. Note that while conceptually the operands can
312 /// be matched commutatively, \p Commutative defaults to false in line with the
313 /// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative
314 /// version of the matcher.
315 template <typename Op0_t, typename Op1_t, bool Commutative = false>
316 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative>
317 m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
318   return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1);
319 }
320 
321 template <typename Op0_t, typename Op1_t>
322 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or,
323                              /*Commutative*/ true>
324 m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
325   return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
326 }
327 
328 template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode>
329 using AllTernaryRecipe_match =
330     Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, Opcode, false,
331                  VPReplicateRecipe, VPInstruction, VPWidenSelectRecipe>;
332 
333 template <typename Op0_t, typename Op1_t, typename Op2_t>
334 inline AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>
335 m_Select(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
336   return AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>(
337       {Op0, Op1, Op2});
338 }
339 
340 template <typename Op0_t, typename Op1_t>
341 inline match_combine_or<
342     BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>,
343     AllTernaryRecipe_match<Op0_t, Op1_t, specific_intval<1>,
344                            Instruction::Select>>
345 m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
346   return m_CombineOr(
347       m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1),
348       m_Select(Op0, Op1, m_False()));
349 }
350 
351 template <typename Op0_t, typename Op1_t>
352 inline AllTernaryRecipe_match<Op0_t, specific_intval<1>, Op1_t,
353                               Instruction::Select>
354 m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) {
355   return m_Select(Op0, m_True(), Op1);
356 }
357 
358 using VPCanonicalIVPHI_match =
359     Recipe_match<std::tuple<>, 0, false, VPCanonicalIVPHIRecipe>;
360 
361 inline VPCanonicalIVPHI_match m_CanonicalIV() {
362   return VPCanonicalIVPHI_match();
363 }
364 
365 template <typename Op0_t, typename Op1_t>
366 using VPScalarIVSteps_match =
367     Recipe_match<std::tuple<Op0_t, Op1_t>, 0, false, VPScalarIVStepsRecipe>;
368 
369 template <typename Op0_t, typename Op1_t>
370 inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0,
371                                                            const Op1_t &Op1) {
372   return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1);
373 }
374 } // namespace VPlanPatternMatch
375 } // namespace llvm
376 
377 #endif
378