xref: /llvm-project/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp (revision 125262312f366bd776b668b24026dbbc8e6b4c75)
1c4f7cc86SDavid Stuttard //===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===//
2c4f7cc86SDavid Stuttard //
3c4f7cc86SDavid Stuttard // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c4f7cc86SDavid Stuttard // See https://llvm.org/LICENSE.txt for license information.
5c4f7cc86SDavid Stuttard // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c4f7cc86SDavid Stuttard //
7c4f7cc86SDavid Stuttard //===----------------------------------------------------------------------===//
8c4f7cc86SDavid Stuttard 
9c4f7cc86SDavid Stuttard #include "llvm/AsmParser/Parser.h"
10c4f7cc86SDavid Stuttard #include "llvm/IR/Module.h"
11c4f7cc86SDavid Stuttard #include "llvm/Passes/PassBuilder.h"
12c4f7cc86SDavid Stuttard #include "llvm/Support/SourceMgr.h"
13c4f7cc86SDavid Stuttard #include "llvm/Testing/Support/Error.h"
14e82fcda1STyler Nowicki #include "llvm/Transforms/Coroutines/ABI.h"
15c4f7cc86SDavid Stuttard #include "llvm/Transforms/Coroutines/CoroSplit.h"
16c4f7cc86SDavid Stuttard #include "gtest/gtest.h"
17c4f7cc86SDavid Stuttard 
18c4f7cc86SDavid Stuttard using namespace llvm;
19c4f7cc86SDavid Stuttard 
20c4f7cc86SDavid Stuttard namespace {
21c4f7cc86SDavid Stuttard 
22c4f7cc86SDavid Stuttard struct ExtraRematTest : public testing::Test {
23c4f7cc86SDavid Stuttard   LLVMContext Ctx;
24c4f7cc86SDavid Stuttard   ModulePassManager MPM;
25c4f7cc86SDavid Stuttard   PassBuilder PB;
26c4f7cc86SDavid Stuttard   LoopAnalysisManager LAM;
27c4f7cc86SDavid Stuttard   FunctionAnalysisManager FAM;
28c4f7cc86SDavid Stuttard   CGSCCAnalysisManager CGAM;
29c4f7cc86SDavid Stuttard   ModuleAnalysisManager MAM;
30c4f7cc86SDavid Stuttard   LLVMContext Context;
31c4f7cc86SDavid Stuttard   std::unique_ptr<Module> M;
32c4f7cc86SDavid Stuttard 
33c4f7cc86SDavid Stuttard   ExtraRematTest() {
34c4f7cc86SDavid Stuttard     PB.registerModuleAnalyses(MAM);
35c4f7cc86SDavid Stuttard     PB.registerCGSCCAnalyses(CGAM);
36c4f7cc86SDavid Stuttard     PB.registerFunctionAnalyses(FAM);
37c4f7cc86SDavid Stuttard     PB.registerLoopAnalyses(LAM);
38c4f7cc86SDavid Stuttard     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
39c4f7cc86SDavid Stuttard   }
40c4f7cc86SDavid Stuttard 
41c4f7cc86SDavid Stuttard   BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const {
42c4f7cc86SDavid Stuttard     for (BasicBlock &BB : *F) {
43c4f7cc86SDavid Stuttard       if (BB.getName() == Name)
44c4f7cc86SDavid Stuttard         return &BB;
45c4f7cc86SDavid Stuttard     }
46c4f7cc86SDavid Stuttard     return nullptr;
47c4f7cc86SDavid Stuttard   }
48c4f7cc86SDavid Stuttard 
49c4f7cc86SDavid Stuttard   CallInst *getCallByName(BasicBlock *BB, StringRef Name) const {
50c4f7cc86SDavid Stuttard     for (Instruction &I : *BB) {
51c4f7cc86SDavid Stuttard       if (CallInst *CI = dyn_cast<CallInst>(&I))
52c4f7cc86SDavid Stuttard         if (CI->getCalledFunction()->getName() == Name)
53c4f7cc86SDavid Stuttard           return CI;
54c4f7cc86SDavid Stuttard     }
55c4f7cc86SDavid Stuttard     return nullptr;
56c4f7cc86SDavid Stuttard   }
57c4f7cc86SDavid Stuttard 
58c4f7cc86SDavid Stuttard   void ParseAssembly(const StringRef IR) {
59c4f7cc86SDavid Stuttard     SMDiagnostic Error;
60c4f7cc86SDavid Stuttard     M = parseAssemblyString(IR, Error, Context);
61c4f7cc86SDavid Stuttard     std::string errMsg;
62c4f7cc86SDavid Stuttard     raw_string_ostream os(errMsg);
63c4f7cc86SDavid Stuttard     Error.print("", os);
64c4f7cc86SDavid Stuttard 
65c4f7cc86SDavid Stuttard     // A failure here means that the test itself is buggy.
66c4f7cc86SDavid Stuttard     if (!M)
6752b48a70SJOE1994       report_fatal_error(errMsg.c_str());
68c4f7cc86SDavid Stuttard   }
69c4f7cc86SDavid Stuttard };
70c4f7cc86SDavid Stuttard 
71c4f7cc86SDavid Stuttard StringRef Text = R"(
72c4f7cc86SDavid Stuttard     define ptr @f(i32 %n) presplitcoroutine {
73c4f7cc86SDavid Stuttard     entry:
74c4f7cc86SDavid Stuttard       %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
75c4f7cc86SDavid Stuttard       %size = call i32 @llvm.coro.size.i32()
76c4f7cc86SDavid Stuttard       %alloc = call ptr @malloc(i32 %size)
77c4f7cc86SDavid Stuttard       %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc)
78c4f7cc86SDavid Stuttard 
79c4f7cc86SDavid Stuttard       %inc1 = add i32 %n, 1
80c4f7cc86SDavid Stuttard       %val2 = call i32 @should.remat(i32 %inc1)
81c4f7cc86SDavid Stuttard       %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
82c4f7cc86SDavid Stuttard       switch i8 %sp1, label %suspend [i8 0, label %resume1
83c4f7cc86SDavid Stuttard                                       i8 1, label %cleanup]
84c4f7cc86SDavid Stuttard     resume1:
85c4f7cc86SDavid Stuttard       %inc2 = add i32 %val2, 1
86c4f7cc86SDavid Stuttard       %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
87c4f7cc86SDavid Stuttard       switch i8 %sp1, label %suspend [i8 0, label %resume2
88c4f7cc86SDavid Stuttard                                       i8 1, label %cleanup]
89c4f7cc86SDavid Stuttard 
90c4f7cc86SDavid Stuttard     resume2:
91c4f7cc86SDavid Stuttard       call void @print(i32 %val2)
92c4f7cc86SDavid Stuttard       call void @print(i32 %inc2)
93c4f7cc86SDavid Stuttard       br label %cleanup
94c4f7cc86SDavid Stuttard 
95c4f7cc86SDavid Stuttard     cleanup:
96c4f7cc86SDavid Stuttard       %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
97c4f7cc86SDavid Stuttard       call void @free(ptr %mem)
98c4f7cc86SDavid Stuttard       br label %suspend
99c4f7cc86SDavid Stuttard     suspend:
100c4f7cc86SDavid Stuttard       call i1 @llvm.coro.end(ptr %hdl, i1 0)
101c4f7cc86SDavid Stuttard       ret ptr %hdl
102c4f7cc86SDavid Stuttard     }
103c4f7cc86SDavid Stuttard 
104c4f7cc86SDavid Stuttard     declare ptr @llvm.coro.free(token, ptr)
105c4f7cc86SDavid Stuttard     declare i32 @llvm.coro.size.i32()
106c4f7cc86SDavid Stuttard     declare i8  @llvm.coro.suspend(token, i1)
107c4f7cc86SDavid Stuttard     declare void @llvm.coro.resume(ptr)
108c4f7cc86SDavid Stuttard     declare void @llvm.coro.destroy(ptr)
109c4f7cc86SDavid Stuttard 
110c4f7cc86SDavid Stuttard     declare token @llvm.coro.id(i32, ptr, ptr, ptr)
111c4f7cc86SDavid Stuttard     declare i1 @llvm.coro.alloc(token)
112c4f7cc86SDavid Stuttard     declare ptr @llvm.coro.begin(token, ptr)
113c4f7cc86SDavid Stuttard     declare i1 @llvm.coro.end(ptr, i1)
114c4f7cc86SDavid Stuttard 
115c4f7cc86SDavid Stuttard     declare i32 @should.remat(i32)
116c4f7cc86SDavid Stuttard 
117c4f7cc86SDavid Stuttard     declare noalias ptr @malloc(i32)
118c4f7cc86SDavid Stuttard     declare void @print(i32)
119c4f7cc86SDavid Stuttard     declare void @free(ptr)
120c4f7cc86SDavid Stuttard   )";
121c4f7cc86SDavid Stuttard 
122c4f7cc86SDavid Stuttard // Materializable callback with extra rematerialization
123c4f7cc86SDavid Stuttard bool ExtraMaterializable(Instruction &I) {
124c4f7cc86SDavid Stuttard   if (isa<CastInst>(&I) || isa<GetElementPtrInst>(&I) ||
125c4f7cc86SDavid Stuttard       isa<BinaryOperator>(&I) || isa<CmpInst>(&I) || isa<SelectInst>(&I))
126c4f7cc86SDavid Stuttard     return true;
127c4f7cc86SDavid Stuttard 
128c4f7cc86SDavid Stuttard   if (auto *CI = dyn_cast<CallInst>(&I)) {
129c4f7cc86SDavid Stuttard     auto *CalledFunc = CI->getCalledFunction();
1305c9d82deSKazu Hirata     if (CalledFunc && CalledFunc->getName().starts_with("should.remat"))
131c4f7cc86SDavid Stuttard       return true;
132c4f7cc86SDavid Stuttard   }
133c4f7cc86SDavid Stuttard 
134c4f7cc86SDavid Stuttard   return false;
135c4f7cc86SDavid Stuttard }
136c4f7cc86SDavid Stuttard 
137c4f7cc86SDavid Stuttard TEST_F(ExtraRematTest, TestCoroRematDefault) {
138c4f7cc86SDavid Stuttard   ParseAssembly(Text);
139c4f7cc86SDavid Stuttard 
140c4f7cc86SDavid Stuttard   ASSERT_TRUE(M);
141c4f7cc86SDavid Stuttard 
142c4f7cc86SDavid Stuttard   CGSCCPassManager CGPM;
143c4f7cc86SDavid Stuttard   CGPM.addPass(CoroSplitPass());
144c4f7cc86SDavid Stuttard   MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
145c4f7cc86SDavid Stuttard   MPM.run(*M, MAM);
146c4f7cc86SDavid Stuttard 
147c4f7cc86SDavid Stuttard   // Verify that extra rematerializable instruction has been rematerialized
148c4f7cc86SDavid Stuttard   Function *F = M->getFunction("f.resume");
149c4f7cc86SDavid Stuttard   ASSERT_TRUE(F) << "could not find split function f.resume";
150c4f7cc86SDavid Stuttard 
151c4f7cc86SDavid Stuttard   BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
152c4f7cc86SDavid Stuttard   ASSERT_TRUE(Resume1)
153c4f7cc86SDavid Stuttard       << "could not find expected BB resume1 in split function";
154c4f7cc86SDavid Stuttard 
155c4f7cc86SDavid Stuttard   // With default materialization the intrinsic should not have been
156c4f7cc86SDavid Stuttard   // rematerialized
157c4f7cc86SDavid Stuttard   CallInst *CI = getCallByName(Resume1, "should.remat");
158c4f7cc86SDavid Stuttard   ASSERT_FALSE(CI);
159c4f7cc86SDavid Stuttard }
160c4f7cc86SDavid Stuttard 
161c4f7cc86SDavid Stuttard TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
162c4f7cc86SDavid Stuttard   ParseAssembly(Text);
163c4f7cc86SDavid Stuttard 
164c4f7cc86SDavid Stuttard   ASSERT_TRUE(M);
165c4f7cc86SDavid Stuttard 
166c4f7cc86SDavid Stuttard   CGSCCPassManager CGPM;
167c4f7cc86SDavid Stuttard   CGPM.addPass(
168c4f7cc86SDavid Stuttard       CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable)));
169c4f7cc86SDavid Stuttard   MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
170c4f7cc86SDavid Stuttard   MPM.run(*M, MAM);
171c4f7cc86SDavid Stuttard 
172c4f7cc86SDavid Stuttard   // Verify that extra rematerializable instruction has been rematerialized
173c4f7cc86SDavid Stuttard   Function *F = M->getFunction("f.resume");
174c4f7cc86SDavid Stuttard   ASSERT_TRUE(F) << "could not find split function f.resume";
175c4f7cc86SDavid Stuttard 
176c4f7cc86SDavid Stuttard   BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
177c4f7cc86SDavid Stuttard   ASSERT_TRUE(Resume1)
178c4f7cc86SDavid Stuttard       << "could not find expected BB resume1 in split function";
179c4f7cc86SDavid Stuttard 
180c4f7cc86SDavid Stuttard   // With callback the extra rematerialization of the function should have
181c4f7cc86SDavid Stuttard   // happened
182c4f7cc86SDavid Stuttard   CallInst *CI = getCallByName(Resume1, "should.remat");
183c4f7cc86SDavid Stuttard   ASSERT_TRUE(CI);
184c4f7cc86SDavid Stuttard }
1853737a532STyler Nowicki 
1863737a532STyler Nowicki StringRef TextCoroBeginCustomABI = R"(
1873737a532STyler Nowicki     define ptr @f(i32 %n) presplitcoroutine {
1883737a532STyler Nowicki     entry:
1893737a532STyler Nowicki       %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
1903737a532STyler Nowicki       %size = call i32 @llvm.coro.size.i32()
1913737a532STyler Nowicki       %alloc = call ptr @malloc(i32 %size)
1923737a532STyler Nowicki       %hdl = call ptr @llvm.coro.begin.custom.abi(token %id, ptr %alloc, i32 0)
1933737a532STyler Nowicki 
1943737a532STyler Nowicki       %inc1 = add i32 %n, 1
1953737a532STyler Nowicki       %val2 = call i32 @should.remat(i32 %inc1)
1963737a532STyler Nowicki       %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
1973737a532STyler Nowicki       switch i8 %sp1, label %suspend [i8 0, label %resume1
1983737a532STyler Nowicki                                       i8 1, label %cleanup]
1993737a532STyler Nowicki     resume1:
2003737a532STyler Nowicki       %inc2 = add i32 %val2, 1
2013737a532STyler Nowicki       %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
2023737a532STyler Nowicki       switch i8 %sp1, label %suspend [i8 0, label %resume2
2033737a532STyler Nowicki                                       i8 1, label %cleanup]
2043737a532STyler Nowicki 
2053737a532STyler Nowicki     resume2:
2063737a532STyler Nowicki       call void @print(i32 %val2)
2073737a532STyler Nowicki       call void @print(i32 %inc2)
2083737a532STyler Nowicki       br label %cleanup
2093737a532STyler Nowicki 
2103737a532STyler Nowicki     cleanup:
2113737a532STyler Nowicki       %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
2123737a532STyler Nowicki       call void @free(ptr %mem)
2133737a532STyler Nowicki       br label %suspend
2143737a532STyler Nowicki     suspend:
2153737a532STyler Nowicki       call i1 @llvm.coro.end(ptr %hdl, i1 0)
2163737a532STyler Nowicki       ret ptr %hdl
2173737a532STyler Nowicki     }
2183737a532STyler Nowicki 
2193737a532STyler Nowicki     declare ptr @llvm.coro.free(token, ptr)
2203737a532STyler Nowicki     declare i32 @llvm.coro.size.i32()
2213737a532STyler Nowicki     declare i8  @llvm.coro.suspend(token, i1)
2223737a532STyler Nowicki     declare void @llvm.coro.resume(ptr)
2233737a532STyler Nowicki     declare void @llvm.coro.destroy(ptr)
2243737a532STyler Nowicki 
2253737a532STyler Nowicki     declare token @llvm.coro.id(i32, ptr, ptr, ptr)
2263737a532STyler Nowicki     declare i1 @llvm.coro.alloc(token)
2273737a532STyler Nowicki     declare ptr @llvm.coro.begin.custom.abi(token, ptr, i32)
2283737a532STyler Nowicki     declare i1 @llvm.coro.end(ptr, i1)
2293737a532STyler Nowicki 
2303737a532STyler Nowicki     declare i32 @should.remat(i32)
2313737a532STyler Nowicki 
2323737a532STyler Nowicki     declare noalias ptr @malloc(i32)
2333737a532STyler Nowicki     declare void @print(i32)
2343737a532STyler Nowicki     declare void @free(ptr)
2353737a532STyler Nowicki   )";
2363737a532STyler Nowicki 
2373737a532STyler Nowicki // SwitchABI with overridden isMaterializable
2383737a532STyler Nowicki class ExtraCustomABI : public coro::SwitchABI {
2393737a532STyler Nowicki public:
2403737a532STyler Nowicki   ExtraCustomABI(Function &F, coro::Shape &S)
2413737a532STyler Nowicki       : coro::SwitchABI(F, S, ExtraMaterializable) {}
2423737a532STyler Nowicki };
2433737a532STyler Nowicki 
2443737a532STyler Nowicki TEST_F(ExtraRematTest, TestCoroRematWithCustomABI) {
2453737a532STyler Nowicki   ParseAssembly(TextCoroBeginCustomABI);
2463737a532STyler Nowicki 
2473737a532STyler Nowicki   ASSERT_TRUE(M);
2483737a532STyler Nowicki 
2493737a532STyler Nowicki   CoroSplitPass::BaseABITy GenCustomABI = [](Function &F, coro::Shape &S) {
250*12526231STyler Nowicki     return std::make_unique<ExtraCustomABI>(F, S);
2513737a532STyler Nowicki   };
2523737a532STyler Nowicki 
2533737a532STyler Nowicki   CGSCCPassManager CGPM;
2543737a532STyler Nowicki   CGPM.addPass(CoroSplitPass({GenCustomABI}));
2553737a532STyler Nowicki   MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
2563737a532STyler Nowicki   MPM.run(*M, MAM);
2573737a532STyler Nowicki 
2583737a532STyler Nowicki   // Verify that extra rematerializable instruction has been rematerialized
2593737a532STyler Nowicki   Function *F = M->getFunction("f.resume");
2603737a532STyler Nowicki   ASSERT_TRUE(F) << "could not find split function f.resume";
2613737a532STyler Nowicki 
2623737a532STyler Nowicki   BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
2633737a532STyler Nowicki   ASSERT_TRUE(Resume1)
2643737a532STyler Nowicki       << "could not find expected BB resume1 in split function";
2653737a532STyler Nowicki 
2663737a532STyler Nowicki   // With callback the extra rematerialization of the function should have
2673737a532STyler Nowicki   // happened
2683737a532STyler Nowicki   CallInst *CI = getCallByName(Resume1, "should.remat");
2693737a532STyler Nowicki   ASSERT_TRUE(CI);
2703737a532STyler Nowicki }
2713737a532STyler Nowicki 
272c4f7cc86SDavid Stuttard } // namespace
273