xref: /llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp (revision ca268ed28520cbe05e5a9f006cb7b615301a4aa1)
1 //===-- ConstraintElimination.cpp - Eliminate conds using constraints. ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Eliminate conditions based on constraints collected from dominating
10 // conditions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Transforms/Scalar/ConstraintElimination.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/ScopeExit.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/Analysis/ConstraintSystem.h"
20 #include "llvm/Analysis/GlobalsModRef.h"
21 #include "llvm/IR/DataLayout.h"
22 #include "llvm/IR/Dominators.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/PatternMatch.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/DebugCounter.h"
30 #include "llvm/Transforms/Scalar.h"
31 
32 #include <string>
33 
34 using namespace llvm;
35 using namespace PatternMatch;
36 
37 #define DEBUG_TYPE "constraint-elimination"
38 
39 STATISTIC(NumCondsRemoved, "Number of instructions removed");
40 DEBUG_COUNTER(EliminatedCounter, "conds-eliminated",
41               "Controls which conditions are eliminated");
42 
43 static int64_t MaxConstraintValue = std::numeric_limits<int64_t>::max();
44 
45 // Decomposes \p V into a vector of pairs of the form { c, X } where c * X. The
46 // sum of the pairs equals \p V.  The first pair is the constant-factor and X
47 // must be nullptr. If the expression cannot be decomposed, returns an empty
48 // vector.
49 static SmallVector<std::pair<int64_t, Value *>, 4> decompose(Value *V) {
50   if (auto *CI = dyn_cast<ConstantInt>(V)) {
51     if (CI->isNegative() || CI->uge(MaxConstraintValue))
52       return {};
53     return {{CI->getSExtValue(), nullptr}};
54   }
55   auto *GEP = dyn_cast<GetElementPtrInst>(V);
56   if (GEP && GEP->getNumOperands() == 2 && GEP->isInBounds()) {
57     if (isa<ConstantInt>(GEP->getOperand(GEP->getNumOperands() - 1))) {
58       return {{cast<ConstantInt>(GEP->getOperand(GEP->getNumOperands() - 1))
59                    ->getSExtValue(),
60                nullptr},
61               {1, GEP->getPointerOperand()}};
62     }
63     Value *Op0;
64     ConstantInt *CI;
65     if (match(GEP->getOperand(GEP->getNumOperands() - 1),
66               m_NUWShl(m_Value(Op0), m_ConstantInt(CI))))
67       return {{0, nullptr},
68               {1, GEP->getPointerOperand()},
69               {std::pow(int64_t(2), CI->getSExtValue()), Op0}};
70     if (match(GEP->getOperand(GEP->getNumOperands() - 1),
71               m_ZExt(m_NUWShl(m_Value(Op0), m_ConstantInt(CI)))))
72       return {{0, nullptr},
73               {1, GEP->getPointerOperand()},
74               {std::pow(int64_t(2), CI->getSExtValue()), Op0}};
75 
76     return {{0, nullptr},
77             {1, GEP->getPointerOperand()},
78             {1, GEP->getOperand(GEP->getNumOperands() - 1)}};
79   }
80 
81   Value *Op0;
82   if (match(V, m_ZExt(m_Value(Op0))))
83     V = Op0;
84 
85   Value *Op1;
86   ConstantInt *CI;
87   if (match(V, m_NUWAdd(m_Value(Op0), m_ConstantInt(CI))))
88     return {{CI->getSExtValue(), nullptr}, {1, Op0}};
89   if (match(V, m_NUWAdd(m_Value(Op0), m_Value(Op1))))
90     return {{0, nullptr}, {1, Op0}, {1, Op1}};
91 
92   if (match(V, m_NUWSub(m_Value(Op0), m_ConstantInt(CI))))
93     return {{-1 * CI->getSExtValue(), nullptr}, {1, Op0}};
94   if (match(V, m_NUWSub(m_Value(Op0), m_Value(Op1))))
95     return {{0, nullptr}, {1, Op0}, {1, Op1}};
96 
97   return {{0, nullptr}, {1, V}};
98 }
99 
100 struct ConstraintTy {
101   SmallVector<int64_t, 8> Coefficients;
102 
103   ConstraintTy(SmallVector<int64_t, 8> Coefficients)
104       : Coefficients(Coefficients) {}
105 
106   unsigned size() const { return Coefficients.size(); }
107 };
108 
109 /// Turn a condition \p CmpI into a constraint vector, using indices from \p
110 /// Value2Index. If \p ShouldAdd is true, new indices are added for values not
111 /// yet in \p Value2Index.
112 static SmallVector<ConstraintTy, 4>
113 getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1,
114               DenseMap<Value *, unsigned> &Value2Index, bool ShouldAdd) {
115   int64_t Offset1 = 0;
116   int64_t Offset2 = 0;
117 
118   auto TryToGetIndex = [ShouldAdd,
119                         &Value2Index](Value *V) -> Optional<unsigned> {
120     if (ShouldAdd) {
121       Value2Index.insert({V, Value2Index.size() + 1});
122       return Value2Index[V];
123     }
124     auto I = Value2Index.find(V);
125     if (I == Value2Index.end())
126       return None;
127     return I->second;
128   };
129 
130   if (Pred == CmpInst::ICMP_UGT || Pred == CmpInst::ICMP_UGE)
131     return getConstraint(CmpInst::getSwappedPredicate(Pred), Op1, Op0,
132                          Value2Index, ShouldAdd);
133 
134   if (Pred == CmpInst::ICMP_EQ) {
135     auto A = getConstraint(CmpInst::ICMP_UGE, Op0, Op1, Value2Index, ShouldAdd);
136     auto B = getConstraint(CmpInst::ICMP_ULE, Op0, Op1, Value2Index, ShouldAdd);
137     append_range(A, B);
138     return A;
139   }
140 
141   // Only ULE and ULT predicates are supported at the moment.
142   if (Pred != CmpInst::ICMP_ULE && Pred != CmpInst::ICMP_ULT)
143     return {};
144 
145   auto ADec = decompose(Op0->stripPointerCasts());
146   auto BDec = decompose(Op1->stripPointerCasts());
147   // Skip if decomposing either of the values failed.
148   if (ADec.empty() || BDec.empty())
149     return {};
150 
151   // Skip trivial constraints without any variables.
152   if (ADec.size() == 1 && BDec.size() == 1)
153     return {};
154 
155   Offset1 = ADec[0].first;
156   Offset2 = BDec[0].first;
157   Offset1 *= -1;
158 
159   // Create iterator ranges that skip the constant-factor.
160   auto VariablesA = make_range(std::next(ADec.begin()), ADec.end());
161   auto VariablesB = make_range(std::next(BDec.begin()), BDec.end());
162 
163   // Check if each referenced value in the constraint is already in the system
164   // or can be added (if ShouldAdd is true).
165   for (const auto &KV :
166        concat<std::pair<int64_t, Value *>>(VariablesA, VariablesB))
167     if (!TryToGetIndex(KV.second))
168       return {};
169 
170   // Build result constraint, by first adding all coefficients from A and then
171   // subtracting all coefficients from B.
172   SmallVector<int64_t, 8> R(Value2Index.size() + 1, 0);
173   for (const auto &KV : VariablesA)
174     R[Value2Index[KV.second]] += KV.first;
175 
176   for (const auto &KV : VariablesB)
177     R[Value2Index[KV.second]] -= KV.first;
178 
179   R[0] = Offset1 + Offset2 + (Pred == CmpInst::ICMP_ULT ? -1 : 0);
180   return {R};
181 }
182 
183 static SmallVector<ConstraintTy, 4>
184 getConstraint(CmpInst *Cmp, DenseMap<Value *, unsigned> &Value2Index,
185               bool ShouldAdd) {
186   return getConstraint(Cmp->getPredicate(), Cmp->getOperand(0),
187                        Cmp->getOperand(1), Value2Index, ShouldAdd);
188 }
189 
190 namespace {
191 /// Represents either a condition that holds on entry to a block or a basic
192 /// block, with their respective Dominator DFS in and out numbers.
193 struct ConstraintOrBlock {
194   unsigned NumIn;
195   unsigned NumOut;
196   bool IsBlock;
197   bool Not;
198   union {
199     BasicBlock *BB;
200     CmpInst *Condition;
201   };
202 
203   ConstraintOrBlock(DomTreeNode *DTN)
204       : NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), IsBlock(true),
205         BB(DTN->getBlock()) {}
206   ConstraintOrBlock(DomTreeNode *DTN, CmpInst *Condition, bool Not)
207       : NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), IsBlock(false),
208         Not(Not), Condition(Condition) {}
209 };
210 
211 struct StackEntry {
212   unsigned NumIn;
213   unsigned NumOut;
214   CmpInst *Condition;
215   bool IsNot;
216 
217   StackEntry(unsigned NumIn, unsigned NumOut, CmpInst *Condition, bool IsNot)
218       : NumIn(NumIn), NumOut(NumOut), Condition(Condition), IsNot(IsNot) {}
219 };
220 } // namespace
221 
222 #ifndef NDEBUG
223 static void dumpWithNames(ConstraintTy &C,
224                           DenseMap<Value *, unsigned> &Value2Index) {
225   SmallVector<std::string> Names(Value2Index.size(), "");
226   for (auto &KV : Value2Index) {
227     Names[KV.second - 1] = std::string("%") + KV.first->getName().str();
228   }
229   ConstraintSystem CS;
230   CS.addVariableRowFill(C.Coefficients);
231   CS.dump(Names);
232 }
233 #endif
234 
235 static bool eliminateConstraints(Function &F, DominatorTree &DT) {
236   bool Changed = false;
237   DT.updateDFSNumbers();
238   ConstraintSystem CS;
239 
240   SmallVector<ConstraintOrBlock, 64> WorkList;
241 
242   // First, collect conditions implied by branches and blocks with their
243   // Dominator DFS in and out numbers.
244   for (BasicBlock &BB : F) {
245     if (!DT.getNode(&BB))
246       continue;
247     WorkList.emplace_back(DT.getNode(&BB));
248 
249     auto *Br = dyn_cast<BranchInst>(BB.getTerminator());
250     if (!Br || !Br->isConditional())
251       continue;
252 
253     // Returns true if we can add a known condition from BB to its successor
254     // block Succ. Each predecessor of Succ can either be BB or be dominated by
255     // Succ (e.g. the case when adding a condition from a pre-header to a loop
256     // header).
257     auto CanAdd = [&BB, &DT](BasicBlock *Succ) {
258       return all_of(predecessors(Succ), [&BB, &DT, Succ](BasicBlock *Pred) {
259         return Pred == &BB || DT.dominates(Succ, Pred);
260       });
261     };
262     // If the condition is an OR of 2 compares and the false successor only has
263     // the current block as predecessor, queue both negated conditions for the
264     // false successor.
265     Value *Op0, *Op1;
266     if (match(Br->getCondition(), m_LogicalOr(m_Value(Op0), m_Value(Op1))) &&
267         match(Op0, m_Cmp()) && match(Op1, m_Cmp())) {
268       BasicBlock *FalseSuccessor = Br->getSuccessor(1);
269       if (CanAdd(FalseSuccessor)) {
270         WorkList.emplace_back(DT.getNode(FalseSuccessor), cast<CmpInst>(Op0),
271                               true);
272         WorkList.emplace_back(DT.getNode(FalseSuccessor), cast<CmpInst>(Op1),
273                               true);
274       }
275       continue;
276     }
277 
278     // If the condition is an AND of 2 compares and the true successor only has
279     // the current block as predecessor, queue both conditions for the true
280     // successor.
281     if (match(Br->getCondition(), m_LogicalAnd(m_Value(Op0), m_Value(Op1))) &&
282         match(Op0, m_Cmp()) && match(Op1, m_Cmp())) {
283       BasicBlock *TrueSuccessor = Br->getSuccessor(0);
284       if (CanAdd(TrueSuccessor)) {
285         WorkList.emplace_back(DT.getNode(TrueSuccessor), cast<CmpInst>(Op0),
286                               false);
287         WorkList.emplace_back(DT.getNode(TrueSuccessor), cast<CmpInst>(Op1),
288                               false);
289       }
290       continue;
291     }
292 
293     auto *CmpI = dyn_cast<CmpInst>(Br->getCondition());
294     if (!CmpI)
295       continue;
296     if (CanAdd(Br->getSuccessor(0)))
297       WorkList.emplace_back(DT.getNode(Br->getSuccessor(0)), CmpI, false);
298     if (CanAdd(Br->getSuccessor(1)))
299       WorkList.emplace_back(DT.getNode(Br->getSuccessor(1)), CmpI, true);
300   }
301 
302   // Next, sort worklist by dominance, so that dominating blocks and conditions
303   // come before blocks and conditions dominated by them. If a block and a
304   // condition have the same numbers, the condition comes before the block, as
305   // it holds on entry to the block.
306   sort(WorkList, [](const ConstraintOrBlock &A, const ConstraintOrBlock &B) {
307     return std::tie(A.NumIn, A.IsBlock) < std::tie(B.NumIn, B.IsBlock);
308   });
309 
310   // Finally, process ordered worklist and eliminate implied conditions.
311   SmallVector<StackEntry, 16> DFSInStack;
312   DenseMap<Value *, unsigned> Value2Index;
313   for (ConstraintOrBlock &CB : WorkList) {
314     // First, pop entries from the stack that are out-of-scope for CB. Remove
315     // the corresponding entry from the constraint system.
316     while (!DFSInStack.empty()) {
317       auto &E = DFSInStack.back();
318       LLVM_DEBUG(dbgs() << "Top of stack : " << E.NumIn << " " << E.NumOut
319                         << "\n");
320       LLVM_DEBUG(dbgs() << "CB: " << CB.NumIn << " " << CB.NumOut << "\n");
321       assert(E.NumIn <= CB.NumIn);
322       if (CB.NumOut <= E.NumOut)
323         break;
324       LLVM_DEBUG(dbgs() << "Removing " << *E.Condition << " " << E.IsNot
325                         << "\n");
326       DFSInStack.pop_back();
327       CS.popLastConstraint();
328     }
329 
330     LLVM_DEBUG({
331       dbgs() << "Processing ";
332       if (CB.IsBlock)
333         dbgs() << *CB.BB;
334       else
335         dbgs() << *CB.Condition;
336       dbgs() << "\n";
337     });
338 
339     // For a block, check if any CmpInsts become known based on the current set
340     // of constraints.
341     if (CB.IsBlock) {
342       for (Instruction &I : *CB.BB) {
343         auto *Cmp = dyn_cast<CmpInst>(&I);
344         if (!Cmp)
345           continue;
346         auto R = getConstraint(Cmp, Value2Index, false);
347         if (R.size() != 1 || R[0].size() == 1)
348           continue;
349         if (CS.isConditionImplied(R[0].Coefficients)) {
350           if (!DebugCounter::shouldExecute(EliminatedCounter))
351             continue;
352 
353           LLVM_DEBUG(dbgs() << "Condition " << *Cmp
354                             << " implied by dominating constraints\n");
355           LLVM_DEBUG({
356             for (auto &E : reverse(DFSInStack))
357               dbgs() << "   C " << *E.Condition << " " << E.IsNot << "\n";
358           });
359           Cmp->replaceAllUsesWith(
360               ConstantInt::getTrue(F.getParent()->getContext()));
361           NumCondsRemoved++;
362           Changed = true;
363         }
364         if (CS.isConditionImplied(
365                 ConstraintSystem::negate(R[0].Coefficients))) {
366           if (!DebugCounter::shouldExecute(EliminatedCounter))
367             continue;
368 
369           LLVM_DEBUG(dbgs() << "Condition !" << *Cmp
370                             << " implied by dominating constraints\n");
371           LLVM_DEBUG({
372             for (auto &E : reverse(DFSInStack))
373               dbgs() << "   C " << *E.Condition << " " << E.IsNot << "\n";
374           });
375           Cmp->replaceAllUsesWith(
376               ConstantInt::getFalse(F.getParent()->getContext()));
377           NumCondsRemoved++;
378           Changed = true;
379         }
380       }
381       continue;
382     }
383 
384     // Set up a function to restore the predicate at the end of the scope if it
385     // has been negated. Negate the predicate in-place, if required.
386     auto *CI = dyn_cast<CmpInst>(CB.Condition);
387     auto PredicateRestorer = make_scope_exit([CI, &CB]() {
388       if (CB.Not && CI)
389         CI->setPredicate(CI->getInversePredicate());
390     });
391     if (CB.Not) {
392       if (CI) {
393         CI->setPredicate(CI->getInversePredicate());
394       } else {
395         LLVM_DEBUG(dbgs() << "Can only negate compares so far.\n");
396         continue;
397       }
398     }
399 
400     // Otherwise, add the condition to the system and stack, if we can transform
401     // it into a constraint.
402     auto R = getConstraint(CB.Condition, Value2Index, true);
403     if (R.empty())
404       continue;
405 
406     LLVM_DEBUG(dbgs() << "Adding " << *CB.Condition << " " << CB.Not << "\n");
407     bool Added = false;
408     for (auto &C : R) {
409       auto Coeffs = C.Coefficients;
410       LLVM_DEBUG({
411         dbgs() << "  constraint: ";
412         dumpWithNames(C, Value2Index);
413       });
414       Added |= CS.addVariableRowFill(Coeffs);
415       // If R has been added to the system, queue it for removal once it goes
416       // out-of-scope.
417       if (Added)
418         DFSInStack.emplace_back(CB.NumIn, CB.NumOut, CB.Condition, CB.Not);
419     }
420   }
421 
422   assert(CS.size() == DFSInStack.size() &&
423          "updates to CS and DFSInStack are out of sync");
424   return Changed;
425 }
426 
427 PreservedAnalyses ConstraintEliminationPass::run(Function &F,
428                                                  FunctionAnalysisManager &AM) {
429   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
430   if (!eliminateConstraints(F, DT))
431     return PreservedAnalyses::all();
432 
433   PreservedAnalyses PA;
434   PA.preserve<DominatorTreeAnalysis>();
435   PA.preserve<GlobalsAA>();
436   PA.preserveSet<CFGAnalyses>();
437   return PA;
438 }
439 
440 namespace {
441 
442 class ConstraintElimination : public FunctionPass {
443 public:
444   static char ID;
445 
446   ConstraintElimination() : FunctionPass(ID) {
447     initializeConstraintEliminationPass(*PassRegistry::getPassRegistry());
448   }
449 
450   bool runOnFunction(Function &F) override {
451     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
452     return eliminateConstraints(F, DT);
453   }
454 
455   void getAnalysisUsage(AnalysisUsage &AU) const override {
456     AU.setPreservesCFG();
457     AU.addRequired<DominatorTreeWrapperPass>();
458     AU.addPreserved<GlobalsAAWrapperPass>();
459     AU.addPreserved<DominatorTreeWrapperPass>();
460   }
461 };
462 
463 } // end anonymous namespace
464 
465 char ConstraintElimination::ID = 0;
466 
467 INITIALIZE_PASS_BEGIN(ConstraintElimination, "constraint-elimination",
468                       "Constraint Elimination", false, false)
469 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
470 INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)
471 INITIALIZE_PASS_END(ConstraintElimination, "constraint-elimination",
472                     "Constraint Elimination", false, false)
473 
474 FunctionPass *llvm::createConstraintEliminationPass() {
475   return new ConstraintElimination();
476 }
477