1 //===---- llvm/unittest/CodeGen/SelectionDAGPatternMatchTest.cpp ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 10 #include "llvm/AsmParser/Parser.h" 11 #include "llvm/CodeGen/MachineModuleInfo.h" 12 #include "llvm/CodeGen/SDPatternMatch.h" 13 #include "llvm/CodeGen/TargetLowering.h" 14 #include "llvm/MC/TargetRegistry.h" 15 #include "llvm/Support/SourceMgr.h" 16 #include "llvm/Support/TargetSelect.h" 17 #include "llvm/Target/TargetMachine.h" 18 #include "gtest/gtest.h" 19 20 using namespace llvm; 21 22 class SelectionDAGPatternMatchTest : public testing::Test { 23 protected: 24 static void SetUpTestCase() { 25 InitializeAllTargets(); 26 InitializeAllTargetMCs(); 27 } 28 29 void SetUp() override { 30 StringRef Assembly = "@g = global i32 0\n" 31 "@g_alias = alias i32, i32* @g\n" 32 "define i32 @f() {\n" 33 " %1 = load i32, i32* @g\n" 34 " ret i32 %1\n" 35 "}"; 36 37 Triple TargetTriple("riscv64--"); 38 std::string Error; 39 const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); 40 // FIXME: These tests do not depend on RISCV specifically, but we have to 41 // initialize a target. A skeleton Target for unittests would allow us to 42 // always run these tests. 43 if (!T) 44 GTEST_SKIP(); 45 46 TargetOptions Options; 47 TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine *>( 48 T->createTargetMachine("riscv64", "", "+m,+f,+d,+v", Options, 49 std::nullopt, std::nullopt, 50 CodeGenOptLevel::Aggressive))); 51 if (!TM) 52 GTEST_SKIP(); 53 54 SMDiagnostic SMError; 55 M = parseAssemblyString(Assembly, SMError, Context); 56 if (!M) 57 report_fatal_error(SMError.getMessage()); 58 M->setDataLayout(TM->createDataLayout()); 59 60 F = M->getFunction("f"); 61 if (!F) 62 report_fatal_error("F?"); 63 G = M->getGlobalVariable("g"); 64 if (!G) 65 report_fatal_error("G?"); 66 AliasedG = M->getNamedAlias("g_alias"); 67 if (!AliasedG) 68 report_fatal_error("AliasedG?"); 69 70 MachineModuleInfo MMI(TM.get()); 71 72 MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F), 73 0, MMI); 74 75 DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::None); 76 if (!DAG) 77 report_fatal_error("DAG?"); 78 OptimizationRemarkEmitter ORE(F); 79 DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); 80 } 81 82 TargetLoweringBase::LegalizeTypeAction getTypeAction(EVT VT) { 83 return DAG->getTargetLoweringInfo().getTypeAction(Context, VT); 84 } 85 86 EVT getTypeToTransformTo(EVT VT) { 87 return DAG->getTargetLoweringInfo().getTypeToTransformTo(Context, VT); 88 } 89 90 LLVMContext Context; 91 std::unique_ptr<LLVMTargetMachine> TM; 92 std::unique_ptr<Module> M; 93 Function *F; 94 GlobalVariable *G; 95 GlobalAlias *AliasedG; 96 std::unique_ptr<MachineFunction> MF; 97 std::unique_ptr<SelectionDAG> DAG; 98 }; 99 100 TEST_F(SelectionDAGPatternMatchTest, matchValueType) { 101 SDLoc DL; 102 auto Int32VT = EVT::getIntegerVT(Context, 32); 103 auto Float32VT = EVT::getFloatingPointVT(32); 104 auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4); 105 106 SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); 107 SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Float32VT); 108 SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT); 109 110 using namespace SDPatternMatch; 111 EXPECT_TRUE(sd_match(Op0, m_SpecificVT(Int32VT))); 112 EVT BindVT; 113 EXPECT_TRUE(sd_match(Op1, m_VT(BindVT))); 114 EXPECT_EQ(BindVT, Float32VT); 115 EXPECT_TRUE(sd_match(Op0, m_IntegerVT())); 116 EXPECT_TRUE(sd_match(Op1, m_FloatingPointVT())); 117 EXPECT_TRUE(sd_match(Op2, m_VectorVT())); 118 EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT())); 119 } 120 121 TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { 122 SDLoc DL; 123 auto Int32VT = EVT::getIntegerVT(Context, 32); 124 auto Float32VT = EVT::getFloatingPointVT(32); 125 126 SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); 127 SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT); 128 SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT); 129 130 SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1); 131 SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0); 132 SDValue Mul = DAG->getNode(ISD::MUL, DL, Int32VT, Add, Sub); 133 SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1); 134 SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0); 135 SDValue Or = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1); 136 SDValue SMax = DAG->getNode(ISD::SMAX, DL, Int32VT, Op0, Op1); 137 SDValue SMin = DAG->getNode(ISD::SMIN, DL, Int32VT, Op1, Op0); 138 SDValue UMax = DAG->getNode(ISD::UMAX, DL, Int32VT, Op0, Op1); 139 SDValue UMin = DAG->getNode(ISD::UMIN, DL, Int32VT, Op1, Op0); 140 141 SDValue SFAdd = DAG->getNode(ISD::STRICT_FADD, DL, {Float32VT, MVT::Other}, 142 {DAG->getEntryNode(), Op2, Op2}); 143 144 using namespace SDPatternMatch; 145 EXPECT_TRUE(sd_match(Sub, m_BinOp(ISD::SUB, m_Value(), m_Value()))); 146 EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value()))); 147 EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value()))); 148 EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value()))); 149 EXPECT_TRUE(sd_match( 150 Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add))))); 151 EXPECT_TRUE( 152 sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_SpecificVT(Float32VT), 153 m_SpecificVT(Float32VT)))); 154 155 EXPECT_TRUE(sd_match(And, m_c_BinOp(ISD::AND, m_Value(), m_Value()))); 156 EXPECT_TRUE(sd_match(And, m_And(m_Value(), m_Value()))); 157 EXPECT_TRUE(sd_match(Xor, m_c_BinOp(ISD::XOR, m_Value(), m_Value()))); 158 EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value()))); 159 EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value()))); 160 EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value()))); 161 162 EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value()))); 163 EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value()))); 164 EXPECT_TRUE(sd_match(SMin, m_c_BinOp(ISD::SMIN, m_Value(), m_Value()))); 165 EXPECT_TRUE(sd_match(SMin, m_SMin(m_Value(), m_Value()))); 166 EXPECT_TRUE(sd_match(UMax, m_c_BinOp(ISD::UMAX, m_Value(), m_Value()))); 167 EXPECT_TRUE(sd_match(UMax, m_UMax(m_Value(), m_Value()))); 168 EXPECT_TRUE(sd_match(UMin, m_c_BinOp(ISD::UMIN, m_Value(), m_Value()))); 169 EXPECT_TRUE(sd_match(UMin, m_UMin(m_Value(), m_Value()))); 170 171 SDValue BindVal; 172 EXPECT_TRUE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_Value(BindVal), 173 m_Deferred(BindVal)))); 174 EXPECT_FALSE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_OtherVT(), 175 m_SpecificVT(Float32VT)))); 176 } 177 178 TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) { 179 SDLoc DL; 180 auto Int32VT = EVT::getIntegerVT(Context, 32); 181 auto Int64VT = EVT::getIntegerVT(Context, 64); 182 183 SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); 184 SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT); 185 186 SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0); 187 SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0); 188 SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1); 189 190 SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Trunc, Op0); 191 SDValue Neg = DAG->getNegative(Op0, DL, Int32VT); 192 SDValue Not = DAG->getNOT(DL, Op0, Int32VT); 193 194 using namespace SDPatternMatch; 195 EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value()))); 196 EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value()))); 197 EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1)))); 198 199 EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value()))); 200 EXPECT_TRUE(sd_match(Not, m_Not(m_Value()))); 201 EXPECT_FALSE(sd_match(ZExt, m_Neg(m_Value()))); 202 EXPECT_FALSE(sd_match(Sub, m_Neg(m_Value()))); 203 EXPECT_FALSE(sd_match(Neg, m_Not(m_Value()))); 204 } 205 206 TEST_F(SelectionDAGPatternMatchTest, matchConstants) { 207 SDLoc DL; 208 auto Int32VT = EVT::getIntegerVT(Context, 32); 209 auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4); 210 211 SDValue Arg0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); 212 213 SDValue Const3 = DAG->getConstant(3, DL, Int32VT); 214 SDValue Const87 = DAG->getConstant(87, DL, Int32VT); 215 SDValue Splat = DAG->getSplat(VInt32VT, DL, Arg0); 216 SDValue ConstSplat = DAG->getSplat(VInt32VT, DL, Const3); 217 SDValue Zero = DAG->getConstant(0, DL, Int32VT); 218 SDValue One = DAG->getConstant(1, DL, Int32VT); 219 SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, Int32VT); 220 221 using namespace SDPatternMatch; 222 EXPECT_TRUE(sd_match(Const87, m_ConstInt())); 223 EXPECT_FALSE(sd_match(Arg0, m_ConstInt())); 224 APInt ConstVal; 225 EXPECT_TRUE(sd_match(ConstSplat, m_ConstInt(ConstVal))); 226 EXPECT_EQ(ConstVal, 3); 227 EXPECT_FALSE(sd_match(Splat, m_ConstInt())); 228 229 EXPECT_TRUE(sd_match(Const87, m_SpecificInt(87))); 230 EXPECT_TRUE(sd_match(Const3, m_SpecificInt(ConstVal))); 231 EXPECT_TRUE(sd_match(AllOnes, m_AllOnes())); 232 233 EXPECT_TRUE(sd_match(Zero, DAG.get(), m_False())); 234 EXPECT_TRUE(sd_match(One, DAG.get(), m_True())); 235 EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_True())); 236 } 237 238 TEST_F(SelectionDAGPatternMatchTest, patternCombinators) { 239 SDLoc DL; 240 auto Int32VT = EVT::getIntegerVT(Context, 32); 241 242 SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); 243 SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT); 244 245 SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1); 246 SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0); 247 248 using namespace SDPatternMatch; 249 EXPECT_TRUE(sd_match( 250 Sub, m_AnyOf(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL)))); 251 EXPECT_TRUE(sd_match(Add, m_AllOf(m_Opc(ISD::ADD), m_OneUse()))); 252 } 253 254 TEST_F(SelectionDAGPatternMatchTest, matchNode) { 255 SDLoc DL; 256 auto Int32VT = EVT::getIntegerVT(Context, 32); 257 258 SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); 259 SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT); 260 261 SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1); 262 263 using namespace SDPatternMatch; 264 EXPECT_TRUE(sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value()))); 265 EXPECT_FALSE(sd_match(Add, m_Node(ISD::SUB, m_Value(), m_Value()))); 266 EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_Value()))); 267 EXPECT_FALSE( 268 sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value(), m_Value()))); 269 EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value()))); 270 } 271 272 namespace { 273 struct VPMatchContext : public SDPatternMatch::BasicMatchContext { 274 using SDPatternMatch::BasicMatchContext::BasicMatchContext; 275 276 bool match(SDValue OpVal, unsigned Opc) const { 277 if (!OpVal->isVPOpcode()) 278 return OpVal->getOpcode() == Opc; 279 280 auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false); 281 return BaseOpc.has_value() && *BaseOpc == Opc; 282 } 283 }; 284 } // anonymous namespace 285 TEST_F(SelectionDAGPatternMatchTest, matchContext) { 286 SDLoc DL; 287 auto BoolVT = EVT::getIntegerVT(Context, 1); 288 auto Int32VT = EVT::getIntegerVT(Context, 32); 289 auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4); 290 auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4); 291 292 SDValue Scalar0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); 293 SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT); 294 SDValue Mask0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, MaskVT); 295 296 SDValue VPAdd = DAG->getNode(ISD::VP_ADD, DL, VInt32VT, 297 {Vector0, Vector0, Mask0, Scalar0}); 298 SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT, 299 {Scalar0, VPAdd, Mask0, Scalar0}); 300 301 using namespace SDPatternMatch; 302 VPMatchContext VPCtx(DAG.get()); 303 EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD))); 304 // VP_REDUCE_ADD doesn't have a based opcode, so we use a normal 305 // sd_match before switching to VPMatchContext when checking VPAdd. 306 EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(), 307 m_Context(VPCtx, m_Opc(ISD::ADD)), 308 m_Value(), m_Value()))); 309 } 310 311 TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) { 312 SDLoc DL; 313 auto Int16VT = EVT::getIntegerVT(Context, 16); 314 auto Int64VT = EVT::getIntegerVT(Context, 64); 315 316 SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT); 317 SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int16VT); 318 319 SDValue Add = DAG->getNode(ISD::ADD, DL, Int64VT, Op0, Op0); 320 321 using namespace SDPatternMatch; 322 EXPECT_TRUE(sd_match(Op0, DAG.get(), m_LegalType(m_Value()))); 323 EXPECT_FALSE(sd_match(Op1, DAG.get(), m_LegalType(m_Value()))); 324 EXPECT_TRUE(sd_match(Add, DAG.get(), 325 m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value()))))); 326 } 327