xref: /llvm-project/llvm/unittests/Transforms/Utils/FunctionComparatorTest.cpp (revision dd3184c30ff531b8aecea280e65233337dd02815)
1 //===- FunctionComparator.cpp - Unit tests for FunctionComparator ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #include "llvm/Transforms/Utils/FunctionComparator.h"
9 #include "llvm/IR/BasicBlock.h"
10 #include "llvm/IR/IRBuilder.h"
11 #include "llvm/IR/Instructions.h"
12 #include "llvm/IR/LLVMContext.h"
13 #include "llvm/IR/Module.h"
14 #include "gtest/gtest.h"
15 
16 using namespace llvm;
17 
18 /// Generates a simple test function.
19 struct TestFunction {
20   Function *F;
21   BasicBlock *BB;
22   Constant *C;
23   Instruction *I;
24   Type *T;
25 
TestFunctionTestFunction26   TestFunction(LLVMContext &Ctx, Module &M, int addVal) {
27     IRBuilder<> B(Ctx);
28     T = B.getInt8Ty();
29     F = Function::Create(FunctionType::get(T, {B.getPtrTy()}, false),
30                          GlobalValue::ExternalLinkage, "F", &M);
31     BB = BasicBlock::Create(Ctx, "", F);
32     B.SetInsertPoint(BB);
33     Argument *PointerArg = &*F->arg_begin();
34     LoadInst *LoadInst = B.CreateLoad(T, PointerArg);
35     C = B.getInt8(addVal);
36     I = cast<Instruction>(B.CreateAdd(LoadInst, C));
37     B.CreateRet(I);
38   }
39 };
40 
41 /// A class for testing the FunctionComparator API.
42 ///
43 /// The main purpose is to test if the required protected functions are
44 /// accessible from a derived class of FunctionComparator.
45 class TestComparator : public FunctionComparator {
46 public:
TestComparator(const Function * F1,const Function * F2,GlobalNumberState * GN)47   TestComparator(const Function *F1, const Function *F2,
48                  GlobalNumberState *GN)
49         : FunctionComparator(F1, F2, GN) {
50   }
51 
testFunctionAccess(const Function * F1,const Function * F2)52   bool testFunctionAccess(const Function *F1, const Function *F2) {
53     // Test if FnL and FnR are accessible.
54     return F1 == FnL && F2 == FnR;
55   }
56 
testCompare()57   int testCompare() {
58     return compare();
59   }
60 
testCompareSignature()61   int testCompareSignature() {
62     beginCompare();
63     return compareSignature();
64   }
65 
testCmpBasicBlocks(BasicBlock * BBL,BasicBlock * BBR)66   int testCmpBasicBlocks(BasicBlock *BBL, BasicBlock *BBR) {
67     beginCompare();
68     return cmpBasicBlocks(BBL, BBR);
69   }
70 
testCmpConstants(const Constant * L,const Constant * R)71   int testCmpConstants(const Constant *L, const Constant *R) {
72     beginCompare();
73     return cmpConstants(L, R);
74   }
75 
testCmpGlobalValues(GlobalValue * L,GlobalValue * R)76   int testCmpGlobalValues(GlobalValue *L, GlobalValue *R) {
77     beginCompare();
78     return cmpGlobalValues(L, R);
79   }
80 
testCmpValues(const Value * L,const Value * R)81   int testCmpValues(const Value *L, const Value *R) {
82     beginCompare();
83     return cmpValues(L, R);
84   }
85 
testCmpOperations(const Instruction * L,const Instruction * R,bool & needToCmpOperands)86   int testCmpOperations(const Instruction *L, const Instruction *R,
87                         bool &needToCmpOperands) {
88     beginCompare();
89     return cmpOperations(L, R, needToCmpOperands);
90   }
91 
testCmpTypes(Type * TyL,Type * TyR)92   int testCmpTypes(Type *TyL, Type *TyR) {
93     beginCompare();
94     return cmpTypes(TyL, TyR);
95   }
96 
testCmpPrimitives()97   int testCmpPrimitives() {
98     beginCompare();
99     return
100       cmpNumbers(2, 3) +
101       cmpAPInts(APInt(32, 2), APInt(32, 3)) +
102       cmpAPFloats(APFloat(2.0), APFloat(3.0)) +
103       cmpMem("2", "3");
104   }
105 };
106 
107 /// A sanity check for the FunctionComparator API.
TEST(FunctionComparatorTest,TestAPI)108 TEST(FunctionComparatorTest, TestAPI) {
109   LLVMContext C;
110   Module M("test", C);
111   TestFunction F1(C, M, 27);
112   TestFunction F2(C, M, 28);
113 
114   GlobalNumberState GN;
115   TestComparator Cmp(F1.F, F2.F, &GN);
116 
117   EXPECT_TRUE(Cmp.testFunctionAccess(F1.F, F2.F));
118   EXPECT_EQ(Cmp.testCompare(), -1);
119   EXPECT_EQ(Cmp.testCompareSignature(), 0);
120   EXPECT_EQ(Cmp.testCmpBasicBlocks(F1.BB, F2.BB), -1);
121   EXPECT_EQ(Cmp.testCmpConstants(F1.C, F2.C), -1);
122   EXPECT_EQ(Cmp.testCmpGlobalValues(F1.F, F2.F), -1);
123   EXPECT_EQ(Cmp.testCmpValues(F1.I, F2.I), 0);
124   bool needToCmpOperands = false;
125   EXPECT_EQ(Cmp.testCmpOperations(F1.I, F2.I, needToCmpOperands), 0);
126   EXPECT_TRUE(needToCmpOperands);
127   EXPECT_EQ(Cmp.testCmpTypes(F1.T, F2.T), 0);
128   EXPECT_EQ(Cmp.testCmpPrimitives(), -4);
129 }
130