xref: /llvm-project/llvm/tools/llvm-reduce/deltas/ReduceOperandsToArgs.cpp (revision dd71b65ca85dd2f80aca3ad0f0235180580f5824)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ReduceOperandsToArgs.h"
10 #include "Delta.h"
11 #include "llvm/ADT/Sequence.h"
12 #include "llvm/IR/InstIterator.h"
13 #include "llvm/IR/InstrTypes.h"
14 #include "llvm/IR/Instructions.h"
15 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
16 #include "llvm/Transforms/Utils/Cloning.h"
17 
18 using namespace llvm;
19 
20 static bool canReplaceFunction(Function *F) {
21   return all_of(F->uses(), [](Use &Op) {
22     if (auto *CI = dyn_cast<CallBase>(Op.getUser()))
23       return &CI->getCalledOperandUse() == &Op;
24     return false;
25   });
26 }
27 
28 static bool canReduceUse(Use &Op) {
29   Value *Val = Op.get();
30   Type *Ty = Val->getType();
31 
32   // Only replace operands that can be passed-by-value.
33   if (!Ty->isFirstClassType())
34     return false;
35 
36   // Don't pass labels as arguments.
37   if (Ty->isLabelTy())
38     return false;
39 
40   // No need to replace values that are already arguments.
41   if (isa<Argument>(Val))
42     return false;
43 
44   // Do not replace literals.
45   if (isa<ConstantData>(Val))
46     return false;
47 
48   // Do not convert direct function calls to indirect calls.
49   if (auto *CI = dyn_cast<CallBase>(Op.getUser()))
50     if (&CI->getCalledOperandUse() == &Op)
51       return false;
52 
53   return true;
54 }
55 
56 /// Goes over OldF calls and replaces them with a call to NewF.
57 static void replaceFunctionCalls(Function *OldF, Function *NewF) {
58   SmallVector<CallBase *> Callers;
59   for (Use &U : OldF->uses()) {
60     auto *CI = cast<CallBase>(U.getUser());
61     assert(&U == &CI->getCalledOperandUse());
62     assert(CI->getCalledFunction() == OldF);
63     Callers.push_back(CI);
64   }
65 
66   // Call arguments for NewF.
67   SmallVector<Value *> Args(NewF->arg_size(), nullptr);
68 
69   // Fill up the additional parameters with undef values.
70   for (auto ArgIdx : llvm::seq<size_t>(OldF->arg_size(), NewF->arg_size())) {
71     Type *NewArgTy = NewF->getArg(ArgIdx)->getType();
72     Args[ArgIdx] = UndefValue::get(NewArgTy);
73   }
74 
75   for (CallBase *CI : Callers) {
76     // Preserve the original function arguments.
77     for (auto Z : zip_first(CI->args(), Args))
78       std::get<1>(Z) = std::get<0>(Z);
79 
80     // Also preserve operand bundles.
81     SmallVector<OperandBundleDef> OperandBundles;
82     CI->getOperandBundlesAsDefs(OperandBundles);
83 
84     // Create the new function call.
85     CallBase *NewCI;
86     if (auto *II = dyn_cast<InvokeInst>(CI)) {
87       NewCI = InvokeInst::Create(NewF, cast<InvokeInst>(II)->getNormalDest(),
88                                  cast<InvokeInst>(II)->getUnwindDest(), Args,
89                                  OperandBundles, CI->getName());
90     } else {
91       assert(isa<CallInst>(CI));
92       NewCI = CallInst::Create(NewF, Args, OperandBundles, CI->getName());
93     }
94     NewCI->setCallingConv(NewF->getCallingConv());
95 
96     // Do the replacement for this use.
97     if (!CI->use_empty())
98       CI->replaceAllUsesWith(NewCI);
99     ReplaceInstWithInst(CI, NewCI);
100   }
101 }
102 
103 /// Add a new function argument to @p F for each use in @OpsToReplace, and
104 /// replace those operand values with the new function argument.
105 static void substituteOperandWithArgument(Function *OldF,
106                                           ArrayRef<Use *> OpsToReplace) {
107   if (OpsToReplace.empty())
108     return;
109 
110   SetVector<Value *> UniqueValues;
111   for (Use *Op : OpsToReplace)
112     UniqueValues.insert(Op->get());
113 
114   // Determine the new function's signature.
115   SmallVector<Type *> NewArgTypes;
116   llvm::append_range(NewArgTypes, OldF->getFunctionType()->params());
117   size_t ArgOffset = NewArgTypes.size();
118   for (Value *V : UniqueValues)
119     NewArgTypes.push_back(V->getType());
120   FunctionType *FTy =
121       FunctionType::get(OldF->getFunctionType()->getReturnType(), NewArgTypes,
122                         OldF->getFunctionType()->isVarArg());
123 
124   // Create the new function...
125   Function *NewF =
126       Function::Create(FTy, OldF->getLinkage(), OldF->getAddressSpace(),
127                        OldF->getName(), OldF->getParent());
128 
129   // In order to preserve function order, we move NewF behind OldF
130   NewF->removeFromParent();
131   OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF);
132 
133   // Preserve the parameters of OldF.
134   ValueToValueMapTy VMap;
135   for (auto Z : zip_first(OldF->args(), NewF->args())) {
136     Argument &OldArg = std::get<0>(Z);
137     Argument &NewArg = std::get<1>(Z);
138 
139     NewArg.setName(OldArg.getName()); // Copy the name over...
140     VMap[&OldArg] = &NewArg;          // Add mapping to VMap
141   }
142 
143   // Adjust the new parameters.
144   ValueToValueMapTy OldValMap;
145   for (auto Z : zip_first(UniqueValues, drop_begin(NewF->args(), ArgOffset))) {
146     Value *OldVal = std::get<0>(Z);
147     Argument &NewArg = std::get<1>(Z);
148 
149     NewArg.setName(OldVal->getName());
150     OldValMap[OldVal] = &NewArg;
151   }
152 
153   SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
154   CloneFunctionInto(NewF, OldF, VMap, CloneFunctionChangeType::LocalChangesOnly,
155                     Returns, "", /*CodeInfo=*/nullptr);
156 
157   // Replace the actual operands.
158   for (Use *Op : OpsToReplace) {
159     Value *NewArg = OldValMap.lookup(Op->get());
160     auto *NewUser = cast<Instruction>(VMap.lookup(Op->getUser()));
161     NewUser->setOperand(Op->getOperandNo(), NewArg);
162   }
163 
164   // Replace all OldF uses with NewF.
165   replaceFunctionCalls(OldF, NewF);
166 
167   // Rename NewF to OldF's name.
168   std::string FName = OldF->getName().str();
169   OldF->replaceAllUsesWith(ConstantExpr::getBitCast(NewF, OldF->getType()));
170   OldF->eraseFromParent();
171   NewF->setName(FName);
172 }
173 
174 static void reduceOperandsToArgs(Oracle &O, Module &Program) {
175   SmallVector<Use *> OperandsToReduce;
176   for (Function &F : make_early_inc_range(Program.functions())) {
177     OperandsToReduce.clear();
178     for (Instruction &I : instructions(&F)) {
179       for (Use &Op : I.operands()) {
180         if (!canReduceUse(Op))
181           continue;
182         if (O.shouldKeep())
183           continue;
184 
185         OperandsToReduce.push_back(&Op);
186       }
187     }
188 
189     substituteOperandWithArgument(&F, OperandsToReduce);
190   }
191 }
192 
193 /// Counts the amount of operands in the module that can be reduced.
194 static int countOperands(Module &Program) {
195   int Count = 0;
196 
197   for (Function &F : Program.functions()) {
198     if (!canReplaceFunction(&F))
199       continue;
200     for (Instruction &I : instructions(&F)) {
201       for (Use &Op : I.operands()) {
202         if (!canReduceUse(Op))
203           continue;
204         Count += 1;
205       }
206     }
207   }
208 
209   return Count;
210 }
211 
212 void llvm::reduceOperandsToArgsDeltaPass(TestRunner &Test) {
213   outs() << "*** Converting operands to function arguments ...\n";
214   int ArgCount = countOperands(Test.getProgram());
215   return runDeltaPass(Test, ArgCount, reduceOperandsToArgs);
216 }
217