xref: /llvm-project/llvm/unittests/Analysis/SparsePropagation.cpp (revision 056a3f4673a4f88d89e9bf00614355f671014ca5)
1 //===- SparsePropagation.cpp - Unit tests for the generic solver ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/SparsePropagation.h"
10 #include "llvm/ADT/PointerIntPair.h"
11 #include "llvm/IR/IRBuilder.h"
12 #include "llvm/IR/Module.h"
13 #include "gtest/gtest.h"
14 using namespace llvm;
15 
16 namespace {
17 /// To enable interprocedural analysis, we assign LLVM values to the following
18 /// groups. The register group represents SSA registers, the return group
19 /// represents the return values of functions, and the memory group represents
20 /// in-memory values. An LLVM Value can technically be in more than one group.
21 /// It's necessary to distinguish these groups so we can, for example, track a
22 /// global variable separately from the value stored at its location.
23 enum class IPOGrouping { Register, Return, Memory };
24 
25 /// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings.
26 /// The PointerIntPair header provides a DenseMapInfo specialization, so using
27 /// these as LatticeKeys is fine.
28 using TestLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>;
29 } // namespace
30 
31 namespace llvm {
32 /// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver
33 /// must translate between LatticeKeys and LLVM Values when adding Values to
34 /// its work list and inspecting the state of control-flow related values.
35 template <> struct LatticeKeyInfo<TestLatticeKey> {
36   static inline Value *getValueFromLatticeKey(TestLatticeKey Key) {
37     return Key.getPointer();
38   }
39   static inline TestLatticeKey getLatticeKeyFromValue(Value *V) {
40     return TestLatticeKey(V, IPOGrouping::Register);
41   }
42 };
43 } // namespace llvm
44 
45 namespace {
46 /// This class defines a simple test lattice value that could be used for
47 /// solving problems similar to constant propagation. The value is maintained
48 /// as a PointerIntPair.
49 class TestLatticeVal {
50 public:
51   /// The states of the lattices value. Only the ConstantVal state is
52   /// interesting; the rest are special states used by the generic solver. The
53   /// UntrackedVal state differs from the other three in that the generic
54   /// solver uses it to avoid doing unnecessary work. In particular, when a
55   /// value moves to the UntrackedVal state, it's users are not notified.
56   enum TestLatticeStateTy {
57     UndefinedVal,
58     ConstantVal,
59     OverdefinedVal,
60     UntrackedVal
61   };
62 
63   TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {}
64   TestLatticeVal(Constant *C, TestLatticeStateTy State)
65       : LatticeVal(C, State) {}
66 
67   /// Return true if this lattice value is in the Constant state. This is used
68   /// for checking the solver results.
69   bool isConstant() const { return LatticeVal.getInt() == ConstantVal; }
70 
71   /// Return true if this lattice value is in the Overdefined state. This is
72   /// used for checking the solver results.
73   bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; }
74 
75   bool operator==(const TestLatticeVal &RHS) const {
76     return LatticeVal == RHS.LatticeVal;
77   }
78 
79   bool operator!=(const TestLatticeVal &RHS) const {
80     return LatticeVal != RHS.LatticeVal;
81   }
82 
83 private:
84   /// A simple lattice value type for problems similar to constant propagation.
85   /// It holds the constant value and the lattice state.
86   PointerIntPair<const Constant *, 2, TestLatticeStateTy> LatticeVal;
87 };
88 
89 /// This class defines a simple test lattice function that could be used for
90 /// solving problems similar to constant propagation. The test lattice differs
91 /// from a "real" lattice in a few ways. First, it initializes all return
92 /// values, values stored in global variables, and arguments in the undefined
93 /// state. This means that there are no limitations on what we can track
94 /// interprocedurally. For simplicity, all global values in the tests will be
95 /// given internal linkage, since this is not something this lattice function
96 /// tracks. Second, it only handles the few instructions necessary for the
97 /// tests.
98 class TestLatticeFunc
99     : public AbstractLatticeFunction<TestLatticeKey, TestLatticeVal> {
100 public:
101   /// Construct a new test lattice function with special values for the
102   /// Undefined, Overdefined, and Untracked states.
103   TestLatticeFunc()
104       : AbstractLatticeFunction(
105             TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal),
106             TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal),
107             TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {}
108 
109   /// Compute and return a TestLatticeVal for the given TestLatticeKey. For the
110   /// test analysis, a LatticeKey will begin in the undefined state, unless it
111   /// represents an LLVM Constant in the register grouping.
112   TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override {
113     if (Key.getInt() == IPOGrouping::Register)
114       if (auto *C = dyn_cast<Constant>(Key.getPointer()))
115         return TestLatticeVal(C, TestLatticeVal::ConstantVal);
116     return getUndefVal();
117   }
118 
119   /// Merge the two given lattice values. This merge should be equivalent to
120   /// what is done for constant propagation. That is, the resulting lattice
121   /// value is constant only if the two given lattice values are constant and
122   /// hold the same value.
123   TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override {
124     if (X == getUntrackedVal() || Y == getUntrackedVal())
125       return getUntrackedVal();
126     if (X == getOverdefinedVal() || Y == getOverdefinedVal())
127       return getOverdefinedVal();
128     if (X == getUndefVal() && Y == getUndefVal())
129       return getUndefVal();
130     if (X == getUndefVal())
131       return Y;
132     if (Y == getUndefVal())
133       return X;
134     if (X == Y)
135       return X;
136     return getOverdefinedVal();
137   }
138 
139   /// Compute the lattice values that change as a result of executing the given
140   /// instruction. We only handle the few instructions needed for the tests.
141   void ComputeInstructionState(
142       Instruction &I,
143       SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues,
144       SparseSolver<TestLatticeKey, TestLatticeVal> &SS) override {
145     switch (I.getOpcode()) {
146     case Instruction::Call:
147       return visitCallBase(cast<CallBase>(I), ChangedValues, SS);
148     case Instruction::Ret:
149       return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS);
150     case Instruction::Store:
151       return visitStore(*cast<StoreInst>(&I), ChangedValues, SS);
152     default:
153       return visitInst(I, ChangedValues, SS);
154     }
155   }
156 
157 private:
158   /// Handle call sites. The state of a called function's argument is the merge
159   /// of the current formal argument state with the call site's corresponding
160   /// actual argument state. The call site state is the merge of the call site
161   /// state with the returned value state of the called function.
162   void visitCallBase(CallBase &I,
163                      SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues,
164                      SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
165     Function *F = I.getCalledFunction();
166     auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
167     if (!F) {
168       ChangedValues[RegI] = getOverdefinedVal();
169       return;
170     }
171     SS.MarkBlockExecutable(&F->front());
172     for (Argument &A : F->args()) {
173       auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register);
174       auto RegActual =
175           TestLatticeKey(I.getArgOperand(A.getArgNo()), IPOGrouping::Register);
176       ChangedValues[RegFormal] =
177           MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual));
178     }
179     auto RetF = TestLatticeKey(F, IPOGrouping::Return);
180     ChangedValues[RegI] =
181         MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
182   }
183 
184   /// Handle return instructions. The function's return state is the merge of
185   /// the returned value state and the function's current return state.
186   void visitReturn(ReturnInst &I,
187                    SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues,
188                    SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
189     Function *F = I.getParent()->getParent();
190     if (F->getReturnType()->isVoidTy())
191       return;
192     auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register);
193     auto RetF = TestLatticeKey(F, IPOGrouping::Return);
194     ChangedValues[RetF] =
195         MergeValues(SS.getValueState(RegR), SS.getValueState(RetF));
196   }
197 
198   /// Handle store instructions. If the pointer operand of the store is a
199   /// global variable, we attempt to track the value. The global variable state
200   /// is the merge of the stored value state with the current global variable
201   /// state.
202   void visitStore(StoreInst &I,
203                   SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues,
204                   SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
205     auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand());
206     if (!GV)
207       return;
208     auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register);
209     auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory);
210     ChangedValues[MemPtr] =
211         MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr));
212   }
213 
214   /// Handle all other instructions. All other instructions are marked
215   /// overdefined.
216   void visitInst(Instruction &I,
217                  SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues,
218                  SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
219     auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
220     ChangedValues[RegI] = getOverdefinedVal();
221   }
222 };
223 
224 /// This class defines the common data used for all of the tests. The tests
225 /// should add code to the module and then run the solver.
226 class SparsePropagationTest : public testing::Test {
227 protected:
228   LLVMContext Context;
229   Module M;
230   IRBuilder<> Builder;
231   TestLatticeFunc Lattice;
232   SparseSolver<TestLatticeKey, TestLatticeVal> Solver;
233 
234 public:
235   SparsePropagationTest()
236       : M("", Context), Builder(Context), Solver(&Lattice) {}
237 };
238 } // namespace
239 
240 /// Test that we mark discovered functions executable.
241 ///
242 /// define internal void @f() {
243 ///   call void @g()
244 ///   ret void
245 /// }
246 ///
247 /// define internal void @g() {
248 ///   call void @f()
249 ///   ret void
250 /// }
251 ///
252 /// For this test, we initially mark "f" executable, and the solver discovers
253 /// "g" because of the call in "f". The mutually recursive call in "g" also
254 /// tests that we don't add a block to the basic block work list if it is
255 /// already executable. Doing so would put the solver into an infinite loop.
256 TEST_F(SparsePropagationTest, MarkBlockExecutable) {
257   Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
258                                  GlobalValue::InternalLinkage, "f", &M);
259   Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
260                                  GlobalValue::InternalLinkage, "g", &M);
261   BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
262   BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
263   Builder.SetInsertPoint(FEntry);
264   Builder.CreateCall(G);
265   Builder.CreateRetVoid();
266   Builder.SetInsertPoint(GEntry);
267   Builder.CreateCall(F);
268   Builder.CreateRetVoid();
269 
270   Solver.MarkBlockExecutable(FEntry);
271   Solver.Solve();
272 
273   EXPECT_TRUE(Solver.isBlockExecutable(GEntry));
274 }
275 
276 /// Test that we propagate information through global variables.
277 ///
278 /// @gv = internal global i64
279 ///
280 /// define internal void @f() {
281 ///   store i64 1, i64* @gv
282 ///   ret void
283 /// }
284 ///
285 /// define internal void @g() {
286 ///   store i64 1, i64* @gv
287 ///   ret void
288 /// }
289 ///
290 /// For this test, we initially mark both "f" and "g" executable, and the
291 /// solver computes the lattice state of the global variable as constant.
292 TEST_F(SparsePropagationTest, GlobalVariableConstant) {
293   Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
294                                  GlobalValue::InternalLinkage, "f", &M);
295   Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
296                                  GlobalValue::InternalLinkage, "g", &M);
297   GlobalVariable *GV =
298       new GlobalVariable(M, Builder.getInt64Ty(), false,
299                          GlobalValue::InternalLinkage, nullptr, "gv");
300   BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
301   BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
302   Builder.SetInsertPoint(FEntry);
303   Builder.CreateStore(Builder.getInt64(1), GV);
304   Builder.CreateRetVoid();
305   Builder.SetInsertPoint(GEntry);
306   Builder.CreateStore(Builder.getInt64(1), GV);
307   Builder.CreateRetVoid();
308 
309   Solver.MarkBlockExecutable(FEntry);
310   Solver.MarkBlockExecutable(GEntry);
311   Solver.Solve();
312 
313   auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
314   EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant());
315 }
316 
317 /// Test that we propagate information through global variables.
318 ///
319 /// @gv = internal global i64
320 ///
321 /// define internal void @f() {
322 ///   store i64 0, i64* @gv
323 ///   ret void
324 /// }
325 ///
326 /// define internal void @g() {
327 ///   store i64 1, i64* @gv
328 ///   ret void
329 /// }
330 ///
331 /// For this test, we initially mark both "f" and "g" executable, and the
332 /// solver computes the lattice state of the global variable as overdefined.
333 TEST_F(SparsePropagationTest, GlobalVariableOverDefined) {
334   Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
335                                  GlobalValue::InternalLinkage, "f", &M);
336   Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
337                                  GlobalValue::InternalLinkage, "g", &M);
338   GlobalVariable *GV =
339       new GlobalVariable(M, Builder.getInt64Ty(), false,
340                          GlobalValue::InternalLinkage, nullptr, "gv");
341   BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
342   BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
343   Builder.SetInsertPoint(FEntry);
344   Builder.CreateStore(Builder.getInt64(0), GV);
345   Builder.CreateRetVoid();
346   Builder.SetInsertPoint(GEntry);
347   Builder.CreateStore(Builder.getInt64(1), GV);
348   Builder.CreateRetVoid();
349 
350   Solver.MarkBlockExecutable(FEntry);
351   Solver.MarkBlockExecutable(GEntry);
352   Solver.Solve();
353 
354   auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
355   EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined());
356 }
357 
358 /// Test that we propagate information through function returns.
359 ///
360 /// define internal i64 @f(i1* %cond) {
361 /// if:
362 ///   %0 = load i1, i1* %cond
363 ///   br i1 %0, label %then, label %else
364 ///
365 /// then:
366 ///   ret i64 1
367 ///
368 /// else:
369 ///   ret i64 1
370 /// }
371 ///
372 /// For this test, we initially mark "f" executable, and the solver computes
373 /// the return value of the function as constant.
374 TEST_F(SparsePropagationTest, FunctionDefined) {
375   Function *F =
376       Function::Create(FunctionType::get(Builder.getInt64Ty(),
377                                          {PointerType::get(Context, 0)}, false),
378                        GlobalValue::InternalLinkage, "f", &M);
379   BasicBlock *If = BasicBlock::Create(Context, "if", F);
380   BasicBlock *Then = BasicBlock::Create(Context, "then", F);
381   BasicBlock *Else = BasicBlock::Create(Context, "else", F);
382   F->arg_begin()->setName("cond");
383   Builder.SetInsertPoint(If);
384   LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin());
385   Builder.CreateCondBr(Cond, Then, Else);
386   Builder.SetInsertPoint(Then);
387   Builder.CreateRet(Builder.getInt64(1));
388   Builder.SetInsertPoint(Else);
389   Builder.CreateRet(Builder.getInt64(1));
390 
391   Solver.MarkBlockExecutable(If);
392   Solver.Solve();
393 
394   auto RetF = TestLatticeKey(F, IPOGrouping::Return);
395   EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant());
396 }
397 
398 /// Test that we propagate information through function returns.
399 ///
400 /// define internal i64 @f(i1* %cond) {
401 /// if:
402 ///   %0 = load i1, i1* %cond
403 ///   br i1 %0, label %then, label %else
404 ///
405 /// then:
406 ///   ret i64 0
407 ///
408 /// else:
409 ///   ret i64 1
410 /// }
411 ///
412 /// For this test, we initially mark "f" executable, and the solver computes
413 /// the return value of the function as overdefined.
414 TEST_F(SparsePropagationTest, FunctionOverDefined) {
415   Function *F =
416       Function::Create(FunctionType::get(Builder.getInt64Ty(),
417                                          {PointerType::get(Context, 0)}, false),
418                        GlobalValue::InternalLinkage, "f", &M);
419   BasicBlock *If = BasicBlock::Create(Context, "if", F);
420   BasicBlock *Then = BasicBlock::Create(Context, "then", F);
421   BasicBlock *Else = BasicBlock::Create(Context, "else", F);
422   F->arg_begin()->setName("cond");
423   Builder.SetInsertPoint(If);
424   LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin());
425   Builder.CreateCondBr(Cond, Then, Else);
426   Builder.SetInsertPoint(Then);
427   Builder.CreateRet(Builder.getInt64(0));
428   Builder.SetInsertPoint(Else);
429   Builder.CreateRet(Builder.getInt64(1));
430 
431   Solver.MarkBlockExecutable(If);
432   Solver.Solve();
433 
434   auto RetF = TestLatticeKey(F, IPOGrouping::Return);
435   EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined());
436 }
437 
438 /// Test that we propagate information through arguments.
439 ///
440 /// define internal void @f() {
441 ///   call void @g(i64 0, i64 1)
442 ///   call void @g(i64 1, i64 1)
443 ///   ret void
444 /// }
445 ///
446 /// define internal void @g(i64 %a, i64 %b) {
447 ///   ret void
448 /// }
449 ///
450 /// For this test, we initially mark "f" executable, and the solver discovers
451 /// "g" because of the calls in "f". The solver computes the state of argument
452 /// "a" as overdefined and the state of "b" as constant.
453 ///
454 /// In addition, this test demonstrates that ComputeInstructionState can alter
455 /// the state of multiple lattice values, in addition to the one associated
456 /// with the instruction definition. Each call instruction in this test updates
457 /// the state of arguments "a" and "b".
458 TEST_F(SparsePropagationTest, ComputeInstructionState) {
459   Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
460                                  GlobalValue::InternalLinkage, "f", &M);
461   Function *G = Function::Create(
462       FunctionType::get(Builder.getVoidTy(),
463                         {Builder.getInt64Ty(), Builder.getInt64Ty()}, false),
464       GlobalValue::InternalLinkage, "g", &M);
465   Argument *A = G->arg_begin();
466   Argument *B = std::next(G->arg_begin());
467   A->setName("a");
468   B->setName("b");
469   BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
470   BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
471   Builder.SetInsertPoint(FEntry);
472   Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)});
473   Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)});
474   Builder.CreateRetVoid();
475   Builder.SetInsertPoint(GEntry);
476   Builder.CreateRetVoid();
477 
478   Solver.MarkBlockExecutable(FEntry);
479   Solver.Solve();
480 
481   auto RegA = TestLatticeKey(A, IPOGrouping::Register);
482   auto RegB = TestLatticeKey(B, IPOGrouping::Register);
483   EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined());
484   EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant());
485 }
486 
487 /// Test that we can handle exceptional terminator instructions.
488 ///
489 /// declare internal void @p()
490 ///
491 /// declare internal void @g()
492 ///
493 /// define internal void @f() personality ptr @p {
494 /// entry:
495 ///   invoke void @g()
496 ///           to label %exit unwind label %catch.pad
497 ///
498 /// catch.pad:
499 ///   %0 = catchswitch within none [label %catch.body] unwind to caller
500 ///
501 /// catch.body:
502 ///   %1 = catchpad within %0 []
503 ///   catchret from %1 to label %exit
504 ///
505 /// exit:
506 ///   ret void
507 /// }
508 ///
509 /// For this test, we initially mark the entry block executable. The solver
510 /// then discovers the rest of the blocks in the function are executable.
511 TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) {
512   Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
513                                  GlobalValue::InternalLinkage, "p", &M);
514   Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
515                                  GlobalValue::InternalLinkage, "g", &M);
516   Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
517                                  GlobalValue::InternalLinkage, "f", &M);
518   F->setPersonalityFn(P);
519   BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
520   BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F);
521   BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F);
522   BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
523   Builder.SetInsertPoint(Entry);
524   Builder.CreateInvoke(G, Exit, Pad);
525   Builder.SetInsertPoint(Pad);
526   CatchSwitchInst *CatchSwitch =
527       Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1);
528   CatchSwitch->addHandler(Body);
529   Builder.SetInsertPoint(Body);
530   CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {});
531   Builder.CreateCatchRet(CatchPad, Exit);
532   Builder.SetInsertPoint(Exit);
533   Builder.CreateRetVoid();
534 
535   Solver.MarkBlockExecutable(Entry);
536   Solver.Solve();
537 
538   EXPECT_TRUE(Solver.isBlockExecutable(Pad));
539   EXPECT_TRUE(Solver.isBlockExecutable(Body));
540   EXPECT_TRUE(Solver.isBlockExecutable(Exit));
541 }
542