xref: /llvm-project/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (revision dac0f7e83ebcaed451d5d0bcc3a4c7f949f0c26c)
1 //===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matches on the VPlan values and recipes, based on
11 // LLVM's IR pattern matchers.
12 //
13 // Currently it provides generic matchers for unary and binary VPInstructions,
14 // and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond,
15 // m_BranchOnCount to match specific VPInstructions.
16 // TODO: Add missing matchers for additional opcodes and recipes as needed.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
21 #define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
22 
23 #include "VPlan.h"
24 
25 namespace llvm {
26 namespace VPlanPatternMatch {
27 
28 template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
29   return P.match(V);
30 }
31 
32 template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
33   auto *R = dyn_cast<VPRecipeBase>(U);
34   return R && match(R, P);
35 }
36 
37 template <typename Class> struct class_match {
38   template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
39 };
40 
41 /// Match an arbitrary VPValue and ignore it.
42 inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); }
43 
44 template <typename Class> struct bind_ty {
45   Class *&VR;
46 
47   bind_ty(Class *&V) : VR(V) {}
48 
49   template <typename ITy> bool match(ITy *V) const {
50     if (auto *CV = dyn_cast<Class>(V)) {
51       VR = CV;
52       return true;
53     }
54     return false;
55   }
56 };
57 
58 /// Match a specified integer value or vector of all elements of that
59 /// value. \p BitWidth optionally specifies the bitwidth the matched constant
60 /// must have. If it is 0, the matched constant can have any bitwidth.
61 template <unsigned BitWidth = 0> struct specific_intval {
62   APInt Val;
63 
64   specific_intval(APInt V) : Val(std::move(V)) {}
65 
66   bool match(VPValue *VPV) const {
67     if (!VPV->isLiveIn())
68       return false;
69     Value *V = VPV->getLiveInIRValue();
70     const auto *CI = dyn_cast<ConstantInt>(V);
71     if (!CI && V->getType()->isVectorTy())
72       if (const auto *C = dyn_cast<Constant>(V))
73         CI = dyn_cast_or_null<ConstantInt>(
74             C->getSplatValue(/*AllowPoison=*/false));
75     if (!CI)
76       return false;
77 
78     assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) &&
79            "Trying the match constant with unexpected bitwidth.");
80     return APInt::isSameValue(CI->getValue(), Val);
81   }
82 };
83 
84 inline specific_intval<0> m_SpecificInt(uint64_t V) {
85   return specific_intval<0>(APInt(64, V));
86 }
87 
88 inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); }
89 
90 /// Matching combinators
91 template <typename LTy, typename RTy> struct match_combine_or {
92   LTy L;
93   RTy R;
94 
95   match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {}
96 
97   template <typename ITy> bool match(ITy *V) const {
98     if (L.match(V))
99       return true;
100     if (R.match(V))
101       return true;
102     return false;
103   }
104 };
105 
106 template <typename LTy, typename RTy>
107 inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) {
108   return match_combine_or<LTy, RTy>(L, R);
109 }
110 
111 /// Match a VPValue, capturing it if we match.
112 inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; }
113 
114 namespace detail {
115 
116 /// A helper to match an opcode against multiple recipe types.
117 template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {};
118 
119 template <unsigned Opcode, typename RecipeTy>
120 struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
121   static bool match(const VPRecipeBase *R) {
122     auto *DefR = dyn_cast<RecipeTy>(R);
123     // Check for recipes that do not have opcodes.
124     if constexpr (std::is_same<RecipeTy, VPScalarIVStepsRecipe>::value ||
125                   std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value)
126       return DefR;
127     else
128       return DefR && DefR->getOpcode() == Opcode;
129   }
130 };
131 
132 template <unsigned Opcode, typename RecipeTy, typename... RecipeTys>
133 struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> {
134   static bool match(const VPRecipeBase *R) {
135     return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) ||
136            MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R);
137   }
138 };
139 template <typename TupleTy, typename Fn, std::size_t... Is>
140 bool CheckTupleElements(const TupleTy &Ops, Fn P, std::index_sequence<Is...>) {
141   return (P(std::get<Is>(Ops), Is) && ...);
142 }
143 
144 /// Helper to check if predicate \p P holds on all tuple elements in \p Ops
145 template <typename TupleTy, typename Fn>
146 bool all_of_tuple_elements(const TupleTy &Ops, Fn P) {
147   return CheckTupleElements(
148       Ops, P, std::make_index_sequence<std::tuple_size<TupleTy>::value>{});
149 }
150 } // namespace detail
151 
152 template <typename Ops_t, unsigned Opcode, bool Commutative,
153           typename... RecipeTys>
154 struct Recipe_match {
155   Ops_t Ops;
156 
157   Recipe_match() : Ops() {
158     static_assert(std::tuple_size<Ops_t>::value == 0 &&
159                   "constructor can only be used with zero operands");
160   }
161   Recipe_match(Ops_t Ops) : Ops(Ops) {}
162   template <typename A_t, typename B_t>
163   Recipe_match(A_t A, B_t B) : Ops({A, B}) {
164     static_assert(std::tuple_size<Ops_t>::value == 2 &&
165                   "constructor can only be used for binary matcher");
166   }
167 
168   bool match(const VPValue *V) const {
169     auto *DefR = V->getDefiningRecipe();
170     return DefR && match(DefR);
171   }
172 
173   bool match(const VPSingleDefRecipe *R) const {
174     return match(static_cast<const VPRecipeBase *>(R));
175   }
176 
177   bool match(const VPRecipeBase *R) const {
178     if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
179       return false;
180     assert(R->getNumOperands() == std::tuple_size<Ops_t>::value &&
181            "recipe with matched opcode the expected number of operands");
182 
183     if (detail::all_of_tuple_elements(Ops, [R](auto Op, unsigned Idx) {
184           return Op.match(R->getOperand(Idx));
185         }))
186       return true;
187 
188     return Commutative &&
189            detail::all_of_tuple_elements(Ops, [R](auto Op, unsigned Idx) {
190              return Op.match(R->getOperand(R->getNumOperands() - Idx - 1));
191            });
192   }
193 };
194 
195 template <typename Op0_t, unsigned Opcode, typename... RecipeTys>
196 using UnaryRecipe_match =
197     Recipe_match<std::tuple<Op0_t>, Opcode, false, RecipeTys...>;
198 
199 template <typename Op0_t, unsigned Opcode>
200 using UnaryVPInstruction_match =
201     UnaryRecipe_match<Op0_t, Opcode, VPInstruction>;
202 
203 template <typename Op0_t, unsigned Opcode>
204 using AllUnaryRecipe_match =
205     UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
206                       VPWidenCastRecipe, VPInstruction>;
207 
208 template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative,
209           typename... RecipeTys>
210 using BinaryRecipe_match =
211     Recipe_match<std::tuple<Op0_t, Op1_t>, Opcode, Commutative, RecipeTys...>;
212 
213 template <typename Op0_t, typename Op1_t, unsigned Opcode>
214 using BinaryVPInstruction_match =
215     BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false,
216                        VPInstruction>;
217 
218 template <typename Op0_t, typename Op1_t, unsigned Opcode,
219           bool Commutative = false>
220 using AllBinaryRecipe_match =
221     BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe,
222                        VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>;
223 
224 template <unsigned Opcode, typename Op0_t>
225 inline UnaryVPInstruction_match<Op0_t, Opcode>
226 m_VPInstruction(const Op0_t &Op0) {
227   return UnaryVPInstruction_match<Op0_t, Opcode>(Op0);
228 }
229 
230 template <unsigned Opcode, typename Op0_t, typename Op1_t>
231 inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>
232 m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) {
233   return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
234 }
235 
236 template <typename Op0_t>
237 inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
238 m_Not(const Op0_t &Op0) {
239   return m_VPInstruction<VPInstruction::Not>(Op0);
240 }
241 
242 template <typename Op0_t>
243 inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond>
244 m_BranchOnCond(const Op0_t &Op0) {
245   return m_VPInstruction<VPInstruction::BranchOnCond>(Op0);
246 }
247 
248 template <typename Op0_t, typename Op1_t>
249 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask>
250 m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) {
251   return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1);
252 }
253 
254 template <typename Op0_t, typename Op1_t>
255 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
256 m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
257   return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
258 }
259 
260 template <unsigned Opcode, typename Op0_t>
261 inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) {
262   return AllUnaryRecipe_match<Op0_t, Opcode>(Op0);
263 }
264 
265 template <typename Op0_t>
266 inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc>
267 m_Trunc(const Op0_t &Op0) {
268   return m_Unary<Instruction::Trunc, Op0_t>(Op0);
269 }
270 
271 template <typename Op0_t>
272 inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) {
273   return m_Unary<Instruction::ZExt, Op0_t>(Op0);
274 }
275 
276 template <typename Op0_t>
277 inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) {
278   return m_Unary<Instruction::SExt, Op0_t>(Op0);
279 }
280 
281 template <typename Op0_t>
282 inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>,
283                         AllUnaryRecipe_match<Op0_t, Instruction::SExt>>
284 m_ZExtOrSExt(const Op0_t &Op0) {
285   return m_CombineOr(m_ZExt(Op0), m_SExt(Op0));
286 }
287 
288 template <unsigned Opcode, typename Op0_t, typename Op1_t,
289           bool Commutative = false>
290 inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>
291 m_Binary(const Op0_t &Op0, const Op1_t &Op1) {
292   return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1);
293 }
294 
295 template <typename Op0_t, typename Op1_t>
296 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul>
297 m_Mul(const Op0_t &Op0, const Op1_t &Op1) {
298   return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
299 }
300 
301 template <typename Op0_t, typename Op1_t>
302 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul,
303                              /* Commutative =*/true>
304 m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
305   return m_Binary<Instruction::Mul, Op0_t, Op1_t, true>(Op0, Op1);
306 }
307 
308 /// Match a binary OR operation. Note that while conceptually the operands can
309 /// be matched commutatively, \p Commutative defaults to false in line with the
310 /// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative
311 /// version of the matcher.
312 template <typename Op0_t, typename Op1_t, bool Commutative = false>
313 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative>
314 m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
315   return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1);
316 }
317 
318 template <typename Op0_t, typename Op1_t>
319 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or,
320                              /*Commutative*/ true>
321 m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
322   return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
323 }
324 
325 template <typename Op0_t, typename Op1_t>
326 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>
327 m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
328   return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1);
329 }
330 
331 using VPCanonicalIVPHI_match =
332     Recipe_match<std::tuple<>, 0, false, VPCanonicalIVPHIRecipe>;
333 
334 inline VPCanonicalIVPHI_match m_CanonicalIV() {
335   return VPCanonicalIVPHI_match();
336 }
337 
338 template <typename Op0_t, typename Op1_t>
339 using VPScalarIVSteps_match =
340     Recipe_match<std::tuple<Op0_t, Op1_t>, 0, false, VPScalarIVStepsRecipe>;
341 
342 template <typename Op0_t, typename Op1_t>
343 inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0,
344                                                            const Op1_t &Op1) {
345   return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1);
346 }
347 
348 } // namespace VPlanPatternMatch
349 } // namespace llvm
350 
351 #endif
352