xref: /llvm-project/llvm/lib/CodeGen/SelectionDAG/MatchContext.h (revision 91b423d955ff1da6cfbe82436ffee280fa25cd02)
1850dde06SYeting Kuo //===---------------- llvm/CodeGen/MatchContext.h  --------------*- C++ -*-===//
2850dde06SYeting Kuo //
3850dde06SYeting Kuo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4850dde06SYeting Kuo // See https://llvm.org/LICENSE.txt for license information.
5850dde06SYeting Kuo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6850dde06SYeting Kuo //
7850dde06SYeting Kuo //===----------------------------------------------------------------------===//
8850dde06SYeting Kuo //
9850dde06SYeting Kuo // This file declares the EmptyMatchContext class and VPMatchContext class.
10850dde06SYeting Kuo //
11850dde06SYeting Kuo //===----------------------------------------------------------------------===//
12850dde06SYeting Kuo 
13850dde06SYeting Kuo #ifndef LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
14850dde06SYeting Kuo #define LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
15850dde06SYeting Kuo 
16850dde06SYeting Kuo #include "llvm/CodeGen/SelectionDAG.h"
17850dde06SYeting Kuo #include "llvm/CodeGen/TargetLowering.h"
18850dde06SYeting Kuo 
191753008bSRahul Joshi namespace llvm {
20850dde06SYeting Kuo 
21850dde06SYeting Kuo class EmptyMatchContext {
22850dde06SYeting Kuo   SelectionDAG &DAG;
23850dde06SYeting Kuo   const TargetLowering &TLI;
24850dde06SYeting Kuo   SDNode *Root;
25850dde06SYeting Kuo 
26850dde06SYeting Kuo public:
27850dde06SYeting Kuo   EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
28850dde06SYeting Kuo       : DAG(DAG), TLI(TLI), Root(Root) {}
29850dde06SYeting Kuo 
30850dde06SYeting Kuo   unsigned getRootBaseOpcode() { return Root->getOpcode(); }
31850dde06SYeting Kuo   bool match(SDValue OpN, unsigned Opcode) const {
32850dde06SYeting Kuo     return Opcode == OpN->getOpcode();
33850dde06SYeting Kuo   }
34850dde06SYeting Kuo 
35850dde06SYeting Kuo   // Same as SelectionDAG::getNode().
36850dde06SYeting Kuo   template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
37850dde06SYeting Kuo     return DAG.getNode(std::forward<ArgT>(Args)...);
38850dde06SYeting Kuo   }
39850dde06SYeting Kuo 
40850dde06SYeting Kuo   bool isOperationLegal(unsigned Op, EVT VT) const {
41850dde06SYeting Kuo     return TLI.isOperationLegal(Op, VT);
42850dde06SYeting Kuo   }
43850dde06SYeting Kuo 
44850dde06SYeting Kuo   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
45850dde06SYeting Kuo                                 bool LegalOnly = false) const {
46850dde06SYeting Kuo     return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
47850dde06SYeting Kuo   }
48fc1b0196Sv01dXYZ 
49fc1b0196Sv01dXYZ   unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
50850dde06SYeting Kuo };
51850dde06SYeting Kuo 
52850dde06SYeting Kuo class VPMatchContext {
53850dde06SYeting Kuo   SelectionDAG &DAG;
54850dde06SYeting Kuo   const TargetLowering &TLI;
55850dde06SYeting Kuo   SDValue RootMaskOp;
56850dde06SYeting Kuo   SDValue RootVectorLenOp;
57850dde06SYeting Kuo   SDNode *Root;
58850dde06SYeting Kuo 
59850dde06SYeting Kuo public:
60850dde06SYeting Kuo   VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *_Root)
61850dde06SYeting Kuo       : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
62850dde06SYeting Kuo     Root = _Root;
63850dde06SYeting Kuo     assert(Root->isVPOpcode());
64850dde06SYeting Kuo     if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
65850dde06SYeting Kuo       RootMaskOp = Root->getOperand(*RootMaskPos);
66850dde06SYeting Kuo     else if (Root->getOpcode() == ISD::VP_SELECT)
67850dde06SYeting Kuo       RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
68850dde06SYeting Kuo                                           Root->getOperand(0).getValueType());
69850dde06SYeting Kuo 
70850dde06SYeting Kuo     if (auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
71850dde06SYeting Kuo       RootVectorLenOp = Root->getOperand(*RootVLenPos);
72850dde06SYeting Kuo   }
73850dde06SYeting Kuo 
74850dde06SYeting Kuo   unsigned getRootBaseOpcode() {
75850dde06SYeting Kuo     std::optional<unsigned> Opcode = ISD::getBaseOpcodeForVP(
76850dde06SYeting Kuo         Root->getOpcode(), !Root->getFlags().hasNoFPExcept());
77850dde06SYeting Kuo     assert(Opcode.has_value());
78850dde06SYeting Kuo     return *Opcode;
79850dde06SYeting Kuo   }
80850dde06SYeting Kuo 
81850dde06SYeting Kuo   /// whether \p OpVal is a node that is functionally compatible with the
82850dde06SYeting Kuo   /// NodeType \p Opc
83850dde06SYeting Kuo   bool match(SDValue OpVal, unsigned Opc) const {
84850dde06SYeting Kuo     if (!OpVal->isVPOpcode())
85850dde06SYeting Kuo       return OpVal->getOpcode() == Opc;
86850dde06SYeting Kuo 
87850dde06SYeting Kuo     auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(),
88850dde06SYeting Kuo                                            !OpVal->getFlags().hasNoFPExcept());
89850dde06SYeting Kuo     if (BaseOpc != Opc)
90850dde06SYeting Kuo       return false;
91850dde06SYeting Kuo 
92850dde06SYeting Kuo     // Make sure the mask of OpVal is true mask or is same as Root's.
93850dde06SYeting Kuo     unsigned VPOpcode = OpVal->getOpcode();
94850dde06SYeting Kuo     if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
95850dde06SYeting Kuo       SDValue MaskOp = OpVal.getOperand(*MaskPos);
96850dde06SYeting Kuo       if (RootMaskOp != MaskOp &&
97850dde06SYeting Kuo           !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
98850dde06SYeting Kuo         return false;
99850dde06SYeting Kuo     }
100850dde06SYeting Kuo 
101850dde06SYeting Kuo     // Make sure the EVL of OpVal is same as Root's.
102850dde06SYeting Kuo     if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
103850dde06SYeting Kuo       if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
104850dde06SYeting Kuo         return false;
105850dde06SYeting Kuo     return true;
106850dde06SYeting Kuo   }
107850dde06SYeting Kuo 
108850dde06SYeting Kuo   // Specialize based on number of operands.
109850dde06SYeting Kuo   // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
110850dde06SYeting Kuo   // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
111850dde06SYeting Kuo   // DAG.getNode(Opcode, DL, VT); }
112850dde06SYeting Kuo   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
113*91b423d9SPhilip Reames     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
114850dde06SYeting Kuo     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
115850dde06SYeting Kuo            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
116850dde06SYeting Kuo     return DAG.getNode(VPOpcode, DL, VT,
117850dde06SYeting Kuo                        {Operand, RootMaskOp, RootVectorLenOp});
118850dde06SYeting Kuo   }
119850dde06SYeting Kuo 
120850dde06SYeting Kuo   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
121850dde06SYeting Kuo                   SDValue N2) {
122*91b423d9SPhilip Reames     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
123850dde06SYeting Kuo     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
124850dde06SYeting Kuo            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
125850dde06SYeting Kuo     return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp});
126850dde06SYeting Kuo   }
127850dde06SYeting Kuo 
128850dde06SYeting Kuo   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
129850dde06SYeting Kuo                   SDValue N2, SDValue N3) {
130*91b423d9SPhilip Reames     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
131850dde06SYeting Kuo     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
132850dde06SYeting Kuo            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
133850dde06SYeting Kuo     return DAG.getNode(VPOpcode, DL, VT,
134850dde06SYeting Kuo                        {N1, N2, N3, RootMaskOp, RootVectorLenOp});
135850dde06SYeting Kuo   }
136850dde06SYeting Kuo 
137850dde06SYeting Kuo   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
138850dde06SYeting Kuo                   SDNodeFlags Flags) {
139*91b423d9SPhilip Reames     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
140850dde06SYeting Kuo     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
141850dde06SYeting Kuo            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
142850dde06SYeting Kuo     return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
143850dde06SYeting Kuo                        Flags);
144850dde06SYeting Kuo   }
145850dde06SYeting Kuo 
146850dde06SYeting Kuo   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
147850dde06SYeting Kuo                   SDValue N2, SDNodeFlags Flags) {
148*91b423d9SPhilip Reames     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
149850dde06SYeting Kuo     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
150850dde06SYeting Kuo            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
151850dde06SYeting Kuo     return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
152850dde06SYeting Kuo                        Flags);
153850dde06SYeting Kuo   }
154850dde06SYeting Kuo 
155850dde06SYeting Kuo   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
156850dde06SYeting Kuo                   SDValue N2, SDValue N3, SDNodeFlags Flags) {
157*91b423d9SPhilip Reames     unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
158850dde06SYeting Kuo     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
159850dde06SYeting Kuo            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
160850dde06SYeting Kuo     return DAG.getNode(VPOpcode, DL, VT,
161850dde06SYeting Kuo                        {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
162850dde06SYeting Kuo   }
163850dde06SYeting Kuo 
164850dde06SYeting Kuo   bool isOperationLegal(unsigned Op, EVT VT) const {
165*91b423d9SPhilip Reames     unsigned VPOp = *ISD::getVPForBaseOpcode(Op);
166850dde06SYeting Kuo     return TLI.isOperationLegal(VPOp, VT);
167850dde06SYeting Kuo   }
168850dde06SYeting Kuo 
169850dde06SYeting Kuo   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
170850dde06SYeting Kuo                                 bool LegalOnly = false) const {
171*91b423d9SPhilip Reames     unsigned VPOp = *ISD::getVPForBaseOpcode(Op);
172850dde06SYeting Kuo     return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
173850dde06SYeting Kuo   }
174fc1b0196Sv01dXYZ 
175fc1b0196Sv01dXYZ   unsigned getNumOperands(SDValue N) const {
176fc1b0196Sv01dXYZ     return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
177fc1b0196Sv01dXYZ   }
178850dde06SYeting Kuo };
1791753008bSRahul Joshi 
1801753008bSRahul Joshi } // namespace llvm
1811753008bSRahul Joshi 
1821753008bSRahul Joshi #endif // LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
183