xref: /llvm-project/llvm/unittests/CodeGen/GlobalISel/GISelMITest.h (revision bb3f5e1fed7c6ba733b7f273e93f5d3930976185)
1 //===- GISelMITest.h --------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H
9 #define LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H
10 
11 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
12 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
13 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
14 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
15 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
16 #include "llvm/CodeGen/GlobalISel/Utils.h"
17 #include "llvm/CodeGen/MIRParser/MIRParser.h"
18 #include "llvm/CodeGen/MachineFunction.h"
19 #include "llvm/CodeGen/MachineModuleInfo.h"
20 #include "llvm/CodeGen/TargetFrameLowering.h"
21 #include "llvm/CodeGen/TargetInstrInfo.h"
22 #include "llvm/CodeGen/TargetLowering.h"
23 #include "llvm/CodeGen/TargetSubtargetInfo.h"
24 #include "llvm/FileCheck/FileCheck.h"
25 #include "llvm/InitializePasses.h"
26 #include "llvm/MC/TargetRegistry.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include "llvm/Support/TargetSelect.h"
29 #include "llvm/Target/TargetMachine.h"
30 #include "llvm/Target/TargetOptions.h"
31 #include "gtest/gtest.h"
32 
33 using namespace llvm;
34 using namespace MIPatternMatch;
35 
36 static inline void initLLVM() {
37   InitializeAllTargets();
38   InitializeAllTargetMCs();
39   InitializeAllAsmPrinters();
40   InitializeAllAsmParsers();
41 
42   PassRegistry *Registry = PassRegistry::getPassRegistry();
43   initializeCore(*Registry);
44   initializeCodeGen(*Registry);
45 }
46 
47 // Define a printers to help debugging when things go wrong.
48 namespace llvm {
49 std::ostream &
50 operator<<(std::ostream &OS, const LLT Ty);
51 
52 std::ostream &
53 operator<<(std::ostream &OS, const MachineFunction &MF);
54 }
55 
56 static std::unique_ptr<Module>
57 parseMIR(LLVMContext &Context, std::unique_ptr<MIRParser> &MIR,
58          const TargetMachine &TM, StringRef MIRCode, MachineModuleInfo &MMI) {
59   SMDiagnostic Diagnostic;
60   std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
61   MIR = createMIRParser(std::move(MBuffer), Context);
62   if (!MIR)
63     return nullptr;
64 
65   std::unique_ptr<Module> M = MIR->parseIRModule();
66   if (!M)
67     return nullptr;
68 
69   M->setDataLayout(TM.createDataLayout());
70 
71   if (MIR->parseMachineFunctions(*M, MMI))
72     return nullptr;
73 
74   return M;
75 }
76 static std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
77 createDummyModule(LLVMContext &Context, const TargetMachine &TM,
78                   StringRef MIRString, const char *FuncName) {
79   std::unique_ptr<MIRParser> MIR;
80   auto MMI = std::make_unique<MachineModuleInfo>(&TM);
81   std::unique_ptr<Module> M = parseMIR(Context, MIR, TM, MIRString, *MMI);
82   return make_pair(std::move(M), std::move(MMI));
83 }
84 
85 static MachineFunction *getMFFromMMI(const Module *M,
86                                      const MachineModuleInfo *MMI) {
87   Function *F = M->getFunction("func");
88   auto *MF = MMI->getMachineFunction(*F);
89   return MF;
90 }
91 
92 static void collectCopies(SmallVectorImpl<Register> &Copies,
93                           MachineFunction *MF) {
94   for (auto &MBB : *MF)
95     for (MachineInstr &MI : MBB) {
96       if (MI.getOpcode() == TargetOpcode::COPY)
97         Copies.push_back(MI.getOperand(0).getReg());
98     }
99 }
100 
101 class GISelMITest : public ::testing::Test {
102 protected:
103   GISelMITest() : ::testing::Test() {}
104 
105   /// Prepare a target specific TargetMachine.
106   virtual std::unique_ptr<TargetMachine> createTargetMachine() const = 0;
107 
108   /// Get the stub sample MIR test function.
109   virtual void getTargetTestModuleString(SmallString<512> &S,
110                                          StringRef MIRFunc) const = 0;
111 
112   void setUp(StringRef ExtraAssembly = "") {
113     TM = createTargetMachine();
114     if (!TM)
115       return;
116 
117     SmallString<512> MIRString;
118     getTargetTestModuleString(MIRString, ExtraAssembly);
119 
120     ModuleMMIPair = createDummyModule(Context, *TM, MIRString, "func");
121     MF = getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
122     collectCopies(Copies, MF);
123     EntryMBB = &*MF->begin();
124     B.setMF(*MF);
125     MRI = &MF->getRegInfo();
126     B.setInsertPt(*EntryMBB, EntryMBB->end());
127   }
128 
129   LLVMContext Context;
130   std::unique_ptr<TargetMachine> TM;
131   MachineFunction *MF;
132   std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
133       ModuleMMIPair;
134   SmallVector<Register, 4> Copies;
135   MachineBasicBlock *EntryMBB;
136   MachineIRBuilder B;
137   MachineRegisterInfo *MRI;
138 };
139 
140 class AArch64GISelMITest : public GISelMITest {
141   std::unique_ptr<TargetMachine> createTargetMachine() const override;
142   void getTargetTestModuleString(SmallString<512> &S,
143                                  StringRef MIRFunc) const override;
144 };
145 
146 class AMDGPUGISelMITest : public GISelMITest {
147   std::unique_ptr<TargetMachine> createTargetMachine() const override;
148   void getTargetTestModuleString(SmallString<512> &S,
149                                  StringRef MIRFunc) const override;
150 };
151 
152 #define DefineLegalizerInfo(Name, SettingUpActionsBlock)                       \
153   class Name##Info : public LegalizerInfo {                                    \
154   public:                                                                      \
155     Name##Info(const TargetSubtargetInfo &ST) {                                \
156       using namespace TargetOpcode;                                            \
157       const LLT s8 = LLT::scalar(8);                                           \
158       (void)s8;                                                                \
159       const LLT s16 = LLT::scalar(16);                                         \
160       (void)s16;                                                               \
161       const LLT s32 = LLT::scalar(32);                                         \
162       (void)s32;                                                               \
163       const LLT s64 = LLT::scalar(64);                                         \
164       (void)s64;                                                               \
165       const LLT s128 = LLT::scalar(128);                                       \
166       (void)s128;                                                              \
167       do                                                                       \
168         SettingUpActionsBlock while (0);                                       \
169       getLegacyLegalizerInfo().computeTables();                                \
170       verify(*ST.getInstrInfo());                                              \
171     }                                                                          \
172   };
173 
174 static inline bool CheckMachineFunction(const MachineFunction &MF,
175                                         StringRef CheckStr) {
176   SmallString<512> Msg;
177   raw_svector_ostream OS(Msg);
178   MF.print(OS);
179   auto OutputBuf = MemoryBuffer::getMemBuffer(Msg, "Output", false);
180   auto CheckBuf = MemoryBuffer::getMemBuffer(CheckStr, "");
181   SmallString<4096> CheckFileBuffer;
182   FileCheckRequest Req;
183   FileCheck FC(Req);
184   StringRef CheckFileText = FC.CanonicalizeFile(*CheckBuf, CheckFileBuffer);
185   SourceMgr SM;
186   SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(CheckFileText, "CheckFile"),
187                         SMLoc());
188   if (FC.readCheckFile(SM, CheckFileText))
189     return false;
190 
191   auto OutBuffer = OutputBuf->getBuffer();
192   SM.AddNewSourceBuffer(std::move(OutputBuf), SMLoc());
193   return FC.checkInput(SM, OutBuffer);
194 }
195 #endif
196