xref: /llvm-project/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp (revision 125262312f366bd776b668b24026dbbc8e6b4c75)
1 //===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/AsmParser/Parser.h"
10 #include "llvm/IR/Module.h"
11 #include "llvm/Passes/PassBuilder.h"
12 #include "llvm/Support/SourceMgr.h"
13 #include "llvm/Testing/Support/Error.h"
14 #include "llvm/Transforms/Coroutines/ABI.h"
15 #include "llvm/Transforms/Coroutines/CoroSplit.h"
16 #include "gtest/gtest.h"
17 
18 using namespace llvm;
19 
20 namespace {
21 
22 struct ExtraRematTest : public testing::Test {
23   LLVMContext Ctx;
24   ModulePassManager MPM;
25   PassBuilder PB;
26   LoopAnalysisManager LAM;
27   FunctionAnalysisManager FAM;
28   CGSCCAnalysisManager CGAM;
29   ModuleAnalysisManager MAM;
30   LLVMContext Context;
31   std::unique_ptr<Module> M;
32 
33   ExtraRematTest() {
34     PB.registerModuleAnalyses(MAM);
35     PB.registerCGSCCAnalyses(CGAM);
36     PB.registerFunctionAnalyses(FAM);
37     PB.registerLoopAnalyses(LAM);
38     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
39   }
40 
41   BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const {
42     for (BasicBlock &BB : *F) {
43       if (BB.getName() == Name)
44         return &BB;
45     }
46     return nullptr;
47   }
48 
49   CallInst *getCallByName(BasicBlock *BB, StringRef Name) const {
50     for (Instruction &I : *BB) {
51       if (CallInst *CI = dyn_cast<CallInst>(&I))
52         if (CI->getCalledFunction()->getName() == Name)
53           return CI;
54     }
55     return nullptr;
56   }
57 
58   void ParseAssembly(const StringRef IR) {
59     SMDiagnostic Error;
60     M = parseAssemblyString(IR, Error, Context);
61     std::string errMsg;
62     raw_string_ostream os(errMsg);
63     Error.print("", os);
64 
65     // A failure here means that the test itself is buggy.
66     if (!M)
67       report_fatal_error(errMsg.c_str());
68   }
69 };
70 
71 StringRef Text = R"(
72     define ptr @f(i32 %n) presplitcoroutine {
73     entry:
74       %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
75       %size = call i32 @llvm.coro.size.i32()
76       %alloc = call ptr @malloc(i32 %size)
77       %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc)
78 
79       %inc1 = add i32 %n, 1
80       %val2 = call i32 @should.remat(i32 %inc1)
81       %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
82       switch i8 %sp1, label %suspend [i8 0, label %resume1
83                                       i8 1, label %cleanup]
84     resume1:
85       %inc2 = add i32 %val2, 1
86       %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
87       switch i8 %sp1, label %suspend [i8 0, label %resume2
88                                       i8 1, label %cleanup]
89 
90     resume2:
91       call void @print(i32 %val2)
92       call void @print(i32 %inc2)
93       br label %cleanup
94 
95     cleanup:
96       %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
97       call void @free(ptr %mem)
98       br label %suspend
99     suspend:
100       call i1 @llvm.coro.end(ptr %hdl, i1 0)
101       ret ptr %hdl
102     }
103 
104     declare ptr @llvm.coro.free(token, ptr)
105     declare i32 @llvm.coro.size.i32()
106     declare i8  @llvm.coro.suspend(token, i1)
107     declare void @llvm.coro.resume(ptr)
108     declare void @llvm.coro.destroy(ptr)
109 
110     declare token @llvm.coro.id(i32, ptr, ptr, ptr)
111     declare i1 @llvm.coro.alloc(token)
112     declare ptr @llvm.coro.begin(token, ptr)
113     declare i1 @llvm.coro.end(ptr, i1)
114 
115     declare i32 @should.remat(i32)
116 
117     declare noalias ptr @malloc(i32)
118     declare void @print(i32)
119     declare void @free(ptr)
120   )";
121 
122 // Materializable callback with extra rematerialization
123 bool ExtraMaterializable(Instruction &I) {
124   if (isa<CastInst>(&I) || isa<GetElementPtrInst>(&I) ||
125       isa<BinaryOperator>(&I) || isa<CmpInst>(&I) || isa<SelectInst>(&I))
126     return true;
127 
128   if (auto *CI = dyn_cast<CallInst>(&I)) {
129     auto *CalledFunc = CI->getCalledFunction();
130     if (CalledFunc && CalledFunc->getName().starts_with("should.remat"))
131       return true;
132   }
133 
134   return false;
135 }
136 
137 TEST_F(ExtraRematTest, TestCoroRematDefault) {
138   ParseAssembly(Text);
139 
140   ASSERT_TRUE(M);
141 
142   CGSCCPassManager CGPM;
143   CGPM.addPass(CoroSplitPass());
144   MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
145   MPM.run(*M, MAM);
146 
147   // Verify that extra rematerializable instruction has been rematerialized
148   Function *F = M->getFunction("f.resume");
149   ASSERT_TRUE(F) << "could not find split function f.resume";
150 
151   BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
152   ASSERT_TRUE(Resume1)
153       << "could not find expected BB resume1 in split function";
154 
155   // With default materialization the intrinsic should not have been
156   // rematerialized
157   CallInst *CI = getCallByName(Resume1, "should.remat");
158   ASSERT_FALSE(CI);
159 }
160 
161 TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
162   ParseAssembly(Text);
163 
164   ASSERT_TRUE(M);
165 
166   CGSCCPassManager CGPM;
167   CGPM.addPass(
168       CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable)));
169   MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
170   MPM.run(*M, MAM);
171 
172   // Verify that extra rematerializable instruction has been rematerialized
173   Function *F = M->getFunction("f.resume");
174   ASSERT_TRUE(F) << "could not find split function f.resume";
175 
176   BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
177   ASSERT_TRUE(Resume1)
178       << "could not find expected BB resume1 in split function";
179 
180   // With callback the extra rematerialization of the function should have
181   // happened
182   CallInst *CI = getCallByName(Resume1, "should.remat");
183   ASSERT_TRUE(CI);
184 }
185 
186 StringRef TextCoroBeginCustomABI = R"(
187     define ptr @f(i32 %n) presplitcoroutine {
188     entry:
189       %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
190       %size = call i32 @llvm.coro.size.i32()
191       %alloc = call ptr @malloc(i32 %size)
192       %hdl = call ptr @llvm.coro.begin.custom.abi(token %id, ptr %alloc, i32 0)
193 
194       %inc1 = add i32 %n, 1
195       %val2 = call i32 @should.remat(i32 %inc1)
196       %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
197       switch i8 %sp1, label %suspend [i8 0, label %resume1
198                                       i8 1, label %cleanup]
199     resume1:
200       %inc2 = add i32 %val2, 1
201       %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
202       switch i8 %sp1, label %suspend [i8 0, label %resume2
203                                       i8 1, label %cleanup]
204 
205     resume2:
206       call void @print(i32 %val2)
207       call void @print(i32 %inc2)
208       br label %cleanup
209 
210     cleanup:
211       %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
212       call void @free(ptr %mem)
213       br label %suspend
214     suspend:
215       call i1 @llvm.coro.end(ptr %hdl, i1 0)
216       ret ptr %hdl
217     }
218 
219     declare ptr @llvm.coro.free(token, ptr)
220     declare i32 @llvm.coro.size.i32()
221     declare i8  @llvm.coro.suspend(token, i1)
222     declare void @llvm.coro.resume(ptr)
223     declare void @llvm.coro.destroy(ptr)
224 
225     declare token @llvm.coro.id(i32, ptr, ptr, ptr)
226     declare i1 @llvm.coro.alloc(token)
227     declare ptr @llvm.coro.begin.custom.abi(token, ptr, i32)
228     declare i1 @llvm.coro.end(ptr, i1)
229 
230     declare i32 @should.remat(i32)
231 
232     declare noalias ptr @malloc(i32)
233     declare void @print(i32)
234     declare void @free(ptr)
235   )";
236 
237 // SwitchABI with overridden isMaterializable
238 class ExtraCustomABI : public coro::SwitchABI {
239 public:
240   ExtraCustomABI(Function &F, coro::Shape &S)
241       : coro::SwitchABI(F, S, ExtraMaterializable) {}
242 };
243 
244 TEST_F(ExtraRematTest, TestCoroRematWithCustomABI) {
245   ParseAssembly(TextCoroBeginCustomABI);
246 
247   ASSERT_TRUE(M);
248 
249   CoroSplitPass::BaseABITy GenCustomABI = [](Function &F, coro::Shape &S) {
250     return std::make_unique<ExtraCustomABI>(F, S);
251   };
252 
253   CGSCCPassManager CGPM;
254   CGPM.addPass(CoroSplitPass({GenCustomABI}));
255   MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
256   MPM.run(*M, MAM);
257 
258   // Verify that extra rematerializable instruction has been rematerialized
259   Function *F = M->getFunction("f.resume");
260   ASSERT_TRUE(F) << "could not find split function f.resume";
261 
262   BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
263   ASSERT_TRUE(Resume1)
264       << "could not find expected BB resume1 in split function";
265 
266   // With callback the extra rematerialization of the function should have
267   // happened
268   CallInst *CI = getCallByName(Resume1, "should.remat");
269   ASSERT_TRUE(CI);
270 }
271 
272 } // namespace
273