xref: /llvm-project/llvm/lib/CodeGen/SelectionDAG/MatchContext.h (revision 91b423d955ff1da6cfbe82436ffee280fa25cd02)
1 //===---------------- llvm/CodeGen/MatchContext.h  --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the EmptyMatchContext class and VPMatchContext class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
14 #define LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
15 
16 #include "llvm/CodeGen/SelectionDAG.h"
17 #include "llvm/CodeGen/TargetLowering.h"
18 
19 namespace llvm {
20 
21 class EmptyMatchContext {
22   SelectionDAG &DAG;
23   const TargetLowering &TLI;
24   SDNode *Root;
25 
26 public:
27   EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
28       : DAG(DAG), TLI(TLI), Root(Root) {}
29 
30   unsigned getRootBaseOpcode() { return Root->getOpcode(); }
31   bool match(SDValue OpN, unsigned Opcode) const {
32     return Opcode == OpN->getOpcode();
33   }
34 
35   // Same as SelectionDAG::getNode().
36   template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
37     return DAG.getNode(std::forward<ArgT>(Args)...);
38   }
39 
40   bool isOperationLegal(unsigned Op, EVT VT) const {
41     return TLI.isOperationLegal(Op, VT);
42   }
43 
44   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
45                                 bool LegalOnly = false) const {
46     return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
47   }
48 
49   unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
50 };
51 
52 class VPMatchContext {
53   SelectionDAG &DAG;
54   const TargetLowering &TLI;
55   SDValue RootMaskOp;
56   SDValue RootVectorLenOp;
57   SDNode *Root;
58 
59 public:
60   VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *_Root)
61       : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
62     Root = _Root;
63     assert(Root->isVPOpcode());
64     if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
65       RootMaskOp = Root->getOperand(*RootMaskPos);
66     else if (Root->getOpcode() == ISD::VP_SELECT)
67       RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
68                                           Root->getOperand(0).getValueType());
69 
70     if (auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
71       RootVectorLenOp = Root->getOperand(*RootVLenPos);
72   }
73 
74   unsigned getRootBaseOpcode() {
75     std::optional<unsigned> Opcode = ISD::getBaseOpcodeForVP(
76         Root->getOpcode(), !Root->getFlags().hasNoFPExcept());
77     assert(Opcode.has_value());
78     return *Opcode;
79   }
80 
81   /// whether \p OpVal is a node that is functionally compatible with the
82   /// NodeType \p Opc
83   bool match(SDValue OpVal, unsigned Opc) const {
84     if (!OpVal->isVPOpcode())
85       return OpVal->getOpcode() == Opc;
86 
87     auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(),
88                                            !OpVal->getFlags().hasNoFPExcept());
89     if (BaseOpc != Opc)
90       return false;
91 
92     // Make sure the mask of OpVal is true mask or is same as Root's.
93     unsigned VPOpcode = OpVal->getOpcode();
94     if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
95       SDValue MaskOp = OpVal.getOperand(*MaskPos);
96       if (RootMaskOp != MaskOp &&
97           !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
98         return false;
99     }
100 
101     // Make sure the EVL of OpVal is same as Root's.
102     if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
103       if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
104         return false;
105     return true;
106   }
107 
108   // Specialize based on number of operands.
109   // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
110   // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
111   // DAG.getNode(Opcode, DL, VT); }
112   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
113     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
114     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
115            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
116     return DAG.getNode(VPOpcode, DL, VT,
117                        {Operand, RootMaskOp, RootVectorLenOp});
118   }
119 
120   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
121                   SDValue N2) {
122     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
123     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
124            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
125     return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp});
126   }
127 
128   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
129                   SDValue N2, SDValue N3) {
130     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
131     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
132            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
133     return DAG.getNode(VPOpcode, DL, VT,
134                        {N1, N2, N3, RootMaskOp, RootVectorLenOp});
135   }
136 
137   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
138                   SDNodeFlags Flags) {
139     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
140     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
141            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
142     return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
143                        Flags);
144   }
145 
146   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
147                   SDValue N2, SDNodeFlags Flags) {
148     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
149     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
150            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
151     return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
152                        Flags);
153   }
154 
155   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
156                   SDValue N2, SDValue N3, SDNodeFlags Flags) {
157     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
158     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
159            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
160     return DAG.getNode(VPOpcode, DL, VT,
161                        {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
162   }
163 
164   bool isOperationLegal(unsigned Op, EVT VT) const {
165     unsigned VPOp = *ISD::getVPForBaseOpcode(Op);
166     return TLI.isOperationLegal(VPOp, VT);
167   }
168 
169   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
170                                 bool LegalOnly = false) const {
171     unsigned VPOp = *ISD::getVPForBaseOpcode(Op);
172     return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
173   }
174 
175   unsigned getNumOperands(SDValue N) const {
176     return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
177   }
178 };
179 
180 } // namespace llvm
181 
182 #endif // LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
183