xref: /llvm-project/llvm/unittests/IR/PassManagerTest.cpp (revision 6b9816477b6bbf08f74e1188bc44bbb2942c3503)
1 //===- llvm/unittest/IR/PassManager.cpp - PassManager tests ---------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/LLVMContext.h"
13 #include "llvm/IR/Module.h"
14 #include "llvm/IR/PassManager.h"
15 #include "llvm/Support/SourceMgr.h"
16 #include "gtest/gtest.h"
17 
18 using namespace llvm;
19 
20 namespace {
21 
22 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> {
23 public:
24   struct Result {
25     Result(int Count) : InstructionCount(Count) {}
26     int InstructionCount;
27   };
28 
29   TestFunctionAnalysis(int &Runs) : Runs(Runs) {}
30 
31   /// \brief Run the analysis pass over the function and return a result.
32   Result run(Function &F, FunctionAnalysisManager &AM) {
33     ++Runs;
34     int Count = 0;
35     for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI)
36       for (BasicBlock::iterator II = BBI->begin(), IE = BBI->end(); II != IE;
37            ++II)
38         ++Count;
39     return Result(Count);
40   }
41 
42 private:
43   friend AnalysisInfoMixin<TestFunctionAnalysis>;
44   static AnalysisKey Key;
45 
46   int &Runs;
47 };
48 
49 AnalysisKey TestFunctionAnalysis::Key;
50 
51 class TestModuleAnalysis : public AnalysisInfoMixin<TestModuleAnalysis> {
52 public:
53   struct Result {
54     Result(int Count) : FunctionCount(Count) {}
55     int FunctionCount;
56   };
57 
58   TestModuleAnalysis(int &Runs) : Runs(Runs) {}
59 
60   Result run(Module &M, ModuleAnalysisManager &AM) {
61     ++Runs;
62     int Count = 0;
63     for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
64       ++Count;
65     return Result(Count);
66   }
67 
68 private:
69   friend AnalysisInfoMixin<TestModuleAnalysis>;
70   static AnalysisKey Key;
71 
72   int &Runs;
73 };
74 
75 AnalysisKey TestModuleAnalysis::Key;
76 
77 struct TestModulePass : PassInfoMixin<TestModulePass> {
78   TestModulePass(int &RunCount) : RunCount(RunCount) {}
79 
80   PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
81     ++RunCount;
82     return PreservedAnalyses::none();
83   }
84 
85   int &RunCount;
86 };
87 
88 struct TestPreservingModulePass : PassInfoMixin<TestPreservingModulePass> {
89   PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
90     return PreservedAnalyses::all();
91   }
92 };
93 
94 struct TestFunctionPass : PassInfoMixin<TestFunctionPass> {
95   TestFunctionPass(int &RunCount, int &AnalyzedInstrCount,
96                    int &AnalyzedFunctionCount,
97                    bool OnlyUseCachedResults = false)
98       : RunCount(RunCount), AnalyzedInstrCount(AnalyzedInstrCount),
99         AnalyzedFunctionCount(AnalyzedFunctionCount),
100         OnlyUseCachedResults(OnlyUseCachedResults) {}
101 
102   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) {
103     ++RunCount;
104 
105     const ModuleAnalysisManager &MAM =
106         AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager();
107     if (TestModuleAnalysis::Result *TMA =
108             MAM.getCachedResult<TestModuleAnalysis>(*F.getParent()))
109       AnalyzedFunctionCount += TMA->FunctionCount;
110 
111     if (OnlyUseCachedResults) {
112       // Hack to force the use of the cached interface.
113       if (TestFunctionAnalysis::Result *AR =
114               AM.getCachedResult<TestFunctionAnalysis>(F))
115         AnalyzedInstrCount += AR->InstructionCount;
116     } else {
117       // Typical path just runs the analysis as needed.
118       TestFunctionAnalysis::Result &AR = AM.getResult<TestFunctionAnalysis>(F);
119       AnalyzedInstrCount += AR.InstructionCount;
120     }
121 
122     return PreservedAnalyses::all();
123   }
124 
125   int &RunCount;
126   int &AnalyzedInstrCount;
127   int &AnalyzedFunctionCount;
128   bool OnlyUseCachedResults;
129 };
130 
131 // A test function pass that invalidates all function analyses for a function
132 // with a specific name.
133 struct TestInvalidationFunctionPass
134     : PassInfoMixin<TestInvalidationFunctionPass> {
135   TestInvalidationFunctionPass(StringRef FunctionName) : Name(FunctionName) {}
136 
137   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) {
138     return F.getName() == Name ? PreservedAnalyses::none()
139                                : PreservedAnalyses::all();
140   }
141 
142   StringRef Name;
143 };
144 
145 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
146   SMDiagnostic Err;
147   return parseAssemblyString(IR, Err, Context);
148 }
149 
150 class PassManagerTest : public ::testing::Test {
151 protected:
152   LLVMContext Context;
153   std::unique_ptr<Module> M;
154 
155 public:
156   PassManagerTest()
157       : M(parseIR(Context, "define void @f() {\n"
158                            "entry:\n"
159                            "  call void @g()\n"
160                            "  call void @h()\n"
161                            "  ret void\n"
162                            "}\n"
163                            "define void @g() {\n"
164                            "  ret void\n"
165                            "}\n"
166                            "define void @h() {\n"
167                            "  ret void\n"
168                            "}\n")) {}
169 };
170 
171 TEST_F(PassManagerTest, BasicPreservedAnalyses) {
172   PreservedAnalyses PA1 = PreservedAnalyses();
173   EXPECT_FALSE(PA1.preserved<TestFunctionAnalysis>());
174   EXPECT_FALSE(PA1.preserved<TestModuleAnalysis>());
175   PreservedAnalyses PA2 = PreservedAnalyses::none();
176   EXPECT_FALSE(PA2.preserved<TestFunctionAnalysis>());
177   EXPECT_FALSE(PA2.preserved<TestModuleAnalysis>());
178   PreservedAnalyses PA3 = PreservedAnalyses::all();
179   EXPECT_TRUE(PA3.preserved<TestFunctionAnalysis>());
180   EXPECT_TRUE(PA3.preserved<TestModuleAnalysis>());
181   PreservedAnalyses PA4 = PA1;
182   EXPECT_FALSE(PA4.preserved<TestFunctionAnalysis>());
183   EXPECT_FALSE(PA4.preserved<TestModuleAnalysis>());
184   PA4 = PA3;
185   EXPECT_TRUE(PA4.preserved<TestFunctionAnalysis>());
186   EXPECT_TRUE(PA4.preserved<TestModuleAnalysis>());
187   PA4 = std::move(PA2);
188   EXPECT_FALSE(PA4.preserved<TestFunctionAnalysis>());
189   EXPECT_FALSE(PA4.preserved<TestModuleAnalysis>());
190   PA4.preserve<TestFunctionAnalysis>();
191   EXPECT_TRUE(PA4.preserved<TestFunctionAnalysis>());
192   EXPECT_FALSE(PA4.preserved<TestModuleAnalysis>());
193   PA1.preserve<TestModuleAnalysis>();
194   EXPECT_FALSE(PA1.preserved<TestFunctionAnalysis>());
195   EXPECT_TRUE(PA1.preserved<TestModuleAnalysis>());
196   PA1.preserve<TestFunctionAnalysis>();
197   EXPECT_TRUE(PA1.preserved<TestFunctionAnalysis>());
198   EXPECT_TRUE(PA1.preserved<TestModuleAnalysis>());
199   PA1.intersect(PA4);
200   EXPECT_TRUE(PA1.preserved<TestFunctionAnalysis>());
201   EXPECT_FALSE(PA1.preserved<TestModuleAnalysis>());
202 }
203 
204 TEST_F(PassManagerTest, Basic) {
205   FunctionAnalysisManager FAM(/*DebugLogging*/ true);
206   int FunctionAnalysisRuns = 0;
207   FAM.registerPass([&] { return TestFunctionAnalysis(FunctionAnalysisRuns); });
208 
209   ModuleAnalysisManager MAM(/*DebugLogging*/ true);
210   int ModuleAnalysisRuns = 0;
211   MAM.registerPass([&] { return TestModuleAnalysis(ModuleAnalysisRuns); });
212   MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
213   FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
214 
215   ModulePassManager MPM;
216 
217   // Count the runs over a Function.
218   int FunctionPassRunCount1 = 0;
219   int AnalyzedInstrCount1 = 0;
220   int AnalyzedFunctionCount1 = 0;
221   {
222     // Pointless scoped copy to test move assignment.
223     ModulePassManager NestedMPM(/*DebugLogging*/ true);
224     FunctionPassManager FPM;
225     {
226       // Pointless scope to test move assignment.
227       FunctionPassManager NestedFPM(/*DebugLogging*/ true);
228       NestedFPM.addPass(TestFunctionPass(
229           FunctionPassRunCount1, AnalyzedInstrCount1, AnalyzedFunctionCount1));
230       FPM = std::move(NestedFPM);
231     }
232     NestedMPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
233     MPM = std::move(NestedMPM);
234   }
235 
236   // Count the runs over a module.
237   int ModulePassRunCount = 0;
238   MPM.addPass(TestModulePass(ModulePassRunCount));
239 
240   // Count the runs over a Function in a separate manager.
241   int FunctionPassRunCount2 = 0;
242   int AnalyzedInstrCount2 = 0;
243   int AnalyzedFunctionCount2 = 0;
244   {
245     FunctionPassManager FPM(/*DebugLogging*/ true);
246     FPM.addPass(TestFunctionPass(FunctionPassRunCount2, AnalyzedInstrCount2,
247                                  AnalyzedFunctionCount2));
248     MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
249   }
250 
251   // A third function pass manager but with only preserving intervening passes
252   // and with a function pass that invalidates exactly one analysis.
253   MPM.addPass(TestPreservingModulePass());
254   int FunctionPassRunCount3 = 0;
255   int AnalyzedInstrCount3 = 0;
256   int AnalyzedFunctionCount3 = 0;
257   {
258     FunctionPassManager FPM(/*DebugLogging*/ true);
259     FPM.addPass(TestFunctionPass(FunctionPassRunCount3, AnalyzedInstrCount3,
260                                  AnalyzedFunctionCount3));
261     FPM.addPass(TestInvalidationFunctionPass("f"));
262     MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
263   }
264 
265   // A fourth function pass manager but with only preserving intervening
266   // passes but triggering the module analysis.
267   MPM.addPass(RequireAnalysisPass<TestModuleAnalysis, Module>());
268   int FunctionPassRunCount4 = 0;
269   int AnalyzedInstrCount4 = 0;
270   int AnalyzedFunctionCount4 = 0;
271   {
272     FunctionPassManager FPM;
273     FPM.addPass(TestFunctionPass(FunctionPassRunCount4, AnalyzedInstrCount4,
274                                  AnalyzedFunctionCount4));
275     MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
276   }
277 
278   // A fifth function pass manager which invalidates one function first but
279   // uses only cached results.
280   int FunctionPassRunCount5 = 0;
281   int AnalyzedInstrCount5 = 0;
282   int AnalyzedFunctionCount5 = 0;
283   {
284     FunctionPassManager FPM(/*DebugLogging*/ true);
285     FPM.addPass(TestInvalidationFunctionPass("f"));
286     FPM.addPass(TestFunctionPass(FunctionPassRunCount5, AnalyzedInstrCount5,
287                                  AnalyzedFunctionCount5,
288                                  /*OnlyUseCachedResults=*/true));
289     MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
290   }
291 
292   MPM.run(*M, MAM);
293 
294   // Validate module pass counters.
295   EXPECT_EQ(1, ModulePassRunCount);
296 
297   // Validate all function pass counter sets are the same.
298   EXPECT_EQ(3, FunctionPassRunCount1);
299   EXPECT_EQ(5, AnalyzedInstrCount1);
300   EXPECT_EQ(0, AnalyzedFunctionCount1);
301   EXPECT_EQ(3, FunctionPassRunCount2);
302   EXPECT_EQ(5, AnalyzedInstrCount2);
303   EXPECT_EQ(0, AnalyzedFunctionCount2);
304   EXPECT_EQ(3, FunctionPassRunCount3);
305   EXPECT_EQ(5, AnalyzedInstrCount3);
306   EXPECT_EQ(0, AnalyzedFunctionCount3);
307   EXPECT_EQ(3, FunctionPassRunCount4);
308   EXPECT_EQ(5, AnalyzedInstrCount4);
309   EXPECT_EQ(9, AnalyzedFunctionCount4);
310   EXPECT_EQ(3, FunctionPassRunCount5);
311   EXPECT_EQ(2, AnalyzedInstrCount5); // Only 'g' and 'h' were cached.
312   EXPECT_EQ(9, AnalyzedFunctionCount5);
313 
314   // Validate the analysis counters:
315   //   first run over 3 functions, then module pass invalidates
316   //   second run over 3 functions, nothing invalidates
317   //   third run over 0 functions, but 1 function invalidated
318   //   fourth run over 1 function
319   //   fifth run invalidates 1 function first, but runs over 0 functions
320   EXPECT_EQ(7, FunctionAnalysisRuns);
321 
322   EXPECT_EQ(1, ModuleAnalysisRuns);
323 }
324 
325 // A customized pass manager that passes extra arguments through the
326 // infrastructure.
327 typedef AnalysisManager<Function, int> CustomizedAnalysisManager;
328 typedef PassManager<Function, CustomizedAnalysisManager, int, int &>
329     CustomizedPassManager;
330 
331 class CustomizedAnalysis : public AnalysisInfoMixin<CustomizedAnalysis> {
332 public:
333   struct Result {
334     Result(int I) : I(I) {}
335     int I;
336   };
337 
338   Result run(Function &F, CustomizedAnalysisManager &AM, int I) {
339     return Result(I);
340   }
341 
342 private:
343   friend AnalysisInfoMixin<CustomizedAnalysis>;
344   static AnalysisKey Key;
345 };
346 
347 AnalysisKey CustomizedAnalysis::Key;
348 
349 struct CustomizedPass : PassInfoMixin<CustomizedPass> {
350   std::function<void(CustomizedAnalysis::Result &, int &)> Callback;
351 
352   template <typename CallbackT>
353   CustomizedPass(CallbackT Callback) : Callback(Callback) {}
354 
355   PreservedAnalyses run(Function &F, CustomizedAnalysisManager &AM, int I,
356                         int &O) {
357     Callback(AM.getResult<CustomizedAnalysis>(F, I), O);
358     return PreservedAnalyses::none();
359   }
360 };
361 
362 TEST_F(PassManagerTest, CustomizedPassManagerArgs) {
363   CustomizedAnalysisManager AM;
364   AM.registerPass([&] { return CustomizedAnalysis(); });
365 
366   CustomizedPassManager PM;
367 
368   // Add an instance of the customized pass that just accumulates the input
369   // after it is round-tripped through the analysis.
370   int Result = 0;
371   PM.addPass(
372       CustomizedPass([](CustomizedAnalysis::Result &R, int &O) { O += R.I; }));
373 
374   // Run this over every function with the input of 42.
375   for (Function &F : *M)
376     PM.run(F, AM, 42, Result);
377 
378   // And ensure that we accumulated the correct result.
379   EXPECT_EQ(42 * (int)M->size(), Result);
380 }
381 
382 /// A test analysis pass which caches in its result another analysis pass and
383 /// uses it to serve queries. This requires the result to invalidate itself
384 /// when its dependency is invalidated.
385 struct TestIndirectFunctionAnalysis
386     : public AnalysisInfoMixin<TestIndirectFunctionAnalysis> {
387   struct Result {
388     Result(TestFunctionAnalysis::Result &Dep) : Dep(Dep) {}
389     TestFunctionAnalysis::Result &Dep;
390 
391     bool invalidate(Function &F, const PreservedAnalyses &PA,
392                     FunctionAnalysisManager::Invalidator &Inv) {
393       return !PA.preserved<TestIndirectFunctionAnalysis>() ||
394              Inv.invalidate<TestFunctionAnalysis>(F, PA);
395     }
396   };
397 
398   TestIndirectFunctionAnalysis(int &Runs) : Runs(Runs) {}
399 
400   /// Run the analysis pass over the function and return a result.
401   Result run(Function &F, FunctionAnalysisManager &AM) {
402     ++Runs;
403     return Result(AM.getResult<TestFunctionAnalysis>(F));
404   }
405 
406 private:
407   friend AnalysisInfoMixin<TestIndirectFunctionAnalysis>;
408   static AnalysisKey Key;
409 
410   int &Runs;
411 };
412 
413 AnalysisKey TestIndirectFunctionAnalysis::Key;
414 
415 struct LambdaPass : public PassInfoMixin<LambdaPass> {
416   using FuncT = std::function<PreservedAnalyses(Function &, FunctionAnalysisManager &)>;
417 
418   LambdaPass(FuncT Func) : Func(std::move(Func)) {}
419 
420   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) {
421     return Func(F, AM);
422   }
423 
424   FuncT Func;
425 };
426 
427 TEST_F(PassManagerTest, IndirectAnalysisInvalidation) {
428   FunctionAnalysisManager FAM(/*DebugLogging*/ true);
429   int AnalysisRuns = 0, IndirectAnalysisRuns = 0;
430   FAM.registerPass([&] { return TestFunctionAnalysis(AnalysisRuns); });
431   FAM.registerPass(
432       [&] { return TestIndirectFunctionAnalysis(IndirectAnalysisRuns); });
433 
434   ModuleAnalysisManager MAM(/*DebugLogging*/ true);
435   MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
436   FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
437 
438   int InstrCount = 0;
439   ModulePassManager MPM(/*DebugLogging*/ true);
440   FunctionPassManager FPM(/*DebugLogging*/ true);
441   // First just use the analysis to get the instruction count, and preserve
442   // everything.
443   FPM.addPass(LambdaPass([&](Function &F, FunctionAnalysisManager &AM) {
444     InstrCount +=
445         AM.getResult<TestIndirectFunctionAnalysis>(F).Dep.InstructionCount;
446     return PreservedAnalyses::all();
447   }));
448   // Next, invalidate
449   //   - both analyses for "f",
450   //   - just the underlying (indirect) analysis for "g", and
451   //   - just the direct analysis for "h".
452   FPM.addPass(LambdaPass([&](Function &F, FunctionAnalysisManager &AM) {
453     InstrCount +=
454         AM.getResult<TestIndirectFunctionAnalysis>(F).Dep.InstructionCount;
455     auto PA = PreservedAnalyses::none();
456     if (F.getName() == "g")
457       PA.preserve<TestFunctionAnalysis>();
458     else if (F.getName() == "h")
459       PA.preserve<TestIndirectFunctionAnalysis>();
460     return PA;
461   }));
462   // Finally, use the analysis again on each function, forcing re-computation
463   // for all of them.
464   FPM.addPass(LambdaPass([&](Function &F, FunctionAnalysisManager &AM) {
465     InstrCount +=
466         AM.getResult<TestIndirectFunctionAnalysis>(F).Dep.InstructionCount;
467     return PreservedAnalyses::all();
468   }));
469   MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
470   MPM.run(*M, MAM);
471 
472   // There are generally two possible runs for each of the three functions. But
473   // for one function, we only invalidate the indirect analysis so the base one
474   // only gets run five times.
475   EXPECT_EQ(5, AnalysisRuns);
476   // The indirect analysis is invalidated for each function (either directly or
477   // indirectly) and run twice for each.
478   EXPECT_EQ(6, IndirectAnalysisRuns);
479 
480   // There are five instructions in the module and we add the count three
481   // times.
482   EXPECT_EQ(5 * 3, InstrCount);
483 }
484 }
485