1 //===- SparsePropagation.cpp - Unit tests for the generic solver ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/SparsePropagation.h" 10 #include "llvm/ADT/PointerIntPair.h" 11 #include "llvm/IR/IRBuilder.h" 12 #include "gtest/gtest.h" 13 using namespace llvm; 14 15 namespace { 16 /// To enable interprocedural analysis, we assign LLVM values to the following 17 /// groups. The register group represents SSA registers, the return group 18 /// represents the return values of functions, and the memory group represents 19 /// in-memory values. An LLVM Value can technically be in more than one group. 20 /// It's necessary to distinguish these groups so we can, for example, track a 21 /// global variable separately from the value stored at its location. 22 enum class IPOGrouping { Register, Return, Memory }; 23 24 /// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings. 25 /// The PointerIntPair header provides a DenseMapInfo specialization, so using 26 /// these as LatticeKeys is fine. 27 using TestLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>; 28 } // namespace 29 30 namespace llvm { 31 /// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver 32 /// must translate between LatticeKeys and LLVM Values when adding Values to 33 /// its work list and inspecting the state of control-flow related values. 34 template <> struct LatticeKeyInfo<TestLatticeKey> { 35 static inline Value *getValueFromLatticeKey(TestLatticeKey Key) { 36 return Key.getPointer(); 37 } 38 static inline TestLatticeKey getLatticeKeyFromValue(Value *V) { 39 return TestLatticeKey(V, IPOGrouping::Register); 40 } 41 }; 42 } // namespace llvm 43 44 namespace { 45 /// This class defines a simple test lattice value that could be used for 46 /// solving problems similar to constant propagation. The value is maintained 47 /// as a PointerIntPair. 48 class TestLatticeVal { 49 public: 50 /// The states of the lattices value. Only the ConstantVal state is 51 /// interesting; the rest are special states used by the generic solver. The 52 /// UntrackedVal state differs from the other three in that the generic 53 /// solver uses it to avoid doing unnecessary work. In particular, when a 54 /// value moves to the UntrackedVal state, it's users are not notified. 55 enum TestLatticeStateTy { 56 UndefinedVal, 57 ConstantVal, 58 OverdefinedVal, 59 UntrackedVal 60 }; 61 62 TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {} 63 TestLatticeVal(Constant *C, TestLatticeStateTy State) 64 : LatticeVal(C, State) {} 65 66 /// Return true if this lattice value is in the Constant state. This is used 67 /// for checking the solver results. 68 bool isConstant() const { return LatticeVal.getInt() == ConstantVal; } 69 70 /// Return true if this lattice value is in the Overdefined state. This is 71 /// used for checking the solver results. 72 bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; } 73 74 bool operator==(const TestLatticeVal &RHS) const { 75 return LatticeVal == RHS.LatticeVal; 76 } 77 78 bool operator!=(const TestLatticeVal &RHS) const { 79 return LatticeVal != RHS.LatticeVal; 80 } 81 82 private: 83 /// A simple lattice value type for problems similar to constant propagation. 84 /// It holds the constant value and the lattice state. 85 PointerIntPair<const Constant *, 2, TestLatticeStateTy> LatticeVal; 86 }; 87 88 /// This class defines a simple test lattice function that could be used for 89 /// solving problems similar to constant propagation. The test lattice differs 90 /// from a "real" lattice in a few ways. First, it initializes all return 91 /// values, values stored in global variables, and arguments in the undefined 92 /// state. This means that there are no limitations on what we can track 93 /// interprocedurally. For simplicity, all global values in the tests will be 94 /// given internal linkage, since this is not something this lattice function 95 /// tracks. Second, it only handles the few instructions necessary for the 96 /// tests. 97 class TestLatticeFunc 98 : public AbstractLatticeFunction<TestLatticeKey, TestLatticeVal> { 99 public: 100 /// Construct a new test lattice function with special values for the 101 /// Undefined, Overdefined, and Untracked states. 102 TestLatticeFunc() 103 : AbstractLatticeFunction( 104 TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal), 105 TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal), 106 TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {} 107 108 /// Compute and return a TestLatticeVal for the given TestLatticeKey. For the 109 /// test analysis, a LatticeKey will begin in the undefined state, unless it 110 /// represents an LLVM Constant in the register grouping. 111 TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override { 112 if (Key.getInt() == IPOGrouping::Register) 113 if (auto *C = dyn_cast<Constant>(Key.getPointer())) 114 return TestLatticeVal(C, TestLatticeVal::ConstantVal); 115 return getUndefVal(); 116 } 117 118 /// Merge the two given lattice values. This merge should be equivalent to 119 /// what is done for constant propagation. That is, the resulting lattice 120 /// value is constant only if the two given lattice values are constant and 121 /// hold the same value. 122 TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override { 123 if (X == getUntrackedVal() || Y == getUntrackedVal()) 124 return getUntrackedVal(); 125 if (X == getOverdefinedVal() || Y == getOverdefinedVal()) 126 return getOverdefinedVal(); 127 if (X == getUndefVal() && Y == getUndefVal()) 128 return getUndefVal(); 129 if (X == getUndefVal()) 130 return Y; 131 if (Y == getUndefVal()) 132 return X; 133 if (X == Y) 134 return X; 135 return getOverdefinedVal(); 136 } 137 138 /// Compute the lattice values that change as a result of executing the given 139 /// instruction. We only handle the few instructions needed for the tests. 140 void ComputeInstructionState( 141 Instruction &I, DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues, 142 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) override { 143 switch (I.getOpcode()) { 144 case Instruction::Call: 145 return visitCallBase(cast<CallBase>(I), ChangedValues, SS); 146 case Instruction::Ret: 147 return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS); 148 case Instruction::Store: 149 return visitStore(*cast<StoreInst>(&I), ChangedValues, SS); 150 default: 151 return visitInst(I, ChangedValues, SS); 152 } 153 } 154 155 private: 156 /// Handle call sites. The state of a called function's argument is the merge 157 /// of the current formal argument state with the call site's corresponding 158 /// actual argument state. The call site state is the merge of the call site 159 /// state with the returned value state of the called function. 160 void visitCallBase(CallBase &I, 161 DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues, 162 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { 163 Function *F = I.getCalledFunction(); 164 auto RegI = TestLatticeKey(&I, IPOGrouping::Register); 165 if (!F) { 166 ChangedValues[RegI] = getOverdefinedVal(); 167 return; 168 } 169 SS.MarkBlockExecutable(&F->front()); 170 for (Argument &A : F->args()) { 171 auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register); 172 auto RegActual = 173 TestLatticeKey(I.getArgOperand(A.getArgNo()), IPOGrouping::Register); 174 ChangedValues[RegFormal] = 175 MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual)); 176 } 177 auto RetF = TestLatticeKey(F, IPOGrouping::Return); 178 ChangedValues[RegI] = 179 MergeValues(SS.getValueState(RegI), SS.getValueState(RetF)); 180 } 181 182 /// Handle return instructions. The function's return state is the merge of 183 /// the returned value state and the function's current return state. 184 void visitReturn(ReturnInst &I, 185 DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues, 186 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { 187 Function *F = I.getParent()->getParent(); 188 if (F->getReturnType()->isVoidTy()) 189 return; 190 auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register); 191 auto RetF = TestLatticeKey(F, IPOGrouping::Return); 192 ChangedValues[RetF] = 193 MergeValues(SS.getValueState(RegR), SS.getValueState(RetF)); 194 } 195 196 /// Handle store instructions. If the pointer operand of the store is a 197 /// global variable, we attempt to track the value. The global variable state 198 /// is the merge of the stored value state with the current global variable 199 /// state. 200 void visitStore(StoreInst &I, 201 DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues, 202 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { 203 auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand()); 204 if (!GV) 205 return; 206 auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register); 207 auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory); 208 ChangedValues[MemPtr] = 209 MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr)); 210 } 211 212 /// Handle all other instructions. All other instructions are marked 213 /// overdefined. 214 void visitInst(Instruction &I, 215 DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues, 216 SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { 217 auto RegI = TestLatticeKey(&I, IPOGrouping::Register); 218 ChangedValues[RegI] = getOverdefinedVal(); 219 } 220 }; 221 222 /// This class defines the common data used for all of the tests. The tests 223 /// should add code to the module and then run the solver. 224 class SparsePropagationTest : public testing::Test { 225 protected: 226 LLVMContext Context; 227 Module M; 228 IRBuilder<> Builder; 229 TestLatticeFunc Lattice; 230 SparseSolver<TestLatticeKey, TestLatticeVal> Solver; 231 232 public: 233 SparsePropagationTest() 234 : M("", Context), Builder(Context), Solver(&Lattice) {} 235 }; 236 } // namespace 237 238 /// Test that we mark discovered functions executable. 239 /// 240 /// define internal void @f() { 241 /// call void @g() 242 /// ret void 243 /// } 244 /// 245 /// define internal void @g() { 246 /// call void @f() 247 /// ret void 248 /// } 249 /// 250 /// For this test, we initially mark "f" executable, and the solver discovers 251 /// "g" because of the call in "f". The mutually recursive call in "g" also 252 /// tests that we don't add a block to the basic block work list if it is 253 /// already executable. Doing so would put the solver into an infinite loop. 254 TEST_F(SparsePropagationTest, MarkBlockExecutable) { 255 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 256 GlobalValue::InternalLinkage, "f", &M); 257 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 258 GlobalValue::InternalLinkage, "g", &M); 259 BasicBlock *FEntry = BasicBlock::Create(Context, "", F); 260 BasicBlock *GEntry = BasicBlock::Create(Context, "", G); 261 Builder.SetInsertPoint(FEntry); 262 Builder.CreateCall(G); 263 Builder.CreateRetVoid(); 264 Builder.SetInsertPoint(GEntry); 265 Builder.CreateCall(F); 266 Builder.CreateRetVoid(); 267 268 Solver.MarkBlockExecutable(FEntry); 269 Solver.Solve(); 270 271 EXPECT_TRUE(Solver.isBlockExecutable(GEntry)); 272 } 273 274 /// Test that we propagate information through global variables. 275 /// 276 /// @gv = internal global i64 277 /// 278 /// define internal void @f() { 279 /// store i64 1, i64* @gv 280 /// ret void 281 /// } 282 /// 283 /// define internal void @g() { 284 /// store i64 1, i64* @gv 285 /// ret void 286 /// } 287 /// 288 /// For this test, we initially mark both "f" and "g" executable, and the 289 /// solver computes the lattice state of the global variable as constant. 290 TEST_F(SparsePropagationTest, GlobalVariableConstant) { 291 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 292 GlobalValue::InternalLinkage, "f", &M); 293 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 294 GlobalValue::InternalLinkage, "g", &M); 295 GlobalVariable *GV = 296 new GlobalVariable(M, Builder.getInt64Ty(), false, 297 GlobalValue::InternalLinkage, nullptr, "gv"); 298 BasicBlock *FEntry = BasicBlock::Create(Context, "", F); 299 BasicBlock *GEntry = BasicBlock::Create(Context, "", G); 300 Builder.SetInsertPoint(FEntry); 301 Builder.CreateStore(Builder.getInt64(1), GV); 302 Builder.CreateRetVoid(); 303 Builder.SetInsertPoint(GEntry); 304 Builder.CreateStore(Builder.getInt64(1), GV); 305 Builder.CreateRetVoid(); 306 307 Solver.MarkBlockExecutable(FEntry); 308 Solver.MarkBlockExecutable(GEntry); 309 Solver.Solve(); 310 311 auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory); 312 EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant()); 313 } 314 315 /// Test that we propagate information through global variables. 316 /// 317 /// @gv = internal global i64 318 /// 319 /// define internal void @f() { 320 /// store i64 0, i64* @gv 321 /// ret void 322 /// } 323 /// 324 /// define internal void @g() { 325 /// store i64 1, i64* @gv 326 /// ret void 327 /// } 328 /// 329 /// For this test, we initially mark both "f" and "g" executable, and the 330 /// solver computes the lattice state of the global variable as overdefined. 331 TEST_F(SparsePropagationTest, GlobalVariableOverDefined) { 332 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 333 GlobalValue::InternalLinkage, "f", &M); 334 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 335 GlobalValue::InternalLinkage, "g", &M); 336 GlobalVariable *GV = 337 new GlobalVariable(M, Builder.getInt64Ty(), false, 338 GlobalValue::InternalLinkage, nullptr, "gv"); 339 BasicBlock *FEntry = BasicBlock::Create(Context, "", F); 340 BasicBlock *GEntry = BasicBlock::Create(Context, "", G); 341 Builder.SetInsertPoint(FEntry); 342 Builder.CreateStore(Builder.getInt64(0), GV); 343 Builder.CreateRetVoid(); 344 Builder.SetInsertPoint(GEntry); 345 Builder.CreateStore(Builder.getInt64(1), GV); 346 Builder.CreateRetVoid(); 347 348 Solver.MarkBlockExecutable(FEntry); 349 Solver.MarkBlockExecutable(GEntry); 350 Solver.Solve(); 351 352 auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory); 353 EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined()); 354 } 355 356 /// Test that we propagate information through function returns. 357 /// 358 /// define internal i64 @f(i1* %cond) { 359 /// if: 360 /// %0 = load i1, i1* %cond 361 /// br i1 %0, label %then, label %else 362 /// 363 /// then: 364 /// ret i64 1 365 /// 366 /// else: 367 /// ret i64 1 368 /// } 369 /// 370 /// For this test, we initially mark "f" executable, and the solver computes 371 /// the return value of the function as constant. 372 TEST_F(SparsePropagationTest, FunctionDefined) { 373 Function *F = 374 Function::Create(FunctionType::get(Builder.getInt64Ty(), 375 {Type::getInt1PtrTy(Context)}, false), 376 GlobalValue::InternalLinkage, "f", &M); 377 BasicBlock *If = BasicBlock::Create(Context, "if", F); 378 BasicBlock *Then = BasicBlock::Create(Context, "then", F); 379 BasicBlock *Else = BasicBlock::Create(Context, "else", F); 380 F->arg_begin()->setName("cond"); 381 Builder.SetInsertPoint(If); 382 LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin()); 383 Builder.CreateCondBr(Cond, Then, Else); 384 Builder.SetInsertPoint(Then); 385 Builder.CreateRet(Builder.getInt64(1)); 386 Builder.SetInsertPoint(Else); 387 Builder.CreateRet(Builder.getInt64(1)); 388 389 Solver.MarkBlockExecutable(If); 390 Solver.Solve(); 391 392 auto RetF = TestLatticeKey(F, IPOGrouping::Return); 393 EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant()); 394 } 395 396 /// Test that we propagate information through function returns. 397 /// 398 /// define internal i64 @f(i1* %cond) { 399 /// if: 400 /// %0 = load i1, i1* %cond 401 /// br i1 %0, label %then, label %else 402 /// 403 /// then: 404 /// ret i64 0 405 /// 406 /// else: 407 /// ret i64 1 408 /// } 409 /// 410 /// For this test, we initially mark "f" executable, and the solver computes 411 /// the return value of the function as overdefined. 412 TEST_F(SparsePropagationTest, FunctionOverDefined) { 413 Function *F = 414 Function::Create(FunctionType::get(Builder.getInt64Ty(), 415 {Type::getInt1PtrTy(Context)}, false), 416 GlobalValue::InternalLinkage, "f", &M); 417 BasicBlock *If = BasicBlock::Create(Context, "if", F); 418 BasicBlock *Then = BasicBlock::Create(Context, "then", F); 419 BasicBlock *Else = BasicBlock::Create(Context, "else", F); 420 F->arg_begin()->setName("cond"); 421 Builder.SetInsertPoint(If); 422 LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin()); 423 Builder.CreateCondBr(Cond, Then, Else); 424 Builder.SetInsertPoint(Then); 425 Builder.CreateRet(Builder.getInt64(0)); 426 Builder.SetInsertPoint(Else); 427 Builder.CreateRet(Builder.getInt64(1)); 428 429 Solver.MarkBlockExecutable(If); 430 Solver.Solve(); 431 432 auto RetF = TestLatticeKey(F, IPOGrouping::Return); 433 EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined()); 434 } 435 436 /// Test that we propagate information through arguments. 437 /// 438 /// define internal void @f() { 439 /// call void @g(i64 0, i64 1) 440 /// call void @g(i64 1, i64 1) 441 /// ret void 442 /// } 443 /// 444 /// define internal void @g(i64 %a, i64 %b) { 445 /// ret void 446 /// } 447 /// 448 /// For this test, we initially mark "f" executable, and the solver discovers 449 /// "g" because of the calls in "f". The solver computes the state of argument 450 /// "a" as overdefined and the state of "b" as constant. 451 /// 452 /// In addition, this test demonstrates that ComputeInstructionState can alter 453 /// the state of multiple lattice values, in addition to the one associated 454 /// with the instruction definition. Each call instruction in this test updates 455 /// the state of arguments "a" and "b". 456 TEST_F(SparsePropagationTest, ComputeInstructionState) { 457 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 458 GlobalValue::InternalLinkage, "f", &M); 459 Function *G = Function::Create( 460 FunctionType::get(Builder.getVoidTy(), 461 {Builder.getInt64Ty(), Builder.getInt64Ty()}, false), 462 GlobalValue::InternalLinkage, "g", &M); 463 Argument *A = G->arg_begin(); 464 Argument *B = std::next(G->arg_begin()); 465 A->setName("a"); 466 B->setName("b"); 467 BasicBlock *FEntry = BasicBlock::Create(Context, "", F); 468 BasicBlock *GEntry = BasicBlock::Create(Context, "", G); 469 Builder.SetInsertPoint(FEntry); 470 Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)}); 471 Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)}); 472 Builder.CreateRetVoid(); 473 Builder.SetInsertPoint(GEntry); 474 Builder.CreateRetVoid(); 475 476 Solver.MarkBlockExecutable(FEntry); 477 Solver.Solve(); 478 479 auto RegA = TestLatticeKey(A, IPOGrouping::Register); 480 auto RegB = TestLatticeKey(B, IPOGrouping::Register); 481 EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined()); 482 EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant()); 483 } 484 485 /// Test that we can handle exceptional terminator instructions. 486 /// 487 /// declare internal void @p() 488 /// 489 /// declare internal void @g() 490 /// 491 /// define internal void @f() personality i8* bitcast (void ()* @p to i8*) { 492 /// entry: 493 /// invoke void @g() 494 /// to label %exit unwind label %catch.pad 495 /// 496 /// catch.pad: 497 /// %0 = catchswitch within none [label %catch.body] unwind to caller 498 /// 499 /// catch.body: 500 /// %1 = catchpad within %0 [] 501 /// catchret from %1 to label %exit 502 /// 503 /// exit: 504 /// ret void 505 /// } 506 /// 507 /// For this test, we initially mark the entry block executable. The solver 508 /// then discovers the rest of the blocks in the function are executable. 509 TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) { 510 Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 511 GlobalValue::InternalLinkage, "p", &M); 512 Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 513 GlobalValue::InternalLinkage, "g", &M); 514 Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), 515 GlobalValue::InternalLinkage, "f", &M); 516 Constant *C = 517 ConstantExpr::getCast(Instruction::BitCast, P, Builder.getInt8PtrTy()); 518 F->setPersonalityFn(C); 519 BasicBlock *Entry = BasicBlock::Create(Context, "entry", F); 520 BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F); 521 BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F); 522 BasicBlock *Exit = BasicBlock::Create(Context, "exit", F); 523 Builder.SetInsertPoint(Entry); 524 Builder.CreateInvoke(G, Exit, Pad); 525 Builder.SetInsertPoint(Pad); 526 CatchSwitchInst *CatchSwitch = 527 Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1); 528 CatchSwitch->addHandler(Body); 529 Builder.SetInsertPoint(Body); 530 CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {}); 531 Builder.CreateCatchRet(CatchPad, Exit); 532 Builder.SetInsertPoint(Exit); 533 Builder.CreateRetVoid(); 534 535 Solver.MarkBlockExecutable(Entry); 536 Solver.Solve(); 537 538 EXPECT_TRUE(Solver.isBlockExecutable(Pad)); 539 EXPECT_TRUE(Solver.isBlockExecutable(Body)); 540 EXPECT_TRUE(Solver.isBlockExecutable(Exit)); 541 } 542