1c4f7cc86SDavid Stuttard //===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===// 2c4f7cc86SDavid Stuttard // 3c4f7cc86SDavid Stuttard // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c4f7cc86SDavid Stuttard // See https://llvm.org/LICENSE.txt for license information. 5c4f7cc86SDavid Stuttard // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c4f7cc86SDavid Stuttard // 7c4f7cc86SDavid Stuttard //===----------------------------------------------------------------------===// 8c4f7cc86SDavid Stuttard 9c4f7cc86SDavid Stuttard #include "llvm/AsmParser/Parser.h" 10c4f7cc86SDavid Stuttard #include "llvm/IR/Module.h" 11c4f7cc86SDavid Stuttard #include "llvm/Passes/PassBuilder.h" 12c4f7cc86SDavid Stuttard #include "llvm/Support/SourceMgr.h" 13c4f7cc86SDavid Stuttard #include "llvm/Testing/Support/Error.h" 14e82fcda1STyler Nowicki #include "llvm/Transforms/Coroutines/ABI.h" 15c4f7cc86SDavid Stuttard #include "llvm/Transforms/Coroutines/CoroSplit.h" 16c4f7cc86SDavid Stuttard #include "gtest/gtest.h" 17c4f7cc86SDavid Stuttard 18c4f7cc86SDavid Stuttard using namespace llvm; 19c4f7cc86SDavid Stuttard 20c4f7cc86SDavid Stuttard namespace { 21c4f7cc86SDavid Stuttard 22c4f7cc86SDavid Stuttard struct ExtraRematTest : public testing::Test { 23c4f7cc86SDavid Stuttard LLVMContext Ctx; 24c4f7cc86SDavid Stuttard ModulePassManager MPM; 25c4f7cc86SDavid Stuttard PassBuilder PB; 26c4f7cc86SDavid Stuttard LoopAnalysisManager LAM; 27c4f7cc86SDavid Stuttard FunctionAnalysisManager FAM; 28c4f7cc86SDavid Stuttard CGSCCAnalysisManager CGAM; 29c4f7cc86SDavid Stuttard ModuleAnalysisManager MAM; 30c4f7cc86SDavid Stuttard LLVMContext Context; 31c4f7cc86SDavid Stuttard std::unique_ptr<Module> M; 32c4f7cc86SDavid Stuttard 33c4f7cc86SDavid Stuttard ExtraRematTest() { 34c4f7cc86SDavid Stuttard PB.registerModuleAnalyses(MAM); 35c4f7cc86SDavid Stuttard PB.registerCGSCCAnalyses(CGAM); 36c4f7cc86SDavid Stuttard PB.registerFunctionAnalyses(FAM); 37c4f7cc86SDavid Stuttard PB.registerLoopAnalyses(LAM); 38c4f7cc86SDavid Stuttard PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 39c4f7cc86SDavid Stuttard } 40c4f7cc86SDavid Stuttard 41c4f7cc86SDavid Stuttard BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const { 42c4f7cc86SDavid Stuttard for (BasicBlock &BB : *F) { 43c4f7cc86SDavid Stuttard if (BB.getName() == Name) 44c4f7cc86SDavid Stuttard return &BB; 45c4f7cc86SDavid Stuttard } 46c4f7cc86SDavid Stuttard return nullptr; 47c4f7cc86SDavid Stuttard } 48c4f7cc86SDavid Stuttard 49c4f7cc86SDavid Stuttard CallInst *getCallByName(BasicBlock *BB, StringRef Name) const { 50c4f7cc86SDavid Stuttard for (Instruction &I : *BB) { 51c4f7cc86SDavid Stuttard if (CallInst *CI = dyn_cast<CallInst>(&I)) 52c4f7cc86SDavid Stuttard if (CI->getCalledFunction()->getName() == Name) 53c4f7cc86SDavid Stuttard return CI; 54c4f7cc86SDavid Stuttard } 55c4f7cc86SDavid Stuttard return nullptr; 56c4f7cc86SDavid Stuttard } 57c4f7cc86SDavid Stuttard 58c4f7cc86SDavid Stuttard void ParseAssembly(const StringRef IR) { 59c4f7cc86SDavid Stuttard SMDiagnostic Error; 60c4f7cc86SDavid Stuttard M = parseAssemblyString(IR, Error, Context); 61c4f7cc86SDavid Stuttard std::string errMsg; 62c4f7cc86SDavid Stuttard raw_string_ostream os(errMsg); 63c4f7cc86SDavid Stuttard Error.print("", os); 64c4f7cc86SDavid Stuttard 65c4f7cc86SDavid Stuttard // A failure here means that the test itself is buggy. 66c4f7cc86SDavid Stuttard if (!M) 6752b48a70SJOE1994 report_fatal_error(errMsg.c_str()); 68c4f7cc86SDavid Stuttard } 69c4f7cc86SDavid Stuttard }; 70c4f7cc86SDavid Stuttard 71c4f7cc86SDavid Stuttard StringRef Text = R"( 72c4f7cc86SDavid Stuttard define ptr @f(i32 %n) presplitcoroutine { 73c4f7cc86SDavid Stuttard entry: 74c4f7cc86SDavid Stuttard %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) 75c4f7cc86SDavid Stuttard %size = call i32 @llvm.coro.size.i32() 76c4f7cc86SDavid Stuttard %alloc = call ptr @malloc(i32 %size) 77c4f7cc86SDavid Stuttard %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc) 78c4f7cc86SDavid Stuttard 79c4f7cc86SDavid Stuttard %inc1 = add i32 %n, 1 80c4f7cc86SDavid Stuttard %val2 = call i32 @should.remat(i32 %inc1) 81c4f7cc86SDavid Stuttard %sp1 = call i8 @llvm.coro.suspend(token none, i1 false) 82c4f7cc86SDavid Stuttard switch i8 %sp1, label %suspend [i8 0, label %resume1 83c4f7cc86SDavid Stuttard i8 1, label %cleanup] 84c4f7cc86SDavid Stuttard resume1: 85c4f7cc86SDavid Stuttard %inc2 = add i32 %val2, 1 86c4f7cc86SDavid Stuttard %sp2 = call i8 @llvm.coro.suspend(token none, i1 false) 87c4f7cc86SDavid Stuttard switch i8 %sp1, label %suspend [i8 0, label %resume2 88c4f7cc86SDavid Stuttard i8 1, label %cleanup] 89c4f7cc86SDavid Stuttard 90c4f7cc86SDavid Stuttard resume2: 91c4f7cc86SDavid Stuttard call void @print(i32 %val2) 92c4f7cc86SDavid Stuttard call void @print(i32 %inc2) 93c4f7cc86SDavid Stuttard br label %cleanup 94c4f7cc86SDavid Stuttard 95c4f7cc86SDavid Stuttard cleanup: 96c4f7cc86SDavid Stuttard %mem = call ptr @llvm.coro.free(token %id, ptr %hdl) 97c4f7cc86SDavid Stuttard call void @free(ptr %mem) 98c4f7cc86SDavid Stuttard br label %suspend 99c4f7cc86SDavid Stuttard suspend: 100c4f7cc86SDavid Stuttard call i1 @llvm.coro.end(ptr %hdl, i1 0) 101c4f7cc86SDavid Stuttard ret ptr %hdl 102c4f7cc86SDavid Stuttard } 103c4f7cc86SDavid Stuttard 104c4f7cc86SDavid Stuttard declare ptr @llvm.coro.free(token, ptr) 105c4f7cc86SDavid Stuttard declare i32 @llvm.coro.size.i32() 106c4f7cc86SDavid Stuttard declare i8 @llvm.coro.suspend(token, i1) 107c4f7cc86SDavid Stuttard declare void @llvm.coro.resume(ptr) 108c4f7cc86SDavid Stuttard declare void @llvm.coro.destroy(ptr) 109c4f7cc86SDavid Stuttard 110c4f7cc86SDavid Stuttard declare token @llvm.coro.id(i32, ptr, ptr, ptr) 111c4f7cc86SDavid Stuttard declare i1 @llvm.coro.alloc(token) 112c4f7cc86SDavid Stuttard declare ptr @llvm.coro.begin(token, ptr) 113c4f7cc86SDavid Stuttard declare i1 @llvm.coro.end(ptr, i1) 114c4f7cc86SDavid Stuttard 115c4f7cc86SDavid Stuttard declare i32 @should.remat(i32) 116c4f7cc86SDavid Stuttard 117c4f7cc86SDavid Stuttard declare noalias ptr @malloc(i32) 118c4f7cc86SDavid Stuttard declare void @print(i32) 119c4f7cc86SDavid Stuttard declare void @free(ptr) 120c4f7cc86SDavid Stuttard )"; 121c4f7cc86SDavid Stuttard 122c4f7cc86SDavid Stuttard // Materializable callback with extra rematerialization 123c4f7cc86SDavid Stuttard bool ExtraMaterializable(Instruction &I) { 124c4f7cc86SDavid Stuttard if (isa<CastInst>(&I) || isa<GetElementPtrInst>(&I) || 125c4f7cc86SDavid Stuttard isa<BinaryOperator>(&I) || isa<CmpInst>(&I) || isa<SelectInst>(&I)) 126c4f7cc86SDavid Stuttard return true; 127c4f7cc86SDavid Stuttard 128c4f7cc86SDavid Stuttard if (auto *CI = dyn_cast<CallInst>(&I)) { 129c4f7cc86SDavid Stuttard auto *CalledFunc = CI->getCalledFunction(); 1305c9d82deSKazu Hirata if (CalledFunc && CalledFunc->getName().starts_with("should.remat")) 131c4f7cc86SDavid Stuttard return true; 132c4f7cc86SDavid Stuttard } 133c4f7cc86SDavid Stuttard 134c4f7cc86SDavid Stuttard return false; 135c4f7cc86SDavid Stuttard } 136c4f7cc86SDavid Stuttard 137c4f7cc86SDavid Stuttard TEST_F(ExtraRematTest, TestCoroRematDefault) { 138c4f7cc86SDavid Stuttard ParseAssembly(Text); 139c4f7cc86SDavid Stuttard 140c4f7cc86SDavid Stuttard ASSERT_TRUE(M); 141c4f7cc86SDavid Stuttard 142c4f7cc86SDavid Stuttard CGSCCPassManager CGPM; 143c4f7cc86SDavid Stuttard CGPM.addPass(CoroSplitPass()); 144c4f7cc86SDavid Stuttard MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); 145c4f7cc86SDavid Stuttard MPM.run(*M, MAM); 146c4f7cc86SDavid Stuttard 147c4f7cc86SDavid Stuttard // Verify that extra rematerializable instruction has been rematerialized 148c4f7cc86SDavid Stuttard Function *F = M->getFunction("f.resume"); 149c4f7cc86SDavid Stuttard ASSERT_TRUE(F) << "could not find split function f.resume"; 150c4f7cc86SDavid Stuttard 151c4f7cc86SDavid Stuttard BasicBlock *Resume1 = getBasicBlockByName(F, "resume1"); 152c4f7cc86SDavid Stuttard ASSERT_TRUE(Resume1) 153c4f7cc86SDavid Stuttard << "could not find expected BB resume1 in split function"; 154c4f7cc86SDavid Stuttard 155c4f7cc86SDavid Stuttard // With default materialization the intrinsic should not have been 156c4f7cc86SDavid Stuttard // rematerialized 157c4f7cc86SDavid Stuttard CallInst *CI = getCallByName(Resume1, "should.remat"); 158c4f7cc86SDavid Stuttard ASSERT_FALSE(CI); 159c4f7cc86SDavid Stuttard } 160c4f7cc86SDavid Stuttard 161c4f7cc86SDavid Stuttard TEST_F(ExtraRematTest, TestCoroRematWithCallback) { 162c4f7cc86SDavid Stuttard ParseAssembly(Text); 163c4f7cc86SDavid Stuttard 164c4f7cc86SDavid Stuttard ASSERT_TRUE(M); 165c4f7cc86SDavid Stuttard 166c4f7cc86SDavid Stuttard CGSCCPassManager CGPM; 167c4f7cc86SDavid Stuttard CGPM.addPass( 168c4f7cc86SDavid Stuttard CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable))); 169c4f7cc86SDavid Stuttard MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); 170c4f7cc86SDavid Stuttard MPM.run(*M, MAM); 171c4f7cc86SDavid Stuttard 172c4f7cc86SDavid Stuttard // Verify that extra rematerializable instruction has been rematerialized 173c4f7cc86SDavid Stuttard Function *F = M->getFunction("f.resume"); 174c4f7cc86SDavid Stuttard ASSERT_TRUE(F) << "could not find split function f.resume"; 175c4f7cc86SDavid Stuttard 176c4f7cc86SDavid Stuttard BasicBlock *Resume1 = getBasicBlockByName(F, "resume1"); 177c4f7cc86SDavid Stuttard ASSERT_TRUE(Resume1) 178c4f7cc86SDavid Stuttard << "could not find expected BB resume1 in split function"; 179c4f7cc86SDavid Stuttard 180c4f7cc86SDavid Stuttard // With callback the extra rematerialization of the function should have 181c4f7cc86SDavid Stuttard // happened 182c4f7cc86SDavid Stuttard CallInst *CI = getCallByName(Resume1, "should.remat"); 183c4f7cc86SDavid Stuttard ASSERT_TRUE(CI); 184c4f7cc86SDavid Stuttard } 1853737a532STyler Nowicki 1863737a532STyler Nowicki StringRef TextCoroBeginCustomABI = R"( 1873737a532STyler Nowicki define ptr @f(i32 %n) presplitcoroutine { 1883737a532STyler Nowicki entry: 1893737a532STyler Nowicki %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) 1903737a532STyler Nowicki %size = call i32 @llvm.coro.size.i32() 1913737a532STyler Nowicki %alloc = call ptr @malloc(i32 %size) 1923737a532STyler Nowicki %hdl = call ptr @llvm.coro.begin.custom.abi(token %id, ptr %alloc, i32 0) 1933737a532STyler Nowicki 1943737a532STyler Nowicki %inc1 = add i32 %n, 1 1953737a532STyler Nowicki %val2 = call i32 @should.remat(i32 %inc1) 1963737a532STyler Nowicki %sp1 = call i8 @llvm.coro.suspend(token none, i1 false) 1973737a532STyler Nowicki switch i8 %sp1, label %suspend [i8 0, label %resume1 1983737a532STyler Nowicki i8 1, label %cleanup] 1993737a532STyler Nowicki resume1: 2003737a532STyler Nowicki %inc2 = add i32 %val2, 1 2013737a532STyler Nowicki %sp2 = call i8 @llvm.coro.suspend(token none, i1 false) 2023737a532STyler Nowicki switch i8 %sp1, label %suspend [i8 0, label %resume2 2033737a532STyler Nowicki i8 1, label %cleanup] 2043737a532STyler Nowicki 2053737a532STyler Nowicki resume2: 2063737a532STyler Nowicki call void @print(i32 %val2) 2073737a532STyler Nowicki call void @print(i32 %inc2) 2083737a532STyler Nowicki br label %cleanup 2093737a532STyler Nowicki 2103737a532STyler Nowicki cleanup: 2113737a532STyler Nowicki %mem = call ptr @llvm.coro.free(token %id, ptr %hdl) 2123737a532STyler Nowicki call void @free(ptr %mem) 2133737a532STyler Nowicki br label %suspend 2143737a532STyler Nowicki suspend: 2153737a532STyler Nowicki call i1 @llvm.coro.end(ptr %hdl, i1 0) 2163737a532STyler Nowicki ret ptr %hdl 2173737a532STyler Nowicki } 2183737a532STyler Nowicki 2193737a532STyler Nowicki declare ptr @llvm.coro.free(token, ptr) 2203737a532STyler Nowicki declare i32 @llvm.coro.size.i32() 2213737a532STyler Nowicki declare i8 @llvm.coro.suspend(token, i1) 2223737a532STyler Nowicki declare void @llvm.coro.resume(ptr) 2233737a532STyler Nowicki declare void @llvm.coro.destroy(ptr) 2243737a532STyler Nowicki 2253737a532STyler Nowicki declare token @llvm.coro.id(i32, ptr, ptr, ptr) 2263737a532STyler Nowicki declare i1 @llvm.coro.alloc(token) 2273737a532STyler Nowicki declare ptr @llvm.coro.begin.custom.abi(token, ptr, i32) 2283737a532STyler Nowicki declare i1 @llvm.coro.end(ptr, i1) 2293737a532STyler Nowicki 2303737a532STyler Nowicki declare i32 @should.remat(i32) 2313737a532STyler Nowicki 2323737a532STyler Nowicki declare noalias ptr @malloc(i32) 2333737a532STyler Nowicki declare void @print(i32) 2343737a532STyler Nowicki declare void @free(ptr) 2353737a532STyler Nowicki )"; 2363737a532STyler Nowicki 2373737a532STyler Nowicki // SwitchABI with overridden isMaterializable 2383737a532STyler Nowicki class ExtraCustomABI : public coro::SwitchABI { 2393737a532STyler Nowicki public: 2403737a532STyler Nowicki ExtraCustomABI(Function &F, coro::Shape &S) 2413737a532STyler Nowicki : coro::SwitchABI(F, S, ExtraMaterializable) {} 2423737a532STyler Nowicki }; 2433737a532STyler Nowicki 2443737a532STyler Nowicki TEST_F(ExtraRematTest, TestCoroRematWithCustomABI) { 2453737a532STyler Nowicki ParseAssembly(TextCoroBeginCustomABI); 2463737a532STyler Nowicki 2473737a532STyler Nowicki ASSERT_TRUE(M); 2483737a532STyler Nowicki 2493737a532STyler Nowicki CoroSplitPass::BaseABITy GenCustomABI = [](Function &F, coro::Shape &S) { 250*12526231STyler Nowicki return std::make_unique<ExtraCustomABI>(F, S); 2513737a532STyler Nowicki }; 2523737a532STyler Nowicki 2533737a532STyler Nowicki CGSCCPassManager CGPM; 2543737a532STyler Nowicki CGPM.addPass(CoroSplitPass({GenCustomABI})); 2553737a532STyler Nowicki MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); 2563737a532STyler Nowicki MPM.run(*M, MAM); 2573737a532STyler Nowicki 2583737a532STyler Nowicki // Verify that extra rematerializable instruction has been rematerialized 2593737a532STyler Nowicki Function *F = M->getFunction("f.resume"); 2603737a532STyler Nowicki ASSERT_TRUE(F) << "could not find split function f.resume"; 2613737a532STyler Nowicki 2623737a532STyler Nowicki BasicBlock *Resume1 = getBasicBlockByName(F, "resume1"); 2633737a532STyler Nowicki ASSERT_TRUE(Resume1) 2643737a532STyler Nowicki << "could not find expected BB resume1 in split function"; 2653737a532STyler Nowicki 2663737a532STyler Nowicki // With callback the extra rematerialization of the function should have 2673737a532STyler Nowicki // happened 2683737a532STyler Nowicki CallInst *CI = getCallByName(Resume1, "should.remat"); 2693737a532STyler Nowicki ASSERT_TRUE(CI); 2703737a532STyler Nowicki } 2713737a532STyler Nowicki 272c4f7cc86SDavid Stuttard } // namespace 273