1 //===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/AsmParser/Parser.h" 10 #include "llvm/IR/Module.h" 11 #include "llvm/Passes/PassBuilder.h" 12 #include "llvm/Support/SourceMgr.h" 13 #include "llvm/Testing/Support/Error.h" 14 #include "llvm/Transforms/Coroutines/ABI.h" 15 #include "llvm/Transforms/Coroutines/CoroSplit.h" 16 #include "gtest/gtest.h" 17 18 using namespace llvm; 19 20 namespace { 21 22 struct ExtraRematTest : public testing::Test { 23 LLVMContext Ctx; 24 ModulePassManager MPM; 25 PassBuilder PB; 26 LoopAnalysisManager LAM; 27 FunctionAnalysisManager FAM; 28 CGSCCAnalysisManager CGAM; 29 ModuleAnalysisManager MAM; 30 LLVMContext Context; 31 std::unique_ptr<Module> M; 32 33 ExtraRematTest() { 34 PB.registerModuleAnalyses(MAM); 35 PB.registerCGSCCAnalyses(CGAM); 36 PB.registerFunctionAnalyses(FAM); 37 PB.registerLoopAnalyses(LAM); 38 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 39 } 40 41 BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const { 42 for (BasicBlock &BB : *F) { 43 if (BB.getName() == Name) 44 return &BB; 45 } 46 return nullptr; 47 } 48 49 CallInst *getCallByName(BasicBlock *BB, StringRef Name) const { 50 for (Instruction &I : *BB) { 51 if (CallInst *CI = dyn_cast<CallInst>(&I)) 52 if (CI->getCalledFunction()->getName() == Name) 53 return CI; 54 } 55 return nullptr; 56 } 57 58 void ParseAssembly(const StringRef IR) { 59 SMDiagnostic Error; 60 M = parseAssemblyString(IR, Error, Context); 61 std::string errMsg; 62 raw_string_ostream os(errMsg); 63 Error.print("", os); 64 65 // A failure here means that the test itself is buggy. 66 if (!M) 67 report_fatal_error(errMsg.c_str()); 68 } 69 }; 70 71 StringRef Text = R"( 72 define ptr @f(i32 %n) presplitcoroutine { 73 entry: 74 %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) 75 %size = call i32 @llvm.coro.size.i32() 76 %alloc = call ptr @malloc(i32 %size) 77 %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc) 78 79 %inc1 = add i32 %n, 1 80 %val2 = call i32 @should.remat(i32 %inc1) 81 %sp1 = call i8 @llvm.coro.suspend(token none, i1 false) 82 switch i8 %sp1, label %suspend [i8 0, label %resume1 83 i8 1, label %cleanup] 84 resume1: 85 %inc2 = add i32 %val2, 1 86 %sp2 = call i8 @llvm.coro.suspend(token none, i1 false) 87 switch i8 %sp1, label %suspend [i8 0, label %resume2 88 i8 1, label %cleanup] 89 90 resume2: 91 call void @print(i32 %val2) 92 call void @print(i32 %inc2) 93 br label %cleanup 94 95 cleanup: 96 %mem = call ptr @llvm.coro.free(token %id, ptr %hdl) 97 call void @free(ptr %mem) 98 br label %suspend 99 suspend: 100 call i1 @llvm.coro.end(ptr %hdl, i1 0) 101 ret ptr %hdl 102 } 103 104 declare ptr @llvm.coro.free(token, ptr) 105 declare i32 @llvm.coro.size.i32() 106 declare i8 @llvm.coro.suspend(token, i1) 107 declare void @llvm.coro.resume(ptr) 108 declare void @llvm.coro.destroy(ptr) 109 110 declare token @llvm.coro.id(i32, ptr, ptr, ptr) 111 declare i1 @llvm.coro.alloc(token) 112 declare ptr @llvm.coro.begin(token, ptr) 113 declare i1 @llvm.coro.end(ptr, i1) 114 115 declare i32 @should.remat(i32) 116 117 declare noalias ptr @malloc(i32) 118 declare void @print(i32) 119 declare void @free(ptr) 120 )"; 121 122 // Materializable callback with extra rematerialization 123 bool ExtraMaterializable(Instruction &I) { 124 if (isa<CastInst>(&I) || isa<GetElementPtrInst>(&I) || 125 isa<BinaryOperator>(&I) || isa<CmpInst>(&I) || isa<SelectInst>(&I)) 126 return true; 127 128 if (auto *CI = dyn_cast<CallInst>(&I)) { 129 auto *CalledFunc = CI->getCalledFunction(); 130 if (CalledFunc && CalledFunc->getName().starts_with("should.remat")) 131 return true; 132 } 133 134 return false; 135 } 136 137 TEST_F(ExtraRematTest, TestCoroRematDefault) { 138 ParseAssembly(Text); 139 140 ASSERT_TRUE(M); 141 142 CGSCCPassManager CGPM; 143 CGPM.addPass(CoroSplitPass()); 144 MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); 145 MPM.run(*M, MAM); 146 147 // Verify that extra rematerializable instruction has been rematerialized 148 Function *F = M->getFunction("f.resume"); 149 ASSERT_TRUE(F) << "could not find split function f.resume"; 150 151 BasicBlock *Resume1 = getBasicBlockByName(F, "resume1"); 152 ASSERT_TRUE(Resume1) 153 << "could not find expected BB resume1 in split function"; 154 155 // With default materialization the intrinsic should not have been 156 // rematerialized 157 CallInst *CI = getCallByName(Resume1, "should.remat"); 158 ASSERT_FALSE(CI); 159 } 160 161 TEST_F(ExtraRematTest, TestCoroRematWithCallback) { 162 ParseAssembly(Text); 163 164 ASSERT_TRUE(M); 165 166 CGSCCPassManager CGPM; 167 CGPM.addPass( 168 CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable))); 169 MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); 170 MPM.run(*M, MAM); 171 172 // Verify that extra rematerializable instruction has been rematerialized 173 Function *F = M->getFunction("f.resume"); 174 ASSERT_TRUE(F) << "could not find split function f.resume"; 175 176 BasicBlock *Resume1 = getBasicBlockByName(F, "resume1"); 177 ASSERT_TRUE(Resume1) 178 << "could not find expected BB resume1 in split function"; 179 180 // With callback the extra rematerialization of the function should have 181 // happened 182 CallInst *CI = getCallByName(Resume1, "should.remat"); 183 ASSERT_TRUE(CI); 184 } 185 186 StringRef TextCoroBeginCustomABI = R"( 187 define ptr @f(i32 %n) presplitcoroutine { 188 entry: 189 %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) 190 %size = call i32 @llvm.coro.size.i32() 191 %alloc = call ptr @malloc(i32 %size) 192 %hdl = call ptr @llvm.coro.begin.custom.abi(token %id, ptr %alloc, i32 0) 193 194 %inc1 = add i32 %n, 1 195 %val2 = call i32 @should.remat(i32 %inc1) 196 %sp1 = call i8 @llvm.coro.suspend(token none, i1 false) 197 switch i8 %sp1, label %suspend [i8 0, label %resume1 198 i8 1, label %cleanup] 199 resume1: 200 %inc2 = add i32 %val2, 1 201 %sp2 = call i8 @llvm.coro.suspend(token none, i1 false) 202 switch i8 %sp1, label %suspend [i8 0, label %resume2 203 i8 1, label %cleanup] 204 205 resume2: 206 call void @print(i32 %val2) 207 call void @print(i32 %inc2) 208 br label %cleanup 209 210 cleanup: 211 %mem = call ptr @llvm.coro.free(token %id, ptr %hdl) 212 call void @free(ptr %mem) 213 br label %suspend 214 suspend: 215 call i1 @llvm.coro.end(ptr %hdl, i1 0) 216 ret ptr %hdl 217 } 218 219 declare ptr @llvm.coro.free(token, ptr) 220 declare i32 @llvm.coro.size.i32() 221 declare i8 @llvm.coro.suspend(token, i1) 222 declare void @llvm.coro.resume(ptr) 223 declare void @llvm.coro.destroy(ptr) 224 225 declare token @llvm.coro.id(i32, ptr, ptr, ptr) 226 declare i1 @llvm.coro.alloc(token) 227 declare ptr @llvm.coro.begin.custom.abi(token, ptr, i32) 228 declare i1 @llvm.coro.end(ptr, i1) 229 230 declare i32 @should.remat(i32) 231 232 declare noalias ptr @malloc(i32) 233 declare void @print(i32) 234 declare void @free(ptr) 235 )"; 236 237 // SwitchABI with overridden isMaterializable 238 class ExtraCustomABI : public coro::SwitchABI { 239 public: 240 ExtraCustomABI(Function &F, coro::Shape &S) 241 : coro::SwitchABI(F, S, ExtraMaterializable) {} 242 }; 243 244 TEST_F(ExtraRematTest, TestCoroRematWithCustomABI) { 245 ParseAssembly(TextCoroBeginCustomABI); 246 247 ASSERT_TRUE(M); 248 249 CoroSplitPass::BaseABITy GenCustomABI = [](Function &F, coro::Shape &S) { 250 return std::make_unique<ExtraCustomABI>(F, S); 251 }; 252 253 CGSCCPassManager CGPM; 254 CGPM.addPass(CoroSplitPass({GenCustomABI})); 255 MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); 256 MPM.run(*M, MAM); 257 258 // Verify that extra rematerializable instruction has been rematerialized 259 Function *F = M->getFunction("f.resume"); 260 ASSERT_TRUE(F) << "could not find split function f.resume"; 261 262 BasicBlock *Resume1 = getBasicBlockByName(F, "resume1"); 263 ASSERT_TRUE(Resume1) 264 << "could not find expected BB resume1 in split function"; 265 266 // With callback the extra rematerialization of the function should have 267 // happened 268 CallInst *CI = getCallByName(Resume1, "should.remat"); 269 ASSERT_TRUE(CI); 270 } 271 272 } // namespace 273