xref: /llvm-project/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp (revision 4169338e75cdce73d34063532db598c95ee82ae4)
1 //===---- llvm/unittest/CodeGen/SelectionDAGPatternMatchTest.cpp  ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/CodeGen/MachineModuleInfo.h"
12 #include "llvm/CodeGen/SDPatternMatch.h"
13 #include "llvm/CodeGen/TargetLowering.h"
14 #include "llvm/IR/Module.h"
15 #include "llvm/MC/TargetRegistry.h"
16 #include "llvm/Support/SourceMgr.h"
17 #include "llvm/Support/TargetSelect.h"
18 #include "llvm/Target/TargetMachine.h"
19 #include "gtest/gtest.h"
20 
21 using namespace llvm;
22 
23 class SelectionDAGPatternMatchTest : public testing::Test {
24 protected:
25   static void SetUpTestCase() {
26     InitializeAllTargets();
27     InitializeAllTargetMCs();
28   }
29 
30   void SetUp() override {
31     StringRef Assembly = "@g = global i32 0\n"
32                          "@g_alias = alias i32, i32* @g\n"
33                          "define i32 @f() {\n"
34                          "  %1 = load i32, i32* @g\n"
35                          "  ret i32 %1\n"
36                          "}";
37 
38     Triple TargetTriple("riscv64--");
39     std::string Error;
40     const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
41     // FIXME: These tests do not depend on RISCV specifically, but we have to
42     // initialize a target. A skeleton Target for unittests would allow us to
43     // always run these tests.
44     if (!T)
45       GTEST_SKIP();
46 
47     TargetOptions Options;
48     TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine *>(
49         T->createTargetMachine("riscv64", "", "+m,+f,+d,+v", Options,
50                                std::nullopt, std::nullopt,
51                                CodeGenOptLevel::Aggressive)));
52     if (!TM)
53       GTEST_SKIP();
54 
55     SMDiagnostic SMError;
56     M = parseAssemblyString(Assembly, SMError, Context);
57     if (!M)
58       report_fatal_error(SMError.getMessage());
59     M->setDataLayout(TM->createDataLayout());
60 
61     F = M->getFunction("f");
62     if (!F)
63       report_fatal_error("F?");
64     G = M->getGlobalVariable("g");
65     if (!G)
66       report_fatal_error("G?");
67     AliasedG = M->getNamedAlias("g_alias");
68     if (!AliasedG)
69       report_fatal_error("AliasedG?");
70 
71     MachineModuleInfo MMI(TM.get());
72 
73     MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F),
74                                            0, MMI);
75 
76     DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::None);
77     if (!DAG)
78       report_fatal_error("DAG?");
79     OptimizationRemarkEmitter ORE(F);
80     DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr);
81   }
82 
83   TargetLoweringBase::LegalizeTypeAction getTypeAction(EVT VT) {
84     return DAG->getTargetLoweringInfo().getTypeAction(Context, VT);
85   }
86 
87   EVT getTypeToTransformTo(EVT VT) {
88     return DAG->getTargetLoweringInfo().getTypeToTransformTo(Context, VT);
89   }
90 
91   LLVMContext Context;
92   std::unique_ptr<LLVMTargetMachine> TM;
93   std::unique_ptr<Module> M;
94   Function *F;
95   GlobalVariable *G;
96   GlobalAlias *AliasedG;
97   std::unique_ptr<MachineFunction> MF;
98   std::unique_ptr<SelectionDAG> DAG;
99 };
100 
101 TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
102   SDLoc DL;
103   auto Int32VT = EVT::getIntegerVT(Context, 32);
104   auto Float32VT = EVT::getFloatingPointVT(32);
105   auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
106 
107   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
108   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Float32VT);
109   SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
110 
111   using namespace SDPatternMatch;
112   EXPECT_TRUE(sd_match(Op0, m_SpecificVT(Int32VT)));
113   EVT BindVT;
114   EXPECT_TRUE(sd_match(Op1, m_VT(BindVT)));
115   EXPECT_EQ(BindVT, Float32VT);
116   EXPECT_TRUE(sd_match(Op0, m_IntegerVT()));
117   EXPECT_TRUE(sd_match(Op1, m_FloatingPointVT()));
118   EXPECT_TRUE(sd_match(Op2, m_VectorVT()));
119   EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
120 }
121 
122 TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
123   SDLoc DL;
124   auto Int32VT = EVT::getIntegerVT(Context, 32);
125   auto Float32VT = EVT::getFloatingPointVT(32);
126 
127   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
128   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
129   SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
130 
131   SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
132   SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
133   SDValue Mul = DAG->getNode(ISD::MUL, DL, Int32VT, Add, Sub);
134   SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
135   SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0);
136   SDValue Or  = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
137   SDValue SMax = DAG->getNode(ISD::SMAX, DL, Int32VT, Op0, Op1);
138   SDValue SMin = DAG->getNode(ISD::SMIN, DL, Int32VT, Op1, Op0);
139   SDValue UMax = DAG->getNode(ISD::UMAX, DL, Int32VT, Op0, Op1);
140   SDValue UMin = DAG->getNode(ISD::UMIN, DL, Int32VT, Op1, Op0);
141 
142   SDValue SFAdd = DAG->getNode(ISD::STRICT_FADD, DL, {Float32VT, MVT::Other},
143                                {DAG->getEntryNode(), Op2, Op2});
144 
145   using namespace SDPatternMatch;
146   EXPECT_TRUE(sd_match(Sub, m_BinOp(ISD::SUB, m_Value(), m_Value())));
147   EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
148   EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value())));
149   EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
150   EXPECT_TRUE(sd_match(
151       Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
152   EXPECT_TRUE(
153       sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_SpecificVT(Float32VT),
154                                      m_SpecificVT(Float32VT))));
155 
156   EXPECT_TRUE(sd_match(And, m_c_BinOp(ISD::AND, m_Value(), m_Value())));
157   EXPECT_TRUE(sd_match(And, m_And(m_Value(), m_Value())));
158   EXPECT_TRUE(sd_match(Xor, m_c_BinOp(ISD::XOR, m_Value(), m_Value())));
159   EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value())));
160   EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value())));
161   EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value())));
162 
163   EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value())));
164   EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value())));
165   EXPECT_TRUE(sd_match(SMin, m_c_BinOp(ISD::SMIN, m_Value(), m_Value())));
166   EXPECT_TRUE(sd_match(SMin, m_SMin(m_Value(), m_Value())));
167   EXPECT_TRUE(sd_match(UMax, m_c_BinOp(ISD::UMAX, m_Value(), m_Value())));
168   EXPECT_TRUE(sd_match(UMax, m_UMax(m_Value(), m_Value())));
169   EXPECT_TRUE(sd_match(UMin, m_c_BinOp(ISD::UMIN, m_Value(), m_Value())));
170   EXPECT_TRUE(sd_match(UMin, m_UMin(m_Value(), m_Value())));
171 
172   SDValue BindVal;
173   EXPECT_TRUE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_Value(BindVal),
174                                              m_Deferred(BindVal))));
175   EXPECT_FALSE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_OtherVT(),
176                                               m_SpecificVT(Float32VT))));
177 }
178 
179 TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
180   SDLoc DL;
181   auto Int32VT = EVT::getIntegerVT(Context, 32);
182   auto Int64VT = EVT::getIntegerVT(Context, 64);
183 
184   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
185   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
186 
187   SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0);
188   SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0);
189   SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1);
190 
191   SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Trunc, Op0);
192   SDValue Neg = DAG->getNegative(Op0, DL, Int32VT);
193   SDValue Not = DAG->getNOT(DL, Op0, Int32VT);
194 
195   using namespace SDPatternMatch;
196   EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value())));
197   EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value())));
198   EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1))));
199 
200   EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value())));
201   EXPECT_TRUE(sd_match(Not, m_Not(m_Value())));
202   EXPECT_FALSE(sd_match(ZExt, m_Neg(m_Value())));
203   EXPECT_FALSE(sd_match(Sub, m_Neg(m_Value())));
204   EXPECT_FALSE(sd_match(Neg, m_Not(m_Value())));
205 }
206 
207 TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
208   SDLoc DL;
209   auto Int32VT = EVT::getIntegerVT(Context, 32);
210   auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
211 
212   SDValue Arg0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
213 
214   SDValue Const3 = DAG->getConstant(3, DL, Int32VT);
215   SDValue Const87 = DAG->getConstant(87, DL, Int32VT);
216   SDValue Splat = DAG->getSplat(VInt32VT, DL, Arg0);
217   SDValue ConstSplat = DAG->getSplat(VInt32VT, DL, Const3);
218   SDValue Zero = DAG->getConstant(0, DL, Int32VT);
219   SDValue One = DAG->getConstant(1, DL, Int32VT);
220   SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, Int32VT);
221   SDValue SetCC = DAG->getSetCC(DL, Int32VT, Arg0, Const3, ISD::SETULT);
222 
223   using namespace SDPatternMatch;
224   EXPECT_TRUE(sd_match(Const87, m_ConstInt()));
225   EXPECT_FALSE(sd_match(Arg0, m_ConstInt()));
226   APInt ConstVal;
227   EXPECT_TRUE(sd_match(ConstSplat, m_ConstInt(ConstVal)));
228   EXPECT_EQ(ConstVal, 3);
229   EXPECT_FALSE(sd_match(Splat, m_ConstInt()));
230 
231   EXPECT_TRUE(sd_match(Const87, m_SpecificInt(87)));
232   EXPECT_TRUE(sd_match(Const3, m_SpecificInt(ConstVal)));
233   EXPECT_TRUE(sd_match(AllOnes, m_AllOnes()));
234 
235   EXPECT_TRUE(sd_match(Zero, DAG.get(), m_False()));
236   EXPECT_TRUE(sd_match(One, DAG.get(), m_True()));
237   EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_True()));
238 
239   ISD::CondCode CC;
240   EXPECT_TRUE(sd_match(
241       SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(), m_CondCode(CC))));
242   EXPECT_EQ(CC, ISD::SETULT);
243   EXPECT_TRUE(sd_match(SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(),
244                                      m_SpecificCondCode(ISD::SETULT))));
245 }
246 
247 TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
248   SDLoc DL;
249   auto Int32VT = EVT::getIntegerVT(Context, 32);
250 
251   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
252   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
253 
254   SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
255   SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
256 
257   using namespace SDPatternMatch;
258   EXPECT_TRUE(sd_match(
259       Sub, m_AnyOf(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
260   EXPECT_TRUE(sd_match(Add, m_AllOf(m_Opc(ISD::ADD), m_OneUse())));
261   EXPECT_TRUE(sd_match(Add, m_NoneOf(m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
262 }
263 
264 TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
265   SDLoc DL;
266   auto Int32VT = EVT::getIntegerVT(Context, 32);
267   auto Int64VT = EVT::getIntegerVT(Context, 64);
268 
269   SDValue Op32 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
270   SDValue Op64 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
271   SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op32);
272   SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op32);
273   SDValue AExt = DAG->getNode(ISD::ANY_EXTEND, DL, Int64VT, Op32);
274   SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op64);
275 
276   using namespace SDPatternMatch;
277   SDValue A;
278   EXPECT_TRUE(sd_match(Op32, m_ZExtOrSelf(m_Value(A))));
279   EXPECT_TRUE(A == Op32);
280   EXPECT_TRUE(sd_match(ZExt, m_ZExtOrSelf(m_Value(A))));
281   EXPECT_TRUE(A == Op32);
282   EXPECT_TRUE(sd_match(Op64, m_SExtOrSelf(m_Value(A))));
283   EXPECT_TRUE(A == Op64);
284   EXPECT_TRUE(sd_match(SExt, m_SExtOrSelf(m_Value(A))));
285   EXPECT_TRUE(A == Op32);
286   EXPECT_TRUE(sd_match(Op32, m_AExtOrSelf(m_Value(A))));
287   EXPECT_TRUE(A == Op32);
288   EXPECT_TRUE(sd_match(AExt, m_AExtOrSelf(m_Value(A))));
289   EXPECT_TRUE(A == Op32);
290   EXPECT_TRUE(sd_match(Op64, m_TruncOrSelf(m_Value(A))));
291   EXPECT_TRUE(A == Op64);
292   EXPECT_TRUE(sd_match(Trunc, m_TruncOrSelf(m_Value(A))));
293   EXPECT_TRUE(A == Op64);
294 }
295 
296 TEST_F(SelectionDAGPatternMatchTest, matchNode) {
297   SDLoc DL;
298   auto Int32VT = EVT::getIntegerVT(Context, 32);
299 
300   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
301   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
302 
303   SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
304 
305   using namespace SDPatternMatch;
306   EXPECT_TRUE(sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value())));
307   EXPECT_FALSE(sd_match(Add, m_Node(ISD::SUB, m_Value(), m_Value())));
308   EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_Value())));
309   EXPECT_FALSE(
310       sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value(), m_Value())));
311   EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value())));
312 }
313 
314 namespace {
315 struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
316   using SDPatternMatch::BasicMatchContext::BasicMatchContext;
317 
318   bool match(SDValue OpVal, unsigned Opc) const {
319     if (!OpVal->isVPOpcode())
320       return OpVal->getOpcode() == Opc;
321 
322     auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
323     return BaseOpc.has_value() && *BaseOpc == Opc;
324   }
325 };
326 } // anonymous namespace
327 TEST_F(SelectionDAGPatternMatchTest, matchContext) {
328   SDLoc DL;
329   auto BoolVT = EVT::getIntegerVT(Context, 1);
330   auto Int32VT = EVT::getIntegerVT(Context, 32);
331   auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
332   auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
333 
334   SDValue Scalar0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
335   SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
336   SDValue Mask0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, MaskVT);
337 
338   SDValue VPAdd = DAG->getNode(ISD::VP_ADD, DL, VInt32VT,
339                                {Vector0, Vector0, Mask0, Scalar0});
340   SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
341                                      {Scalar0, VPAdd, Mask0, Scalar0});
342 
343   using namespace SDPatternMatch;
344   VPMatchContext VPCtx(DAG.get());
345   EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
346   // VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
347   // sd_match before switching to VPMatchContext when checking VPAdd.
348   EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
349                                            m_Context(VPCtx, m_Opc(ISD::ADD)),
350                                            m_Value(), m_Value())));
351 }
352 
353 TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
354   SDLoc DL;
355   auto Int16VT = EVT::getIntegerVT(Context, 16);
356   auto Int64VT = EVT::getIntegerVT(Context, 64);
357 
358   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
359   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int16VT);
360 
361   SDValue Add = DAG->getNode(ISD::ADD, DL, Int64VT, Op0, Op0);
362 
363   using namespace SDPatternMatch;
364   EXPECT_TRUE(sd_match(Op0, DAG.get(), m_LegalType(m_Value())));
365   EXPECT_FALSE(sd_match(Op1, DAG.get(), m_LegalType(m_Value())));
366   EXPECT_TRUE(sd_match(Add, DAG.get(),
367                        m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
368 }
369