1 //===---------------- llvm/CodeGen/MatchContext.h --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares the EmptyMatchContext class and VPMatchContext class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H 14 #define LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H 15 16 #include "llvm/CodeGen/SelectionDAG.h" 17 #include "llvm/CodeGen/TargetLowering.h" 18 19 namespace llvm { 20 21 class EmptyMatchContext { 22 SelectionDAG &DAG; 23 const TargetLowering &TLI; 24 SDNode *Root; 25 26 public: 27 EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) 28 : DAG(DAG), TLI(TLI), Root(Root) {} 29 30 unsigned getRootBaseOpcode() { return Root->getOpcode(); } 31 bool match(SDValue OpN, unsigned Opcode) const { 32 return Opcode == OpN->getOpcode(); 33 } 34 35 // Same as SelectionDAG::getNode(). 36 template <typename... ArgT> SDValue getNode(ArgT &&...Args) { 37 return DAG.getNode(std::forward<ArgT>(Args)...); 38 } 39 40 bool isOperationLegal(unsigned Op, EVT VT) const { 41 return TLI.isOperationLegal(Op, VT); 42 } 43 44 bool isOperationLegalOrCustom(unsigned Op, EVT VT, 45 bool LegalOnly = false) const { 46 return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly); 47 } 48 49 unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); } 50 }; 51 52 class VPMatchContext { 53 SelectionDAG &DAG; 54 const TargetLowering &TLI; 55 SDValue RootMaskOp; 56 SDValue RootVectorLenOp; 57 SDNode *Root; 58 59 public: 60 VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *_Root) 61 : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() { 62 Root = _Root; 63 assert(Root->isVPOpcode()); 64 if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode())) 65 RootMaskOp = Root->getOperand(*RootMaskPos); 66 else if (Root->getOpcode() == ISD::VP_SELECT) 67 RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root), 68 Root->getOperand(0).getValueType()); 69 70 if (auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Root->getOpcode())) 71 RootVectorLenOp = Root->getOperand(*RootVLenPos); 72 } 73 74 unsigned getRootBaseOpcode() { 75 std::optional<unsigned> Opcode = ISD::getBaseOpcodeForVP( 76 Root->getOpcode(), !Root->getFlags().hasNoFPExcept()); 77 assert(Opcode.has_value()); 78 return *Opcode; 79 } 80 81 /// whether \p OpVal is a node that is functionally compatible with the 82 /// NodeType \p Opc 83 bool match(SDValue OpVal, unsigned Opc) const { 84 if (!OpVal->isVPOpcode()) 85 return OpVal->getOpcode() == Opc; 86 87 auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), 88 !OpVal->getFlags().hasNoFPExcept()); 89 if (BaseOpc != Opc) 90 return false; 91 92 // Make sure the mask of OpVal is true mask or is same as Root's. 93 unsigned VPOpcode = OpVal->getOpcode(); 94 if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) { 95 SDValue MaskOp = OpVal.getOperand(*MaskPos); 96 if (RootMaskOp != MaskOp && 97 !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode())) 98 return false; 99 } 100 101 // Make sure the EVL of OpVal is same as Root's. 102 if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode)) 103 if (RootVectorLenOp != OpVal.getOperand(*VLenPos)) 104 return false; 105 return true; 106 } 107 108 // Specialize based on number of operands. 109 // TODO emit VP intrinsics where MaskOp/VectorLenOp != null 110 // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return 111 // DAG.getNode(Opcode, DL, VT); } 112 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) { 113 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); 114 assert(ISD::getVPMaskIdx(VPOpcode) == 1 && 115 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); 116 return DAG.getNode(VPOpcode, DL, VT, 117 {Operand, RootMaskOp, RootVectorLenOp}); 118 } 119 120 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, 121 SDValue N2) { 122 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); 123 assert(ISD::getVPMaskIdx(VPOpcode) == 2 && 124 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); 125 return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}); 126 } 127 128 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, 129 SDValue N2, SDValue N3) { 130 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); 131 assert(ISD::getVPMaskIdx(VPOpcode) == 3 && 132 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); 133 return DAG.getNode(VPOpcode, DL, VT, 134 {N1, N2, N3, RootMaskOp, RootVectorLenOp}); 135 } 136 137 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, 138 SDNodeFlags Flags) { 139 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); 140 assert(ISD::getVPMaskIdx(VPOpcode) == 1 && 141 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); 142 return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}, 143 Flags); 144 } 145 146 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, 147 SDValue N2, SDNodeFlags Flags) { 148 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); 149 assert(ISD::getVPMaskIdx(VPOpcode) == 2 && 150 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); 151 return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}, 152 Flags); 153 } 154 155 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, 156 SDValue N2, SDValue N3, SDNodeFlags Flags) { 157 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); 158 assert(ISD::getVPMaskIdx(VPOpcode) == 3 && 159 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); 160 return DAG.getNode(VPOpcode, DL, VT, 161 {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); 162 } 163 164 bool isOperationLegal(unsigned Op, EVT VT) const { 165 unsigned VPOp = *ISD::getVPForBaseOpcode(Op); 166 return TLI.isOperationLegal(VPOp, VT); 167 } 168 169 bool isOperationLegalOrCustom(unsigned Op, EVT VT, 170 bool LegalOnly = false) const { 171 unsigned VPOp = *ISD::getVPForBaseOpcode(Op); 172 return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly); 173 } 174 175 unsigned getNumOperands(SDValue N) const { 176 return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands(); 177 } 178 }; 179 180 } // namespace llvm 181 182 #endif // LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H 183