1 //===- GISelMITest.h --------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #ifndef LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H 9 #define LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H 10 11 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" 12 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 13 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" 14 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 15 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 16 #include "llvm/CodeGen/GlobalISel/Utils.h" 17 #include "llvm/CodeGen/MIRParser/MIRParser.h" 18 #include "llvm/CodeGen/MachineFunction.h" 19 #include "llvm/CodeGen/MachineModuleInfo.h" 20 #include "llvm/CodeGen/TargetFrameLowering.h" 21 #include "llvm/CodeGen/TargetInstrInfo.h" 22 #include "llvm/CodeGen/TargetLowering.h" 23 #include "llvm/CodeGen/TargetSubtargetInfo.h" 24 #include "llvm/FileCheck/FileCheck.h" 25 #include "llvm/InitializePasses.h" 26 #include "llvm/MC/TargetRegistry.h" 27 #include "llvm/Support/SourceMgr.h" 28 #include "llvm/Support/TargetSelect.h" 29 #include "llvm/Target/TargetMachine.h" 30 #include "llvm/Target/TargetOptions.h" 31 #include "gtest/gtest.h" 32 33 using namespace llvm; 34 using namespace MIPatternMatch; 35 36 static inline void initLLVM() { 37 InitializeAllTargets(); 38 InitializeAllTargetMCs(); 39 InitializeAllAsmPrinters(); 40 InitializeAllAsmParsers(); 41 42 PassRegistry *Registry = PassRegistry::getPassRegistry(); 43 initializeCore(*Registry); 44 initializeCodeGen(*Registry); 45 } 46 47 // Define a printers to help debugging when things go wrong. 48 namespace llvm { 49 std::ostream & 50 operator<<(std::ostream &OS, const LLT Ty); 51 52 std::ostream & 53 operator<<(std::ostream &OS, const MachineFunction &MF); 54 } 55 56 static std::unique_ptr<Module> 57 parseMIR(LLVMContext &Context, std::unique_ptr<MIRParser> &MIR, 58 const TargetMachine &TM, StringRef MIRCode, MachineModuleInfo &MMI) { 59 SMDiagnostic Diagnostic; 60 std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode); 61 MIR = createMIRParser(std::move(MBuffer), Context); 62 if (!MIR) 63 return nullptr; 64 65 std::unique_ptr<Module> M = MIR->parseIRModule(); 66 if (!M) 67 return nullptr; 68 69 M->setDataLayout(TM.createDataLayout()); 70 71 if (MIR->parseMachineFunctions(*M, MMI)) 72 return nullptr; 73 74 return M; 75 } 76 static std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>> 77 createDummyModule(LLVMContext &Context, const TargetMachine &TM, 78 StringRef MIRString, const char *FuncName) { 79 std::unique_ptr<MIRParser> MIR; 80 auto MMI = std::make_unique<MachineModuleInfo>(&TM); 81 std::unique_ptr<Module> M = parseMIR(Context, MIR, TM, MIRString, *MMI); 82 return make_pair(std::move(M), std::move(MMI)); 83 } 84 85 static MachineFunction *getMFFromMMI(const Module *M, 86 const MachineModuleInfo *MMI) { 87 Function *F = M->getFunction("func"); 88 auto *MF = MMI->getMachineFunction(*F); 89 return MF; 90 } 91 92 static void collectCopies(SmallVectorImpl<Register> &Copies, 93 MachineFunction *MF) { 94 for (auto &MBB : *MF) 95 for (MachineInstr &MI : MBB) { 96 if (MI.getOpcode() == TargetOpcode::COPY) 97 Copies.push_back(MI.getOperand(0).getReg()); 98 } 99 } 100 101 class GISelMITest : public ::testing::Test { 102 protected: 103 GISelMITest() : ::testing::Test() {} 104 105 /// Prepare a target specific TargetMachine. 106 virtual std::unique_ptr<TargetMachine> createTargetMachine() const = 0; 107 108 /// Get the stub sample MIR test function. 109 virtual void getTargetTestModuleString(SmallString<512> &S, 110 StringRef MIRFunc) const = 0; 111 112 void setUp(StringRef ExtraAssembly = "") { 113 TM = createTargetMachine(); 114 if (!TM) 115 return; 116 117 SmallString<512> MIRString; 118 getTargetTestModuleString(MIRString, ExtraAssembly); 119 120 ModuleMMIPair = createDummyModule(Context, *TM, MIRString, "func"); 121 MF = getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 122 collectCopies(Copies, MF); 123 EntryMBB = &*MF->begin(); 124 B.setMF(*MF); 125 MRI = &MF->getRegInfo(); 126 B.setInsertPt(*EntryMBB, EntryMBB->end()); 127 } 128 129 LLVMContext Context; 130 std::unique_ptr<TargetMachine> TM; 131 MachineFunction *MF; 132 std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>> 133 ModuleMMIPair; 134 SmallVector<Register, 4> Copies; 135 MachineBasicBlock *EntryMBB; 136 MachineIRBuilder B; 137 MachineRegisterInfo *MRI; 138 }; 139 140 class AArch64GISelMITest : public GISelMITest { 141 std::unique_ptr<TargetMachine> createTargetMachine() const override; 142 void getTargetTestModuleString(SmallString<512> &S, 143 StringRef MIRFunc) const override; 144 }; 145 146 class AMDGPUGISelMITest : public GISelMITest { 147 std::unique_ptr<TargetMachine> createTargetMachine() const override; 148 void getTargetTestModuleString(SmallString<512> &S, 149 StringRef MIRFunc) const override; 150 }; 151 152 #define DefineLegalizerInfo(Name, SettingUpActionsBlock) \ 153 class Name##Info : public LegalizerInfo { \ 154 public: \ 155 Name##Info(const TargetSubtargetInfo &ST) { \ 156 using namespace TargetOpcode; \ 157 const LLT s8 = LLT::scalar(8); \ 158 (void)s8; \ 159 const LLT s16 = LLT::scalar(16); \ 160 (void)s16; \ 161 const LLT s32 = LLT::scalar(32); \ 162 (void)s32; \ 163 const LLT s64 = LLT::scalar(64); \ 164 (void)s64; \ 165 const LLT s128 = LLT::scalar(128); \ 166 (void)s128; \ 167 do \ 168 SettingUpActionsBlock while (0); \ 169 getLegacyLegalizerInfo().computeTables(); \ 170 verify(*ST.getInstrInfo()); \ 171 } \ 172 }; 173 174 static inline bool CheckMachineFunction(const MachineFunction &MF, 175 StringRef CheckStr) { 176 SmallString<512> Msg; 177 raw_svector_ostream OS(Msg); 178 MF.print(OS); 179 auto OutputBuf = MemoryBuffer::getMemBuffer(Msg, "Output", false); 180 auto CheckBuf = MemoryBuffer::getMemBuffer(CheckStr, ""); 181 SmallString<4096> CheckFileBuffer; 182 FileCheckRequest Req; 183 FileCheck FC(Req); 184 StringRef CheckFileText = FC.CanonicalizeFile(*CheckBuf, CheckFileBuffer); 185 SourceMgr SM; 186 SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(CheckFileText, "CheckFile"), 187 SMLoc()); 188 if (FC.readCheckFile(SM, CheckFileText)) 189 return false; 190 191 auto OutBuffer = OutputBuf->getBuffer(); 192 SM.AddNewSourceBuffer(std::move(OutputBuf), SMLoc()); 193 return FC.checkInput(SM, OutBuffer); 194 } 195 #endif 196