xref: /minix3/minix/llvm/passes/include/magic/support/MagicMmapCtlFunction.h (revision d2532d3d42d764c9ef9816851cdb17eda7e08d36)
1 #ifndef MAGIC_MMAP_CTL_FUNCTION_H
2 #define MAGIC_MMAP_CTL_FUNCTION_H
3 
4 #include <pass.h>
5 #include <magic/support/TypeInfo.h>
6 
7 using namespace llvm;
8 
9 namespace llvm {
10 
11 class MagicMmapCtlFunction {
12   public:
13       MagicMmapCtlFunction(Function *function, PointerType *voidPointerType, std::string &ptrArgName, std::string &lenArgName);
14 
15       Function* getFunction() const;
16       void fixCalls(Module &M, Function *magicGetPageSizeFunc) const;
17 
18       void print(raw_ostream &OS) const;
19       void printDescription(raw_ostream &OS) const;
20       const std::string getDescription() const;
21 
22   private:
23       Function *function;
24       int ptrArg;
25       int lenArg;
26 };
27 
28 inline raw_ostream &operator<<(raw_ostream &OS, const MagicMmapCtlFunction &aMagicMmapCtlFunction) {
29     aMagicMmapCtlFunction.print(OS);
30     return OS;
31 }
32 
33 inline void MagicMmapCtlFunction::print(raw_ostream &OS) const {
34      OS << getDescription();
35 }
36 
37 inline void MagicMmapCtlFunction::printDescription(raw_ostream &OS) const {
38     OS << "[ function = "; OS << function->getName() << "(" << TypeUtil::getDescription(function->getFunctionType()) << ")"
39        << ", ptr arg = "; OS << ptrArg
40        << ", len arg = "; OS << lenArg
41        << "]";
42 }
43 
44 inline const std::string MagicMmapCtlFunction::getDescription() const {
45     std::string string;
46     raw_string_ostream ostream(string);
47     printDescription(ostream);
48     ostream.flush();
49     return string;
50 }
51 
52 inline MagicMmapCtlFunction::MagicMmapCtlFunction(Function *function, PointerType *voidPointerType, std::string &ptrArgName, std::string &lenArgName) {
53     this->function = function;
54     this->ptrArg = -1;
55     this->lenArg = -1;
56     bool lookupPtrArg = ptrArgName.compare("");
57     bool lookupLenArg = lenArgName.compare("");
58     assert((lookupPtrArg || lookupLenArg) && "No valid argument name specified!");
59     unsigned i=0;
60     for (Function::arg_iterator it = function->arg_begin(), E = function->arg_end();
61         it != E; ++it) {
62         std::string argName = it->getName();
63         if(lookupPtrArg && !argName.compare(ptrArgName)) {
64             this->ptrArg = i;
65         }
66         else if(lookupLenArg && !argName.compare(lenArgName)) {
67             this->lenArg = i;
68         }
69         i++;
70     }
71     if(this->ptrArg >= 0) {
72         assert(function->getFunctionType()->getContainedType(this->ptrArg+1) == voidPointerType && "Invalid ptr argument specified!");
73     }
74     else {
75         assert(!lookupPtrArg && "Invalid ptr argument name specified!");
76     }
77     if(this->lenArg >= 0) {
78         assert(isa<IntegerType>(function->getFunctionType()->getContainedType(this->lenArg+1)) && "Invalid len argument specified!");
79     }
80     else {
81         assert(!lookupLenArg && "Invalid len argument name specified!");
82     }
83 }
84 
85 inline Function* MagicMmapCtlFunction::getFunction() const {
86     return function;
87 }
88 
89 /* This assumes in-band metadata of 1 page before every mmapped region. */
90 inline void MagicMmapCtlFunction::fixCalls(Module &M, Function *magicGetPageSizeFunc) const {
91     std::vector<User*> Users(function->user_begin(), function->user_end());
92     while (!Users.empty()) {
93         User *U = Users.back();
94         Users.pop_back();
95 
96         if (Instruction *I = dyn_cast<Instruction>(U)) {
97             Function *parent = I->getParent()->getParent();
98             if(parent->getName().startswith("magic") || parent->getName().startswith("_magic")) {
99                 continue;
100             }
101             CallSite CS = MagicUtil::getCallSiteFromInstruction(I);
102 
103             std::vector<Value*> args;
104             CallInst* magicGetPageSizeCall = MagicUtil::createCallInstruction(magicGetPageSizeFunc, args, "", I);
105             magicGetPageSizeCall->setCallingConv(CallingConv::C);
106             magicGetPageSizeCall->setTailCall(false);
107             TYPECONST IntegerType *type = dyn_cast<IntegerType>(magicGetPageSizeCall->getType());
108             assert(type);
109 
110             if(this->ptrArg >= 0) {
111                 Value *ptrValue = CS.getArgument(this->ptrArg);
112                 BinaryOperator* negativePageSize = BinaryOperator::Create(Instruction::Sub, ConstantInt::get(M.getContext(), APInt(type->getBitWidth(), 0)), magicGetPageSizeCall, "", I);
113                 GetElementPtrInst* ptrValueWithOffset = GetElementPtrInst::Create(ptrValue, negativePageSize, "", I);
114 
115                 CS.setArgument(this->ptrArg, ptrValueWithOffset);
116             }
117             if(this->lenArg >= 0) {
118                 Value *lenValue = CS.getArgument(this->lenArg);
119                 BinaryOperator* lenValuePlusPageSize = BinaryOperator::Create(Instruction::Add, lenValue, magicGetPageSizeCall, "", I);
120 
121                 CS.setArgument(this->lenArg, lenValuePlusPageSize);
122             }
123         }
124     }
125 }
126 
127 }
128 
129 #endif
130 
131