1 //===- SparsePropagation.cpp - Unit tests for the generic solver ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/SparsePropagation.h" 10 #include "llvm/ADT/PointerIntPair.h" 11 #include "llvm/IR/IRBuilder.h" 12 #include "llvm/IR/Module.h" 13 #include "gtest/gtest.h" 14 using namespace llvm; 15 16 namespace { 17 /// To enable interprocedural analysis, we assign LLVM values to the following 18 /// groups. The register group represents SSA registers, the return group 19 /// represents the return values of functions, and the memory group represents 20 /// in-memory values. An LLVM Value can technically be in more than one group. 21 /// It's necessary to distinguish these groups so we can, for example, track a 22 /// global variable separately from the value stored at its location. 23 enum class IPOGrouping { Register, Return, Memory }; 24 25 /// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings. 26 /// The PointerIntPair header provides a DenseMapInfo specialization, so using 27 /// these as LatticeKeys is fine. 28 using TestLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>; 29 } // namespace 30 31 namespace llvm { 32 /// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver 33 /// must translate between LatticeKeys and LLVM Values when adding Values to 34 /// its work list and inspecting the state of control-flow related values. 35 template <> struct LatticeKeyInfo<TestLatticeKey> { 36 static inline Value *getValueFromLatticeKey(TestLatticeKey Key) { 37 return Key.getPointer(); 38 } 39 static inline TestLatticeKey getLatticeKeyFromValue(Value *V) { 40 return TestLatticeKey(V, IPOGrouping::Register); 41 } 42 }; 43 } // namespace llvm 44 45 namespace { 46 /// This class defines a simple test lattice value that could be used for 47 /// solving problems similar to constant propagation. The value is maintained 48 /// as a PointerIntPair. 49 class TestLatticeVal { 50 public: 51 /// The states of the lattices value. Only the ConstantVal state is 52 /// interesting; the rest are special states used by the generic solver. The 53 /// UntrackedVal state differs from the other three in that the generic 54 /// solver uses it to avoid doing unnecessary work. In particular, when a 55 /// value moves to the UntrackedVal state, it's users are not notified. 56 enum TestLatticeStateTy { 57 UndefinedVal, 58 ConstantVal, 59 OverdefinedVal, 60 UntrackedVal 61 }; 62 63 TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {} 64 TestLatticeVal(Constant *C, TestLatticeStateTy State) 65 : LatticeVal(C, State) {} 66 67 /// Return true if this lattice value is in the Constant state. This is used 68 /// for checking the solver results. 69 bool isConstant() const { return LatticeVal.getInt() == ConstantVal; } 70 71 /// Return true if this lattice value is in the Overdefined state. This is 72 /// used for checking the solver results. 73 bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; } 74 75 bool operator==(const TestLatticeVal &RHS) const { 76 return LatticeVal == RHS.LatticeVal; 77 } 78 79 bool operator!=(const TestLatticeVal &RHS) const { 80 return LatticeVal != RHS.LatticeVal; 81 } 82 83 private: 84 /// A simple lattice value type for problems similar to constant propagation. 85 /// It holds the constant value and the lattice state. 86 PointerIntPair<const Constant *, 2, TestLatticeStateTy> LatticeVal; 87 }; 88 89 /// This class defines a simple test lattice function that could be used for 90 /// solving problems similar to constant propagation. The test lattice differs 91 /// from a "real" lattice in a few ways. First, it initializes all return 92 /// values, values stored in global variables, and arguments in the undefined 93 /// state. This means that there are no limitations on what we can track 94 /// interprocedurally. For simplicity, all global values in the tests will be 95 /// given internal linkage, since this is not something this lattice function 96 /// tracks. Second, it only handles the few instructions necessary for the 97 /// tests. 98 class TestLatticeFunc 99 : public AbstractLatticeFunction<TestLatticeKey, TestLatticeVal> { 100 public: 101 /// Construct a new test lattice function with special values for the 102 /// Undefined, Overdefined, and Untracked states. 103 TestLatticeFunc() 104 : AbstractLatticeFunction( 105 TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal), 106 TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal), 107 TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {} 108 109 /// Compute and return a TestLatticeVal for the given TestLatticeKey. For the 110 /// test analysis, a LatticeKey will begin in the undefined state, unless it 111 /// represents an LLVM Constant in the register grouping. 112 TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override { 113 if (Key.getInt() == IPOGrouping::Register) 114 if (auto *C = dyn_cast<Constant>(Key.getPointer())) 115 return TestLatticeVal(C, TestLatticeVal::ConstantVal); 116 return getUndefVal(); 117 } 118 119 /// Merge the two given lattice values. This merge should be equivalent to 120 /// what is done for constant propagation. That is, the resulting lattice 121 /// value is constant only if the two given lattice values are constant and 122 /// hold the same value. 123 TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override { 124 if (X == getUntrackedVal() || Y == getUntrackedVal()) 125 return getUntrackedVal(); 126 if (X == getOverdefinedVal() || Y == getOverdefinedVal()) 127 return getOverdefinedVal(); 128 if (X == getUndefVal() && Y == getUndefVal()) 129 return getUndefVal(); 130 if (X == getUndefVal()) 131 return Y; 132 if (Y == getUndefVal()) 133 return X; 134 if (X == Y) 135 return X; 136 return getOverdefinedVal(); 137 } 138 139 /// Compute the lattice values that change as a result of executing the given 140 /// instruction. We only handle the few instructions needed for the tests. 141 void ComputeInstructionState( 142 Instruction &I, 143 SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues, 144 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) override { 145 switch (I.getOpcode()) { 146 case Instruction::Call: 147 return visitCallBase(cast<CallBase>(I), ChangedValues, SS); 148 case Instruction::Ret: 149 return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS); 150 case Instruction::Store: 151 return visitStore(*cast<StoreInst>(&I), ChangedValues, SS); 152 default: 153 return visitInst(I, ChangedValues, SS); 154 } 155 } 156 157 private: 158 /// Handle call sites. The state of a called function's argument is the merge 159 /// of the current formal argument state with the call site's corresponding 160 /// actual argument state. The call site state is the merge of the call site 161 /// state with the returned value state of the called function. 162 void visitCallBase(CallBase &I, 163 SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues, 164 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { 165 Function *F = I.getCalledFunction(); 166 auto RegI = TestLatticeKey(&I, IPOGrouping::Register); 167 if (!F) { 168 ChangedValues[RegI] = getOverdefinedVal(); 169 return; 170 } 171 SS.MarkBlockExecutable(&F->front()); 172 for (Argument &A : F->args()) { 173 auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register); 174 auto RegActual = 175 TestLatticeKey(I.getArgOperand(A.getArgNo()), IPOGrouping::Register); 176 ChangedValues[RegFormal] = 177 MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual)); 178 } 179 auto RetF = TestLatticeKey(F, IPOGrouping::Return); 180 ChangedValues[RegI] = 181 MergeValues(SS.getValueState(RegI), SS.getValueState(RetF)); 182 } 183 184 /// Handle return instructions. The function's return state is the merge of 185 /// the returned value state and the function's current return state. 186 void visitReturn(ReturnInst &I, 187 SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues, 188 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { 189 Function *F = I.getParent()->getParent(); 190 if (F->getReturnType()->isVoidTy()) 191 return; 192 auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register); 193 auto RetF = TestLatticeKey(F, IPOGrouping::Return); 194 ChangedValues[RetF] = 195 MergeValues(SS.getValueState(RegR), SS.getValueState(RetF)); 196 } 197 198 /// Handle store instructions. If the pointer operand of the store is a 199 /// global variable, we attempt to track the value. The global variable state 200 /// is the merge of the stored value state with the current global variable 201 /// state. 202 void visitStore(StoreInst &I, 203 SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues, 204 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { 205 auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand()); 206 if (!GV) 207 return; 208 auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register); 209 auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory); 210 ChangedValues[MemPtr] = 211 MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr)); 212 } 213 214 /// Handle all other instructions. All other instructions are marked 215 /// overdefined. 216 void visitInst(Instruction &I, 217 SmallDenseMap<TestLatticeKey, TestLatticeVal, 16> &ChangedValues, 218 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { 219 auto RegI = TestLatticeKey(&I, IPOGrouping::Register); 220 ChangedValues[RegI] = getOverdefinedVal(); 221 } 222 }; 223 224 /// This class defines the common data used for all of the tests. The tests 225 /// should add code to the module and then run the solver. 226 class SparsePropagationTest : public testing::Test { 227 protected: 228 LLVMContext Context; 229 Module M; 230 IRBuilder<> Builder; 231 TestLatticeFunc Lattice; 232 SparseSolver<TestLatticeKey, TestLatticeVal> Solver; 233 234 public: 235 SparsePropagationTest() 236 : M("", Context), Builder(Context), Solver(&Lattice) {} 237 }; 238 } // namespace 239 240 /// Test that we mark discovered functions executable. 241 /// 242 /// define internal void @f() { 243 /// call void @g() 244 /// ret void 245 /// } 246 /// 247 /// define internal void @g() { 248 /// call void @f() 249 /// ret void 250 /// } 251 /// 252 /// For this test, we initially mark "f" executable, and the solver discovers 253 /// "g" because of the call in "f". The mutually recursive call in "g" also 254 /// tests that we don't add a block to the basic block work list if it is 255 /// already executable. Doing so would put the solver into an infinite loop. 256 TEST_F(SparsePropagationTest, MarkBlockExecutable) { 257 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 258 GlobalValue::InternalLinkage, "f", &M); 259 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 260 GlobalValue::InternalLinkage, "g", &M); 261 BasicBlock *FEntry = BasicBlock::Create(Context, "", F); 262 BasicBlock *GEntry = BasicBlock::Create(Context, "", G); 263 Builder.SetInsertPoint(FEntry); 264 Builder.CreateCall(G); 265 Builder.CreateRetVoid(); 266 Builder.SetInsertPoint(GEntry); 267 Builder.CreateCall(F); 268 Builder.CreateRetVoid(); 269 270 Solver.MarkBlockExecutable(FEntry); 271 Solver.Solve(); 272 273 EXPECT_TRUE(Solver.isBlockExecutable(GEntry)); 274 } 275 276 /// Test that we propagate information through global variables. 277 /// 278 /// @gv = internal global i64 279 /// 280 /// define internal void @f() { 281 /// store i64 1, i64* @gv 282 /// ret void 283 /// } 284 /// 285 /// define internal void @g() { 286 /// store i64 1, i64* @gv 287 /// ret void 288 /// } 289 /// 290 /// For this test, we initially mark both "f" and "g" executable, and the 291 /// solver computes the lattice state of the global variable as constant. 292 TEST_F(SparsePropagationTest, GlobalVariableConstant) { 293 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 294 GlobalValue::InternalLinkage, "f", &M); 295 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 296 GlobalValue::InternalLinkage, "g", &M); 297 GlobalVariable *GV = 298 new GlobalVariable(M, Builder.getInt64Ty(), false, 299 GlobalValue::InternalLinkage, nullptr, "gv"); 300 BasicBlock *FEntry = BasicBlock::Create(Context, "", F); 301 BasicBlock *GEntry = BasicBlock::Create(Context, "", G); 302 Builder.SetInsertPoint(FEntry); 303 Builder.CreateStore(Builder.getInt64(1), GV); 304 Builder.CreateRetVoid(); 305 Builder.SetInsertPoint(GEntry); 306 Builder.CreateStore(Builder.getInt64(1), GV); 307 Builder.CreateRetVoid(); 308 309 Solver.MarkBlockExecutable(FEntry); 310 Solver.MarkBlockExecutable(GEntry); 311 Solver.Solve(); 312 313 auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory); 314 EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant()); 315 } 316 317 /// Test that we propagate information through global variables. 318 /// 319 /// @gv = internal global i64 320 /// 321 /// define internal void @f() { 322 /// store i64 0, i64* @gv 323 /// ret void 324 /// } 325 /// 326 /// define internal void @g() { 327 /// store i64 1, i64* @gv 328 /// ret void 329 /// } 330 /// 331 /// For this test, we initially mark both "f" and "g" executable, and the 332 /// solver computes the lattice state of the global variable as overdefined. 333 TEST_F(SparsePropagationTest, GlobalVariableOverDefined) { 334 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 335 GlobalValue::InternalLinkage, "f", &M); 336 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 337 GlobalValue::InternalLinkage, "g", &M); 338 GlobalVariable *GV = 339 new GlobalVariable(M, Builder.getInt64Ty(), false, 340 GlobalValue::InternalLinkage, nullptr, "gv"); 341 BasicBlock *FEntry = BasicBlock::Create(Context, "", F); 342 BasicBlock *GEntry = BasicBlock::Create(Context, "", G); 343 Builder.SetInsertPoint(FEntry); 344 Builder.CreateStore(Builder.getInt64(0), GV); 345 Builder.CreateRetVoid(); 346 Builder.SetInsertPoint(GEntry); 347 Builder.CreateStore(Builder.getInt64(1), GV); 348 Builder.CreateRetVoid(); 349 350 Solver.MarkBlockExecutable(FEntry); 351 Solver.MarkBlockExecutable(GEntry); 352 Solver.Solve(); 353 354 auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory); 355 EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined()); 356 } 357 358 /// Test that we propagate information through function returns. 359 /// 360 /// define internal i64 @f(i1* %cond) { 361 /// if: 362 /// %0 = load i1, i1* %cond 363 /// br i1 %0, label %then, label %else 364 /// 365 /// then: 366 /// ret i64 1 367 /// 368 /// else: 369 /// ret i64 1 370 /// } 371 /// 372 /// For this test, we initially mark "f" executable, and the solver computes 373 /// the return value of the function as constant. 374 TEST_F(SparsePropagationTest, FunctionDefined) { 375 Function *F = 376 Function::Create(FunctionType::get(Builder.getInt64Ty(), 377 {PointerType::get(Context, 0)}, false), 378 GlobalValue::InternalLinkage, "f", &M); 379 BasicBlock *If = BasicBlock::Create(Context, "if", F); 380 BasicBlock *Then = BasicBlock::Create(Context, "then", F); 381 BasicBlock *Else = BasicBlock::Create(Context, "else", F); 382 F->arg_begin()->setName("cond"); 383 Builder.SetInsertPoint(If); 384 LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin()); 385 Builder.CreateCondBr(Cond, Then, Else); 386 Builder.SetInsertPoint(Then); 387 Builder.CreateRet(Builder.getInt64(1)); 388 Builder.SetInsertPoint(Else); 389 Builder.CreateRet(Builder.getInt64(1)); 390 391 Solver.MarkBlockExecutable(If); 392 Solver.Solve(); 393 394 auto RetF = TestLatticeKey(F, IPOGrouping::Return); 395 EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant()); 396 } 397 398 /// Test that we propagate information through function returns. 399 /// 400 /// define internal i64 @f(i1* %cond) { 401 /// if: 402 /// %0 = load i1, i1* %cond 403 /// br i1 %0, label %then, label %else 404 /// 405 /// then: 406 /// ret i64 0 407 /// 408 /// else: 409 /// ret i64 1 410 /// } 411 /// 412 /// For this test, we initially mark "f" executable, and the solver computes 413 /// the return value of the function as overdefined. 414 TEST_F(SparsePropagationTest, FunctionOverDefined) { 415 Function *F = 416 Function::Create(FunctionType::get(Builder.getInt64Ty(), 417 {PointerType::get(Context, 0)}, false), 418 GlobalValue::InternalLinkage, "f", &M); 419 BasicBlock *If = BasicBlock::Create(Context, "if", F); 420 BasicBlock *Then = BasicBlock::Create(Context, "then", F); 421 BasicBlock *Else = BasicBlock::Create(Context, "else", F); 422 F->arg_begin()->setName("cond"); 423 Builder.SetInsertPoint(If); 424 LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin()); 425 Builder.CreateCondBr(Cond, Then, Else); 426 Builder.SetInsertPoint(Then); 427 Builder.CreateRet(Builder.getInt64(0)); 428 Builder.SetInsertPoint(Else); 429 Builder.CreateRet(Builder.getInt64(1)); 430 431 Solver.MarkBlockExecutable(If); 432 Solver.Solve(); 433 434 auto RetF = TestLatticeKey(F, IPOGrouping::Return); 435 EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined()); 436 } 437 438 /// Test that we propagate information through arguments. 439 /// 440 /// define internal void @f() { 441 /// call void @g(i64 0, i64 1) 442 /// call void @g(i64 1, i64 1) 443 /// ret void 444 /// } 445 /// 446 /// define internal void @g(i64 %a, i64 %b) { 447 /// ret void 448 /// } 449 /// 450 /// For this test, we initially mark "f" executable, and the solver discovers 451 /// "g" because of the calls in "f". The solver computes the state of argument 452 /// "a" as overdefined and the state of "b" as constant. 453 /// 454 /// In addition, this test demonstrates that ComputeInstructionState can alter 455 /// the state of multiple lattice values, in addition to the one associated 456 /// with the instruction definition. Each call instruction in this test updates 457 /// the state of arguments "a" and "b". 458 TEST_F(SparsePropagationTest, ComputeInstructionState) { 459 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 460 GlobalValue::InternalLinkage, "f", &M); 461 Function *G = Function::Create( 462 FunctionType::get(Builder.getVoidTy(), 463 {Builder.getInt64Ty(), Builder.getInt64Ty()}, false), 464 GlobalValue::InternalLinkage, "g", &M); 465 Argument *A = G->arg_begin(); 466 Argument *B = std::next(G->arg_begin()); 467 A->setName("a"); 468 B->setName("b"); 469 BasicBlock *FEntry = BasicBlock::Create(Context, "", F); 470 BasicBlock *GEntry = BasicBlock::Create(Context, "", G); 471 Builder.SetInsertPoint(FEntry); 472 Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)}); 473 Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)}); 474 Builder.CreateRetVoid(); 475 Builder.SetInsertPoint(GEntry); 476 Builder.CreateRetVoid(); 477 478 Solver.MarkBlockExecutable(FEntry); 479 Solver.Solve(); 480 481 auto RegA = TestLatticeKey(A, IPOGrouping::Register); 482 auto RegB = TestLatticeKey(B, IPOGrouping::Register); 483 EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined()); 484 EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant()); 485 } 486 487 /// Test that we can handle exceptional terminator instructions. 488 /// 489 /// declare internal void @p() 490 /// 491 /// declare internal void @g() 492 /// 493 /// define internal void @f() personality ptr @p { 494 /// entry: 495 /// invoke void @g() 496 /// to label %exit unwind label %catch.pad 497 /// 498 /// catch.pad: 499 /// %0 = catchswitch within none [label %catch.body] unwind to caller 500 /// 501 /// catch.body: 502 /// %1 = catchpad within %0 [] 503 /// catchret from %1 to label %exit 504 /// 505 /// exit: 506 /// ret void 507 /// } 508 /// 509 /// For this test, we initially mark the entry block executable. The solver 510 /// then discovers the rest of the blocks in the function are executable. 511 TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) { 512 Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 513 GlobalValue::InternalLinkage, "p", &M); 514 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 515 GlobalValue::InternalLinkage, "g", &M); 516 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 517 GlobalValue::InternalLinkage, "f", &M); 518 F->setPersonalityFn(P); 519 BasicBlock *Entry = BasicBlock::Create(Context, "entry", F); 520 BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F); 521 BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F); 522 BasicBlock *Exit = BasicBlock::Create(Context, "exit", F); 523 Builder.SetInsertPoint(Entry); 524 Builder.CreateInvoke(G, Exit, Pad); 525 Builder.SetInsertPoint(Pad); 526 CatchSwitchInst *CatchSwitch = 527 Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1); 528 CatchSwitch->addHandler(Body); 529 Builder.SetInsertPoint(Body); 530 CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {}); 531 Builder.CreateCatchRet(CatchPad, Exit); 532 Builder.SetInsertPoint(Exit); 533 Builder.CreateRetVoid(); 534 535 Solver.MarkBlockExecutable(Entry); 536 Solver.Solve(); 537 538 EXPECT_TRUE(Solver.isBlockExecutable(Pad)); 539 EXPECT_TRUE(Solver.isBlockExecutable(Body)); 540 EXPECT_TRUE(Solver.isBlockExecutable(Exit)); 541 } 542