1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "ReduceOperandsToArgs.h" 10 #include "Delta.h" 11 #include "llvm/ADT/Sequence.h" 12 #include "llvm/IR/InstIterator.h" 13 #include "llvm/IR/InstrTypes.h" 14 #include "llvm/IR/Instructions.h" 15 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 16 #include "llvm/Transforms/Utils/Cloning.h" 17 18 using namespace llvm; 19 20 static bool canReplaceFunction(Function *F) { 21 return all_of(F->uses(), [](Use &Op) { 22 if (auto *CI = dyn_cast<CallBase>(Op.getUser())) 23 return &CI->getCalledOperandUse() == &Op; 24 return false; 25 }); 26 } 27 28 static bool canReduceUse(Use &Op) { 29 Value *Val = Op.get(); 30 Type *Ty = Val->getType(); 31 32 // Only replace operands that can be passed-by-value. 33 if (!Ty->isFirstClassType()) 34 return false; 35 36 // Don't pass labels as arguments. 37 if (Ty->isLabelTy()) 38 return false; 39 40 // No need to replace values that are already arguments. 41 if (isa<Argument>(Val)) 42 return false; 43 44 // Do not replace literals. 45 if (isa<ConstantData>(Val)) 46 return false; 47 48 // Do not convert direct function calls to indirect calls. 49 if (auto *CI = dyn_cast<CallBase>(Op.getUser())) 50 if (&CI->getCalledOperandUse() == &Op) 51 return false; 52 53 return true; 54 } 55 56 /// Goes over OldF calls and replaces them with a call to NewF. 57 static void replaceFunctionCalls(Function *OldF, Function *NewF) { 58 SmallVector<CallBase *> Callers; 59 for (Use &U : OldF->uses()) { 60 auto *CI = cast<CallBase>(U.getUser()); 61 assert(&U == &CI->getCalledOperandUse()); 62 assert(CI->getCalledFunction() == OldF); 63 Callers.push_back(CI); 64 } 65 66 // Call arguments for NewF. 67 SmallVector<Value *> Args(NewF->arg_size(), nullptr); 68 69 // Fill up the additional parameters with undef values. 70 for (auto ArgIdx : llvm::seq<size_t>(OldF->arg_size(), NewF->arg_size())) { 71 Type *NewArgTy = NewF->getArg(ArgIdx)->getType(); 72 Args[ArgIdx] = UndefValue::get(NewArgTy); 73 } 74 75 for (CallBase *CI : Callers) { 76 // Preserve the original function arguments. 77 for (auto Z : zip_first(CI->args(), Args)) 78 std::get<1>(Z) = std::get<0>(Z); 79 80 // Also preserve operand bundles. 81 SmallVector<OperandBundleDef> OperandBundles; 82 CI->getOperandBundlesAsDefs(OperandBundles); 83 84 // Create the new function call. 85 CallBase *NewCI; 86 if (auto *II = dyn_cast<InvokeInst>(CI)) { 87 NewCI = InvokeInst::Create(NewF, cast<InvokeInst>(II)->getNormalDest(), 88 cast<InvokeInst>(II)->getUnwindDest(), Args, 89 OperandBundles, CI->getName()); 90 } else { 91 assert(isa<CallInst>(CI)); 92 NewCI = CallInst::Create(NewF, Args, OperandBundles, CI->getName()); 93 } 94 NewCI->setCallingConv(NewF->getCallingConv()); 95 96 // Do the replacement for this use. 97 if (!CI->use_empty()) 98 CI->replaceAllUsesWith(NewCI); 99 ReplaceInstWithInst(CI, NewCI); 100 } 101 } 102 103 /// Add a new function argument to @p F for each use in @OpsToReplace, and 104 /// replace those operand values with the new function argument. 105 static void substituteOperandWithArgument(Function *OldF, 106 ArrayRef<Use *> OpsToReplace) { 107 if (OpsToReplace.empty()) 108 return; 109 110 SetVector<Value *> UniqueValues; 111 for (Use *Op : OpsToReplace) 112 UniqueValues.insert(Op->get()); 113 114 // Determine the new function's signature. 115 SmallVector<Type *> NewArgTypes; 116 llvm::append_range(NewArgTypes, OldF->getFunctionType()->params()); 117 size_t ArgOffset = NewArgTypes.size(); 118 for (Value *V : UniqueValues) 119 NewArgTypes.push_back(V->getType()); 120 FunctionType *FTy = 121 FunctionType::get(OldF->getFunctionType()->getReturnType(), NewArgTypes, 122 OldF->getFunctionType()->isVarArg()); 123 124 // Create the new function... 125 Function *NewF = 126 Function::Create(FTy, OldF->getLinkage(), OldF->getAddressSpace(), 127 OldF->getName(), OldF->getParent()); 128 129 // In order to preserve function order, we move NewF behind OldF 130 NewF->removeFromParent(); 131 OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF); 132 133 // Preserve the parameters of OldF. 134 ValueToValueMapTy VMap; 135 for (auto Z : zip_first(OldF->args(), NewF->args())) { 136 Argument &OldArg = std::get<0>(Z); 137 Argument &NewArg = std::get<1>(Z); 138 139 NewArg.setName(OldArg.getName()); // Copy the name over... 140 VMap[&OldArg] = &NewArg; // Add mapping to VMap 141 } 142 143 // Adjust the new parameters. 144 ValueToValueMapTy OldValMap; 145 for (auto Z : zip_first(UniqueValues, drop_begin(NewF->args(), ArgOffset))) { 146 Value *OldVal = std::get<0>(Z); 147 Argument &NewArg = std::get<1>(Z); 148 149 NewArg.setName(OldVal->getName()); 150 OldValMap[OldVal] = &NewArg; 151 } 152 153 SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned. 154 CloneFunctionInto(NewF, OldF, VMap, CloneFunctionChangeType::LocalChangesOnly, 155 Returns, "", /*CodeInfo=*/nullptr); 156 157 // Replace the actual operands. 158 for (Use *Op : OpsToReplace) { 159 Value *NewArg = OldValMap.lookup(Op->get()); 160 auto *NewUser = cast<Instruction>(VMap.lookup(Op->getUser())); 161 NewUser->setOperand(Op->getOperandNo(), NewArg); 162 } 163 164 // Replace all OldF uses with NewF. 165 replaceFunctionCalls(OldF, NewF); 166 167 // Rename NewF to OldF's name. 168 std::string FName = OldF->getName().str(); 169 OldF->replaceAllUsesWith(ConstantExpr::getBitCast(NewF, OldF->getType())); 170 OldF->eraseFromParent(); 171 NewF->setName(FName); 172 } 173 174 static void reduceOperandsToArgs(Oracle &O, Module &Program) { 175 SmallVector<Use *> OperandsToReduce; 176 for (Function &F : make_early_inc_range(Program.functions())) { 177 OperandsToReduce.clear(); 178 for (Instruction &I : instructions(&F)) { 179 for (Use &Op : I.operands()) { 180 if (!canReduceUse(Op)) 181 continue; 182 if (O.shouldKeep()) 183 continue; 184 185 OperandsToReduce.push_back(&Op); 186 } 187 } 188 189 substituteOperandWithArgument(&F, OperandsToReduce); 190 } 191 } 192 193 /// Counts the amount of operands in the module that can be reduced. 194 static int countOperands(Module &Program) { 195 int Count = 0; 196 197 for (Function &F : Program.functions()) { 198 if (!canReplaceFunction(&F)) 199 continue; 200 for (Instruction &I : instructions(&F)) { 201 for (Use &Op : I.operands()) { 202 if (!canReduceUse(Op)) 203 continue; 204 Count += 1; 205 } 206 } 207 } 208 209 return Count; 210 } 211 212 void llvm::reduceOperandsToArgsDeltaPass(TestRunner &Test) { 213 outs() << "*** Converting operands to function arguments ...\n"; 214 int ArgCount = countOperands(Test.getProgram()); 215 return runDeltaPass(Test, ArgCount, reduceOperandsToArgs); 216 } 217