1 #ifndef MAGIC_MMAP_CTL_FUNCTION_H 2 #define MAGIC_MMAP_CTL_FUNCTION_H 3 4 #include <pass.h> 5 #include <magic/support/TypeInfo.h> 6 7 using namespace llvm; 8 9 namespace llvm { 10 11 class MagicMmapCtlFunction { 12 public: 13 MagicMmapCtlFunction(Function *function, PointerType *voidPointerType, std::string &ptrArgName, std::string &lenArgName); 14 15 Function* getFunction() const; 16 void fixCalls(Module &M, Function *magicGetPageSizeFunc) const; 17 18 void print(raw_ostream &OS) const; 19 void printDescription(raw_ostream &OS) const; 20 const std::string getDescription() const; 21 22 private: 23 Function *function; 24 int ptrArg; 25 int lenArg; 26 }; 27 28 inline raw_ostream &operator<<(raw_ostream &OS, const MagicMmapCtlFunction &aMagicMmapCtlFunction) { 29 aMagicMmapCtlFunction.print(OS); 30 return OS; 31 } 32 33 inline void MagicMmapCtlFunction::print(raw_ostream &OS) const { 34 OS << getDescription(); 35 } 36 37 inline void MagicMmapCtlFunction::printDescription(raw_ostream &OS) const { 38 OS << "[ function = "; OS << function->getName() << "(" << TypeUtil::getDescription(function->getFunctionType()) << ")" 39 << ", ptr arg = "; OS << ptrArg 40 << ", len arg = "; OS << lenArg 41 << "]"; 42 } 43 44 inline const std::string MagicMmapCtlFunction::getDescription() const { 45 std::string string; 46 raw_string_ostream ostream(string); 47 printDescription(ostream); 48 ostream.flush(); 49 return string; 50 } 51 52 inline MagicMmapCtlFunction::MagicMmapCtlFunction(Function *function, PointerType *voidPointerType, std::string &ptrArgName, std::string &lenArgName) { 53 this->function = function; 54 this->ptrArg = -1; 55 this->lenArg = -1; 56 bool lookupPtrArg = ptrArgName.compare(""); 57 bool lookupLenArg = lenArgName.compare(""); 58 assert((lookupPtrArg || lookupLenArg) && "No valid argument name specified!"); 59 unsigned i=0; 60 for (Function::arg_iterator it = function->arg_begin(), E = function->arg_end(); 61 it != E; ++it) { 62 std::string argName = it->getName(); 63 if(lookupPtrArg && !argName.compare(ptrArgName)) { 64 this->ptrArg = i; 65 } 66 else if(lookupLenArg && !argName.compare(lenArgName)) { 67 this->lenArg = i; 68 } 69 i++; 70 } 71 if(this->ptrArg >= 0) { 72 assert(function->getFunctionType()->getContainedType(this->ptrArg+1) == voidPointerType && "Invalid ptr argument specified!"); 73 } 74 else { 75 assert(!lookupPtrArg && "Invalid ptr argument name specified!"); 76 } 77 if(this->lenArg >= 0) { 78 assert(isa<IntegerType>(function->getFunctionType()->getContainedType(this->lenArg+1)) && "Invalid len argument specified!"); 79 } 80 else { 81 assert(!lookupLenArg && "Invalid len argument name specified!"); 82 } 83 } 84 85 inline Function* MagicMmapCtlFunction::getFunction() const { 86 return function; 87 } 88 89 /* This assumes in-band metadata of 1 page before every mmapped region. */ 90 inline void MagicMmapCtlFunction::fixCalls(Module &M, Function *magicGetPageSizeFunc) const { 91 std::vector<User*> Users(function->user_begin(), function->user_end()); 92 while (!Users.empty()) { 93 User *U = Users.back(); 94 Users.pop_back(); 95 96 if (Instruction *I = dyn_cast<Instruction>(U)) { 97 Function *parent = I->getParent()->getParent(); 98 if(parent->getName().startswith("magic") || parent->getName().startswith("_magic")) { 99 continue; 100 } 101 CallSite CS = MagicUtil::getCallSiteFromInstruction(I); 102 103 std::vector<Value*> args; 104 CallInst* magicGetPageSizeCall = MagicUtil::createCallInstruction(magicGetPageSizeFunc, args, "", I); 105 magicGetPageSizeCall->setCallingConv(CallingConv::C); 106 magicGetPageSizeCall->setTailCall(false); 107 TYPECONST IntegerType *type = dyn_cast<IntegerType>(magicGetPageSizeCall->getType()); 108 assert(type); 109 110 if(this->ptrArg >= 0) { 111 Value *ptrValue = CS.getArgument(this->ptrArg); 112 BinaryOperator* negativePageSize = BinaryOperator::Create(Instruction::Sub, ConstantInt::get(M.getContext(), APInt(type->getBitWidth(), 0)), magicGetPageSizeCall, "", I); 113 GetElementPtrInst* ptrValueWithOffset = GetElementPtrInst::Create(ptrValue, negativePageSize, "", I); 114 115 CS.setArgument(this->ptrArg, ptrValueWithOffset); 116 } 117 if(this->lenArg >= 0) { 118 Value *lenValue = CS.getArgument(this->lenArg); 119 BinaryOperator* lenValuePlusPageSize = BinaryOperator::Create(Instruction::Add, lenValue, magicGetPageSizeCall, "", I); 120 121 CS.setArgument(this->lenArg, lenValuePlusPageSize); 122 } 123 } 124 } 125 } 126 127 } 128 129 #endif 130 131