xref: /llvm-project/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp (revision 4f54b91842ea2ab9546459869df442f7e7fe59d6)
1 //===---- llvm/unittest/CodeGen/SelectionDAGPatternMatchTest.cpp  ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/CodeGen/MachineModuleInfo.h"
12 #include "llvm/CodeGen/SDPatternMatch.h"
13 #include "llvm/CodeGen/TargetLowering.h"
14 #include "llvm/MC/TargetRegistry.h"
15 #include "llvm/Support/SourceMgr.h"
16 #include "llvm/Support/TargetSelect.h"
17 #include "llvm/Target/TargetMachine.h"
18 #include "gtest/gtest.h"
19 
20 using namespace llvm;
21 
22 class SelectionDAGPatternMatchTest : public testing::Test {
23 protected:
24   static void SetUpTestCase() {
25     InitializeAllTargets();
26     InitializeAllTargetMCs();
27   }
28 
29   void SetUp() override {
30     StringRef Assembly = "@g = global i32 0\n"
31                          "@g_alias = alias i32, i32* @g\n"
32                          "define i32 @f() {\n"
33                          "  %1 = load i32, i32* @g\n"
34                          "  ret i32 %1\n"
35                          "}";
36 
37     Triple TargetTriple("riscv64--");
38     std::string Error;
39     const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
40     // FIXME: These tests do not depend on RISCV specifically, but we have to
41     // initialize a target. A skeleton Target for unittests would allow us to
42     // always run these tests.
43     if (!T)
44       GTEST_SKIP();
45 
46     TargetOptions Options;
47     TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine *>(
48         T->createTargetMachine("riscv64", "", "+m,+f,+d,+v", Options,
49                                std::nullopt, std::nullopt,
50                                CodeGenOptLevel::Aggressive)));
51     if (!TM)
52       GTEST_SKIP();
53 
54     SMDiagnostic SMError;
55     M = parseAssemblyString(Assembly, SMError, Context);
56     if (!M)
57       report_fatal_error(SMError.getMessage());
58     M->setDataLayout(TM->createDataLayout());
59 
60     F = M->getFunction("f");
61     if (!F)
62       report_fatal_error("F?");
63     G = M->getGlobalVariable("g");
64     if (!G)
65       report_fatal_error("G?");
66     AliasedG = M->getNamedAlias("g_alias");
67     if (!AliasedG)
68       report_fatal_error("AliasedG?");
69 
70     MachineModuleInfo MMI(TM.get());
71 
72     MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F),
73                                            0, MMI);
74 
75     DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::None);
76     if (!DAG)
77       report_fatal_error("DAG?");
78     OptimizationRemarkEmitter ORE(F);
79     DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr);
80   }
81 
82   TargetLoweringBase::LegalizeTypeAction getTypeAction(EVT VT) {
83     return DAG->getTargetLoweringInfo().getTypeAction(Context, VT);
84   }
85 
86   EVT getTypeToTransformTo(EVT VT) {
87     return DAG->getTargetLoweringInfo().getTypeToTransformTo(Context, VT);
88   }
89 
90   LLVMContext Context;
91   std::unique_ptr<LLVMTargetMachine> TM;
92   std::unique_ptr<Module> M;
93   Function *F;
94   GlobalVariable *G;
95   GlobalAlias *AliasedG;
96   std::unique_ptr<MachineFunction> MF;
97   std::unique_ptr<SelectionDAG> DAG;
98 };
99 
100 TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
101   SDLoc DL;
102   auto Int32VT = EVT::getIntegerVT(Context, 32);
103   auto Float32VT = EVT::getFloatingPointVT(32);
104   auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
105 
106   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
107   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Float32VT);
108   SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
109 
110   using namespace SDPatternMatch;
111   EXPECT_TRUE(sd_match(Op0, m_SpecificVT(Int32VT)));
112   EVT BindVT;
113   EXPECT_TRUE(sd_match(Op1, m_VT(BindVT)));
114   EXPECT_EQ(BindVT, Float32VT);
115   EXPECT_TRUE(sd_match(Op0, m_IntegerVT()));
116   EXPECT_TRUE(sd_match(Op1, m_FloatingPointVT()));
117   EXPECT_TRUE(sd_match(Op2, m_VectorVT()));
118   EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
119 }
120 
121 TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
122   SDLoc DL;
123   auto Int32VT = EVT::getIntegerVT(Context, 32);
124   auto Float32VT = EVT::getFloatingPointVT(32);
125 
126   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
127   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
128   SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
129 
130   SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
131   SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
132   SDValue Mul = DAG->getNode(ISD::MUL, DL, Int32VT, Add, Sub);
133   SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
134   SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0);
135   SDValue Or  = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
136   SDValue SMax = DAG->getNode(ISD::SMAX, DL, Int32VT, Op0, Op1);
137   SDValue SMin = DAG->getNode(ISD::SMIN, DL, Int32VT, Op1, Op0);
138   SDValue UMax = DAG->getNode(ISD::UMAX, DL, Int32VT, Op0, Op1);
139   SDValue UMin = DAG->getNode(ISD::UMIN, DL, Int32VT, Op1, Op0);
140 
141   SDValue SFAdd = DAG->getNode(ISD::STRICT_FADD, DL, {Float32VT, MVT::Other},
142                                {DAG->getEntryNode(), Op2, Op2});
143 
144   using namespace SDPatternMatch;
145   EXPECT_TRUE(sd_match(Sub, m_BinOp(ISD::SUB, m_Value(), m_Value())));
146   EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
147   EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value())));
148   EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
149   EXPECT_TRUE(sd_match(
150       Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
151   EXPECT_TRUE(
152       sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_SpecificVT(Float32VT),
153                                      m_SpecificVT(Float32VT))));
154 
155   EXPECT_TRUE(sd_match(And, m_c_BinOp(ISD::AND, m_Value(), m_Value())));
156   EXPECT_TRUE(sd_match(And, m_And(m_Value(), m_Value())));
157   EXPECT_TRUE(sd_match(Xor, m_c_BinOp(ISD::XOR, m_Value(), m_Value())));
158   EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value())));
159   EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value())));
160   EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value())));
161 
162   EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value())));
163   EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value())));
164   EXPECT_TRUE(sd_match(SMin, m_c_BinOp(ISD::SMIN, m_Value(), m_Value())));
165   EXPECT_TRUE(sd_match(SMin, m_SMin(m_Value(), m_Value())));
166   EXPECT_TRUE(sd_match(UMax, m_c_BinOp(ISD::UMAX, m_Value(), m_Value())));
167   EXPECT_TRUE(sd_match(UMax, m_UMax(m_Value(), m_Value())));
168   EXPECT_TRUE(sd_match(UMin, m_c_BinOp(ISD::UMIN, m_Value(), m_Value())));
169   EXPECT_TRUE(sd_match(UMin, m_UMin(m_Value(), m_Value())));
170 
171   SDValue BindVal;
172   EXPECT_TRUE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_Value(BindVal),
173                                              m_Deferred(BindVal))));
174   EXPECT_FALSE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_OtherVT(),
175                                               m_SpecificVT(Float32VT))));
176 }
177 
178 TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
179   SDLoc DL;
180   auto Int32VT = EVT::getIntegerVT(Context, 32);
181   auto Int64VT = EVT::getIntegerVT(Context, 64);
182 
183   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
184   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
185 
186   SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0);
187   SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0);
188   SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1);
189 
190   SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Trunc, Op0);
191   SDValue Neg = DAG->getNegative(Op0, DL, Int32VT);
192   SDValue Not = DAG->getNOT(DL, Op0, Int32VT);
193 
194   using namespace SDPatternMatch;
195   EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value())));
196   EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value())));
197   EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1))));
198 
199   EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value())));
200   EXPECT_TRUE(sd_match(Not, m_Not(m_Value())));
201   EXPECT_FALSE(sd_match(ZExt, m_Neg(m_Value())));
202   EXPECT_FALSE(sd_match(Sub, m_Neg(m_Value())));
203   EXPECT_FALSE(sd_match(Neg, m_Not(m_Value())));
204 }
205 
206 TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
207   SDLoc DL;
208   auto Int32VT = EVT::getIntegerVT(Context, 32);
209   auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
210 
211   SDValue Arg0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
212 
213   SDValue Const3 = DAG->getConstant(3, DL, Int32VT);
214   SDValue Const87 = DAG->getConstant(87, DL, Int32VT);
215   SDValue Splat = DAG->getSplat(VInt32VT, DL, Arg0);
216   SDValue ConstSplat = DAG->getSplat(VInt32VT, DL, Const3);
217   SDValue Zero = DAG->getConstant(0, DL, Int32VT);
218   SDValue One = DAG->getConstant(1, DL, Int32VT);
219   SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, Int32VT);
220   SDValue SetCC = DAG->getSetCC(DL, Int32VT, Arg0, Const3, ISD::SETULT);
221 
222   using namespace SDPatternMatch;
223   EXPECT_TRUE(sd_match(Const87, m_ConstInt()));
224   EXPECT_FALSE(sd_match(Arg0, m_ConstInt()));
225   APInt ConstVal;
226   EXPECT_TRUE(sd_match(ConstSplat, m_ConstInt(ConstVal)));
227   EXPECT_EQ(ConstVal, 3);
228   EXPECT_FALSE(sd_match(Splat, m_ConstInt()));
229 
230   EXPECT_TRUE(sd_match(Const87, m_SpecificInt(87)));
231   EXPECT_TRUE(sd_match(Const3, m_SpecificInt(ConstVal)));
232   EXPECT_TRUE(sd_match(AllOnes, m_AllOnes()));
233 
234   EXPECT_TRUE(sd_match(Zero, DAG.get(), m_False()));
235   EXPECT_TRUE(sd_match(One, DAG.get(), m_True()));
236   EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_True()));
237 
238   ISD::CondCode CC;
239   EXPECT_TRUE(sd_match(
240       SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(), m_CondCode(CC))));
241   EXPECT_EQ(CC, ISD::SETULT);
242   EXPECT_TRUE(sd_match(SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(),
243                                      m_SpecificCondCode(ISD::SETULT))));
244 }
245 
246 TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
247   SDLoc DL;
248   auto Int32VT = EVT::getIntegerVT(Context, 32);
249 
250   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
251   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
252 
253   SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
254   SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
255 
256   using namespace SDPatternMatch;
257   EXPECT_TRUE(sd_match(
258       Sub, m_AnyOf(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
259   EXPECT_TRUE(sd_match(Add, m_AllOf(m_Opc(ISD::ADD), m_OneUse())));
260   EXPECT_TRUE(sd_match(Add, m_NoneOf(m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
261 }
262 
263 TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
264   SDLoc DL;
265   auto Int32VT = EVT::getIntegerVT(Context, 32);
266   auto Int64VT = EVT::getIntegerVT(Context, 64);
267 
268   SDValue Op32 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
269   SDValue Op64 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
270   SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op32);
271   SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op32);
272   SDValue AExt = DAG->getNode(ISD::ANY_EXTEND, DL, Int64VT, Op32);
273   SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op64);
274 
275   using namespace SDPatternMatch;
276   SDValue A;
277   EXPECT_TRUE(sd_match(Op32, m_ZExtOrSelf(m_Value(A))));
278   EXPECT_TRUE(A == Op32);
279   EXPECT_TRUE(sd_match(ZExt, m_ZExtOrSelf(m_Value(A))));
280   EXPECT_TRUE(A == Op32);
281   EXPECT_TRUE(sd_match(Op64, m_SExtOrSelf(m_Value(A))));
282   EXPECT_TRUE(A == Op64);
283   EXPECT_TRUE(sd_match(SExt, m_SExtOrSelf(m_Value(A))));
284   EXPECT_TRUE(A == Op32);
285   EXPECT_TRUE(sd_match(Op32, m_AExtOrSelf(m_Value(A))));
286   EXPECT_TRUE(A == Op32);
287   EXPECT_TRUE(sd_match(AExt, m_AExtOrSelf(m_Value(A))));
288   EXPECT_TRUE(A == Op32);
289   EXPECT_TRUE(sd_match(Op64, m_TruncOrSelf(m_Value(A))));
290   EXPECT_TRUE(A == Op64);
291   EXPECT_TRUE(sd_match(Trunc, m_TruncOrSelf(m_Value(A))));
292   EXPECT_TRUE(A == Op64);
293 }
294 
295 TEST_F(SelectionDAGPatternMatchTest, matchNode) {
296   SDLoc DL;
297   auto Int32VT = EVT::getIntegerVT(Context, 32);
298 
299   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
300   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
301 
302   SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
303 
304   using namespace SDPatternMatch;
305   EXPECT_TRUE(sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value())));
306   EXPECT_FALSE(sd_match(Add, m_Node(ISD::SUB, m_Value(), m_Value())));
307   EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_Value())));
308   EXPECT_FALSE(
309       sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value(), m_Value())));
310   EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value())));
311 }
312 
313 namespace {
314 struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
315   using SDPatternMatch::BasicMatchContext::BasicMatchContext;
316 
317   bool match(SDValue OpVal, unsigned Opc) const {
318     if (!OpVal->isVPOpcode())
319       return OpVal->getOpcode() == Opc;
320 
321     auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
322     return BaseOpc.has_value() && *BaseOpc == Opc;
323   }
324 };
325 } // anonymous namespace
326 TEST_F(SelectionDAGPatternMatchTest, matchContext) {
327   SDLoc DL;
328   auto BoolVT = EVT::getIntegerVT(Context, 1);
329   auto Int32VT = EVT::getIntegerVT(Context, 32);
330   auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
331   auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
332 
333   SDValue Scalar0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
334   SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
335   SDValue Mask0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, MaskVT);
336 
337   SDValue VPAdd = DAG->getNode(ISD::VP_ADD, DL, VInt32VT,
338                                {Vector0, Vector0, Mask0, Scalar0});
339   SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
340                                      {Scalar0, VPAdd, Mask0, Scalar0});
341 
342   using namespace SDPatternMatch;
343   VPMatchContext VPCtx(DAG.get());
344   EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
345   // VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
346   // sd_match before switching to VPMatchContext when checking VPAdd.
347   EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
348                                            m_Context(VPCtx, m_Opc(ISD::ADD)),
349                                            m_Value(), m_Value())));
350 }
351 
352 TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
353   SDLoc DL;
354   auto Int16VT = EVT::getIntegerVT(Context, 16);
355   auto Int64VT = EVT::getIntegerVT(Context, 64);
356 
357   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
358   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int16VT);
359 
360   SDValue Add = DAG->getNode(ISD::ADD, DL, Int64VT, Op0, Op0);
361 
362   using namespace SDPatternMatch;
363   EXPECT_TRUE(sd_match(Op0, DAG.get(), m_LegalType(m_Value())));
364   EXPECT_FALSE(sd_match(Op1, DAG.get(), m_LegalType(m_Value())));
365   EXPECT_TRUE(sd_match(Add, DAG.get(),
366                        m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
367 }
368