xref: /llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp (revision beda7d517dfb06ea4a3523b907fe80afe438d499)
1 //===- CallSiteSplitting.cpp ----------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a transformation that tries to split a call-site to pass
11 // more constrained arguments if its argument is predicated in the control flow
12 // so that we can expose better context to the later passes (e.g, inliner, jump
13 // threading, or IPA-CP based function cloning, etc.).
14 // As of now we support two cases :
15 //
16 // 1) If a call site is dominated by an OR condition and if any of its arguments
17 // are predicated on this OR condition, try to split the condition with more
18 // constrained arguments. For example, in the code below, we try to split the
19 // call site since we can predicate the argument(ptr) based on the OR condition.
20 //
21 // Split from :
22 //   if (!ptr || c)
23 //     callee(ptr);
24 // to :
25 //   if (!ptr)
26 //     callee(null)         // set the known constant value
27 //   else if (c)
28 //     callee(nonnull ptr)  // set non-null attribute in the argument
29 //
30 // 2) We can also split a call-site based on constant incoming values of a PHI
31 // For example,
32 // from :
33 //   Header:
34 //    %c = icmp eq i32 %i1, %i2
35 //    br i1 %c, label %Tail, label %TBB
36 //   TBB:
37 //    br label Tail%
38 //   Tail:
39 //    %p = phi i32 [ 0, %Header], [ 1, %TBB]
40 //    call void @bar(i32 %p)
41 // to
42 //   Header:
43 //    %c = icmp eq i32 %i1, %i2
44 //    br i1 %c, label %Tail-split0, label %TBB
45 //   TBB:
46 //    br label %Tail-split1
47 //   Tail-split0:
48 //    call void @bar(i32 0)
49 //    br label %Tail
50 //   Tail-split1:
51 //    call void @bar(i32 1)
52 //    br label %Tail
53 //   Tail:
54 //    %p = phi i32 [ 0, %Tail-split0 ], [ 1, %Tail-split1 ]
55 //
56 //===----------------------------------------------------------------------===//
57 
58 #include "llvm/Transforms/Scalar/CallSiteSplitting.h"
59 #include "llvm/ADT/Statistic.h"
60 #include "llvm/Analysis/TargetLibraryInfo.h"
61 #include "llvm/IR/IntrinsicInst.h"
62 #include "llvm/IR/PatternMatch.h"
63 #include "llvm/Support/Debug.h"
64 #include "llvm/Transforms/Scalar.h"
65 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
66 #include "llvm/Transforms/Utils/Local.h"
67 
68 using namespace llvm;
69 using namespace PatternMatch;
70 
71 #define DEBUG_TYPE "callsite-splitting"
72 
73 STATISTIC(NumCallSiteSplit, "Number of call-site split");
74 
75 static void addNonNullAttribute(Instruction *CallI, Instruction *NewCallI,
76                                 Value *Op) {
77   CallSite CS(NewCallI);
78   unsigned ArgNo = 0;
79   for (auto &I : CS.args()) {
80     if (&*I == Op)
81       CS.addParamAttr(ArgNo, Attribute::NonNull);
82     ++ArgNo;
83   }
84 }
85 
86 static void setConstantInArgument(Instruction *CallI, Instruction *NewCallI,
87                                   Value *Op, Constant *ConstValue) {
88   CallSite CS(NewCallI);
89   unsigned ArgNo = 0;
90   for (auto &I : CS.args()) {
91     if (&*I == Op)
92       CS.setArgument(ArgNo, ConstValue);
93     ++ArgNo;
94   }
95 }
96 
97 static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) {
98   assert(isa<Constant>(Cmp->getOperand(1)) && "Expected a constant operand.");
99   Value *Op0 = Cmp->getOperand(0);
100   unsigned ArgNo = 0;
101   for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E;
102        ++I, ++ArgNo) {
103     // Don't consider constant or arguments that are already known non-null.
104     if (isa<Constant>(*I) || CS.paramHasAttr(ArgNo, Attribute::NonNull))
105       continue;
106 
107     if (*I == Op0)
108       return true;
109   }
110   return false;
111 }
112 
113 /// If From has a conditional jump to To, add the condition to Conditions,
114 /// if it is relevant to any argument at CS.
115 static void
116 recordCondition(const CallSite &CS, BasicBlock *From, BasicBlock *To,
117                 SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
118   auto *BI = dyn_cast<BranchInst>(From->getTerminator());
119   if (!BI || !BI->isConditional())
120     return;
121 
122   CmpInst::Predicate Pred;
123   Value *Cond = BI->getCondition();
124   if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant())))
125     return;
126 
127   ICmpInst *Cmp = cast<ICmpInst>(Cond);
128   if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)
129     if (isCondRelevantToAnyCallArgument(Cmp, CS))
130       Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To
131                                      ? Pred
132                                      : Cmp->getInversePredicate()});
133 }
134 
135 /// Record ICmp conditions relevant to any argument in CS following Pred's
136 /// single successors. If there are conflicting conditions along a path, like
137 /// x == 1 and x == 0, the first condition will be used.
138 static void
139 recordConditions(const CallSite &CS, BasicBlock *Pred,
140                  SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
141   recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions);
142   BasicBlock *From = Pred;
143   BasicBlock *To = Pred;
144   SmallPtrSet<BasicBlock *, 4> Visited = {From};
145   while (!Visited.count(From->getSinglePredecessor()) &&
146          (From = From->getSinglePredecessor())) {
147     recordCondition(CS, From, To, Conditions);
148     To = From;
149   }
150 }
151 
152 static Instruction *
153 addConditions(CallSite &CS,
154               SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
155   if (Conditions.empty())
156     return nullptr;
157 
158   Instruction *NewCI = CS.getInstruction()->clone();
159   for (auto &Cond : Conditions) {
160     Value *Arg = Cond.first->getOperand(0);
161     Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1));
162     if (Cond.second == ICmpInst::ICMP_EQ)
163       setConstantInArgument(CS.getInstruction(), NewCI, Arg, ConstVal);
164     else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) {
165       assert(Cond.second == ICmpInst::ICMP_NE);
166       addNonNullAttribute(CS.getInstruction(), NewCI, Arg);
167     }
168   }
169   return NewCI;
170 }
171 
172 static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) {
173   SmallVector<BasicBlock *, 2> Preds(predecessors((BB)));
174   assert(Preds.size() == 2 && "Expected exactly 2 predecessors!");
175   return Preds;
176 }
177 
178 static bool canSplitCallSite(CallSite CS) {
179   // FIXME: As of now we handle only CallInst. InvokeInst could be handled
180   // without too much effort.
181   Instruction *Instr = CS.getInstruction();
182   if (!isa<CallInst>(Instr))
183     return false;
184 
185   // Allow splitting a call-site only when there is no instruction before the
186   // call-site in the basic block. Based on this constraint, we only clone the
187   // call instruction, and we do not move a call-site across any other
188   // instruction.
189   BasicBlock *CallSiteBB = Instr->getParent();
190   if (Instr != CallSiteBB->getFirstNonPHIOrDbg())
191     return false;
192 
193   // Need 2 predecessors and cannot split an edge from an IndirectBrInst.
194   SmallVector<BasicBlock *, 2> Preds(predecessors(CallSiteBB));
195   if (Preds.size() != 2 || isa<IndirectBrInst>(Preds[0]->getTerminator()) ||
196       isa<IndirectBrInst>(Preds[1]->getTerminator()))
197     return false;
198 
199   return CallSiteBB->canSplitPredecessors();
200 }
201 
202 /// Return true if the CS is split into its new predecessors which are directly
203 /// hooked to each of its orignial predecessors pointed by PredBB1 and PredBB2.
204 /// In OR predicated case, PredBB1 will point the header, and PredBB2 will point
205 /// to the second compare block. CallInst1 and CallInst2 will be the new
206 /// call-sites placed in the new predecessors split for PredBB1 and PredBB2,
207 /// repectively. Therefore, CallInst1 will be the call-site placed
208 /// between Header and Tail, and CallInst2 will be the call-site between TBB and
209 /// Tail. For example, in the IR below with an OR condition, the call-site can
210 /// be split
211 ///
212 /// from :
213 ///
214 ///   Header:
215 ///     %c = icmp eq i32* %a, null
216 ///     br i1 %c %Tail, %TBB
217 ///   TBB:
218 ///     %c2 = icmp eq i32* %b, null
219 ///     br i1 %c %Tail, %End
220 ///   Tail:
221 ///     %ca = call i1  @callee (i32* %a, i32* %b)
222 ///
223 ///  to :
224 ///
225 ///   Header:                          // PredBB1 is Header
226 ///     %c = icmp eq i32* %a, null
227 ///     br i1 %c %Tail-split1, %TBB
228 ///   TBB:                             // PredBB2 is TBB
229 ///     %c2 = icmp eq i32* %b, null
230 ///     br i1 %c %Tail-split2, %End
231 ///   Tail-split1:
232 ///     %ca1 = call @callee (i32* null, i32* %b)         // CallInst1
233 ///    br %Tail
234 ///   Tail-split2:
235 ///     %ca2 = call @callee (i32* nonnull %a, i32* null) // CallInst2
236 ///    br %Tail
237 ///   Tail:
238 ///    %p = phi i1 [%ca1, %Tail-split1],[%ca2, %Tail-split2]
239 ///
240 /// Note that for an OR predicated case, CallInst1 and CallInst2 should be
241 /// created with more constrained arguments in
242 /// createCallSitesOnOrPredicatedArgument().
243 static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2,
244                           Instruction *CallInst1, Instruction *CallInst2) {
245   Instruction *Instr = CS.getInstruction();
246   BasicBlock *TailBB = Instr->getParent();
247   assert(Instr == (TailBB->getFirstNonPHIOrDbg()) && "Unexpected call-site");
248 
249   BasicBlock *SplitBlock1 =
250       SplitBlockPredecessors(TailBB, PredBB1, ".predBB1.split");
251   BasicBlock *SplitBlock2 =
252       SplitBlockPredecessors(TailBB, PredBB2, ".predBB2.split");
253 
254   assert((SplitBlock1 && SplitBlock2) && "Unexpected new basic block split.");
255 
256   if (!CallInst1)
257     CallInst1 = Instr->clone();
258   if (!CallInst2)
259     CallInst2 = Instr->clone();
260 
261   CallInst1->insertBefore(&*SplitBlock1->getFirstInsertionPt());
262   CallInst2->insertBefore(&*SplitBlock2->getFirstInsertionPt());
263 
264   CallSite CS1(CallInst1);
265   CallSite CS2(CallInst2);
266 
267   // Handle PHIs used as arguments in the call-site.
268   for (auto &PI : *TailBB) {
269     PHINode *PN = dyn_cast<PHINode>(&PI);
270     if (!PN)
271       break;
272     unsigned ArgNo = 0;
273     for (auto &CI : CS.args()) {
274       if (&*CI == PN) {
275         CS1.setArgument(ArgNo, PN->getIncomingValueForBlock(SplitBlock1));
276         CS2.setArgument(ArgNo, PN->getIncomingValueForBlock(SplitBlock2));
277       }
278       ++ArgNo;
279     }
280   }
281 
282   // Replace users of the original call with a PHI mering call-sites split.
283   if (Instr->getNumUses()) {
284     PHINode *PN = PHINode::Create(Instr->getType(), 2, "phi.call",
285                                   TailBB->getFirstNonPHI());
286     PN->addIncoming(CallInst1, SplitBlock1);
287     PN->addIncoming(CallInst2, SplitBlock2);
288     Instr->replaceAllUsesWith(PN);
289   }
290   DEBUG(dbgs() << "split call-site : " << *Instr << " into \n");
291   DEBUG(dbgs() << "    " << *CallInst1 << " in " << SplitBlock1->getName()
292                << "\n");
293   DEBUG(dbgs() << "    " << *CallInst2 << " in " << SplitBlock2->getName()
294                << "\n");
295   Instr->eraseFromParent();
296   NumCallSiteSplit++;
297 }
298 
299 // Return true if the call-site has an argument which is a PHI with only
300 // constant incoming values.
301 static bool isPredicatedOnPHI(CallSite CS) {
302   Instruction *Instr = CS.getInstruction();
303   BasicBlock *Parent = Instr->getParent();
304   if (Instr != Parent->getFirstNonPHIOrDbg())
305     return false;
306 
307   for (auto &BI : *Parent) {
308     if (PHINode *PN = dyn_cast<PHINode>(&BI)) {
309       for (auto &I : CS.args())
310         if (&*I == PN) {
311           assert(PN->getNumIncomingValues() == 2 &&
312                  "Unexpected number of incoming values");
313           if (PN->getIncomingBlock(0) == PN->getIncomingBlock(1))
314             return false;
315           if (PN->getIncomingValue(0) == PN->getIncomingValue(1))
316             continue;
317           if (isa<Constant>(PN->getIncomingValue(0)) &&
318               isa<Constant>(PN->getIncomingValue(1)))
319             return true;
320         }
321     }
322     break;
323   }
324   return false;
325 }
326 
327 static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) {
328   if (!isPredicatedOnPHI(CS))
329     return false;
330 
331   auto Preds = getTwoPredecessors(CS.getInstruction()->getParent());
332   splitCallSite(CS, Preds[0], Preds[1], nullptr, nullptr);
333   return true;
334 }
335 // Check if one of the predecessors is a single predecessors of the other.
336 // This is a requirement for control flow modeling an OR. HeaderBB points to
337 // the single predecessor and OrBB points to other node. HeaderBB potentially
338 // contains the first compare of the OR and OrBB the second.
339 static bool isOrHeader(BasicBlock *HeaderBB, BasicBlock *OrBB) {
340   return OrBB->getSinglePredecessor() == HeaderBB &&
341          HeaderBB->getTerminator()->getNumSuccessors() == 2;
342 }
343 
344 static bool tryToSplitOnOrPredicatedArgument(CallSite CS) {
345   auto Preds = getTwoPredecessors(CS.getInstruction()->getParent());
346   if (!isOrHeader(Preds[0], Preds[1]) && !isOrHeader(Preds[1], Preds[0]))
347     return false;
348 
349   SmallVector<std::pair<ICmpInst *, unsigned>, 2> C1, C2;
350   recordConditions(CS, Preds[0], C1);
351   recordConditions(CS, Preds[1], C2);
352 
353   Instruction *CallInst1 = addConditions(CS, C1);
354   Instruction *CallInst2 = addConditions(CS, C2);
355   if (!CallInst1 && !CallInst2)
356     return false;
357 
358   splitCallSite(CS, Preds[1], Preds[0], CallInst2, CallInst1);
359   return true;
360 }
361 
362 static bool tryToSplitCallSite(CallSite CS) {
363   if (!CS.arg_size() || !canSplitCallSite(CS))
364     return false;
365   return tryToSplitOnOrPredicatedArgument(CS) ||
366          tryToSplitOnPHIPredicatedArgument(CS);
367 }
368 
369 static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI) {
370   bool Changed = false;
371   for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE;) {
372     BasicBlock &BB = *BI++;
373     for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) {
374       Instruction *I = &*II++;
375       CallSite CS(cast<Value>(I));
376       if (!CS || isa<IntrinsicInst>(I) || isInstructionTriviallyDead(I, &TLI))
377         continue;
378 
379       Function *Callee = CS.getCalledFunction();
380       if (!Callee || Callee->isDeclaration())
381         continue;
382       Changed |= tryToSplitCallSite(CS);
383     }
384   }
385   return Changed;
386 }
387 
388 namespace {
389 struct CallSiteSplittingLegacyPass : public FunctionPass {
390   static char ID;
391   CallSiteSplittingLegacyPass() : FunctionPass(ID) {
392     initializeCallSiteSplittingLegacyPassPass(*PassRegistry::getPassRegistry());
393   }
394 
395   void getAnalysisUsage(AnalysisUsage &AU) const override {
396     AU.addRequired<TargetLibraryInfoWrapperPass>();
397     FunctionPass::getAnalysisUsage(AU);
398   }
399 
400   bool runOnFunction(Function &F) override {
401     if (skipFunction(F))
402       return false;
403 
404     auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
405     return doCallSiteSplitting(F, TLI);
406   }
407 };
408 } // namespace
409 
410 char CallSiteSplittingLegacyPass::ID = 0;
411 INITIALIZE_PASS_BEGIN(CallSiteSplittingLegacyPass, "callsite-splitting",
412                       "Call-site splitting", false, false)
413 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
414 INITIALIZE_PASS_END(CallSiteSplittingLegacyPass, "callsite-splitting",
415                     "Call-site splitting", false, false)
416 FunctionPass *llvm::createCallSiteSplittingPass() {
417   return new CallSiteSplittingLegacyPass();
418 }
419 
420 PreservedAnalyses CallSiteSplittingPass::run(Function &F,
421                                              FunctionAnalysisManager &AM) {
422   auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
423 
424   if (!doCallSiteSplitting(F, TLI))
425     return PreservedAnalyses::all();
426   PreservedAnalyses PA;
427   return PA;
428 }
429